Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 144 additions & 2 deletions lib/bigdecimal/math.rb
Original file line number Diff line number Diff line change
Expand Up @@ -737,12 +737,149 @@ def gamma(x, prec)
return pi.div(gamma(1 - x, prec2).mult(sin, prec2), prec)
elsif x.frac.zero? && x < 1000 * prec
return _gamma_positive_integer(x, prec2).mult(1, prec)
elsif x < (prec / prec.bit_length)**2
return _gamma_lagrange(x, prec2).mult(1, prec)
end

a, sum = _gamma_spouge_sum_part(x, prec2)
(x + (a - 1)).power(x - 0.5, prec2).mult(BigMath.exp(1 - x, prec2), prec2).mult(sum, prec)
end

# Calculates prod{x-k} and its coefficients for given ks, xn and prec with baby-step giant-step method.
# xn is an array of precalculated powers of x: [1, x, x**2, x**3, ...]
private_class_method def _x_minus_k_prod_coef(ks, xn, prec) # :nodoc:
coef = [1]
ks.each do |k|
coef_next = [0] * (coef.size + 1)
coef.each_with_index do |c, i|
coef_next[i] -= k * c
coef_next[i + 1] += c
end
coef = coef_next
end

prd = coef.each_with_index.map do |c, i|
xn[i].mult(c, prec)
end.reduce do |sum, value|
sum.add(value, prec)
end
[prd, coef]
end

# Calculate approximate gamma by Lagrange interpolation of f(x) = b**x / x!
# Nodes are placed at x_i = b-l, b-l+1, ..., b+l.
#
# Mathematically, we use the barycentric interpolation form:
# f(x) \approx \omega(x) \sum_{i} \frac{w_i f(x_i)}{x - x_i}
# Therefore, \Gamma(x+1) = x! = b**x / f(x)
#
# Estimated time complexity:
# - O(N * log^3 N) for small-digit/rational x (Binary Splitting)
# - O(N^2) for full-digit x (Baby-step Giant-step)
# Precondition: 0 < x < const * (prec / prec.bit_length)**2
private_class_method def _gamma_lagrange(x, prec) # :nodoc:
x = BigDecimal(x) - 1
l = prec

# Shift x to ensure the scaled bell curve b**x/x! is wide enough to cover
# all 2*l+1 interpolation nodes without losing significant digits.
shift = x < 2 * prec ? 2 * prec - x.floor : 0
x += shift
b = x.round

# c0 represents the common large factorial part factored out from the weights
# to avoid computing massive numbers in every term.
c0s = [*(1..b-l), *(1..2*l)]
c0s = c0s.each_slice(2).map {|a, b| b ? BigDecimal(a).mult(b, prec) : BigDecimal(a) } while c0s.size != 1
c0 = c0s.first

# --- Reference: Naive interpolation logic ---
# Optimize this calculation for full-digit-x case and small-digit-x case.
# sum = BigDecimal(0)
# prod = [*(b-l)..(b+l), *(0...shift)].map {|i| x - i }.reduce { _1.mult(_2, prec) }
# c = BigDecimal(1) # represents w_i * f(x_i) (normalized)
# ((b-l)..(b+l)).each do |i|
# if i != b - l
# c = c.mult(-b * (b + l - i + 1), prec).div((i - b + l) * i, prec)
# end
# sum = sum.add(c.div(x - i, prec), prec)
# end
# --------------------------------------------

if x.n_significant_digits > prec / prec.bit_length
# Reduce full-precision multiplications/divisions by Baby-Step Giant-Step method

batch_size = prec.bit_length
# When expanding prod{x-k}, the coefficient of x**n might be huge.
# Increase internal calculation precision to avoid catastrophic cancellation.
internal_xn_prec = prec + (Math.log10(b + l) * batch_size).ceil
xn = [BigDecimal(1)]
xn << xn.last.mult(x, internal_xn_prec) while xn.size <= batch_size

c = BigDecimal(1)
sum = BigDecimal(0)
prod = BigDecimal(1)

((b-l)..(b+l)).to_a.each_slice(batch_size) do |batch_ks|
# Calculate prod{x-k} in this batch
batch_prod, prod_coef = _x_minus_k_prod_coef(batch_ks, xn, internal_xn_prec)

# Calculate coefficients of batch_prod / (x-k) using Synthetic Division (Ruffini's rule)
batch_coef = [0] * batch_ks.size
c_scale = 1r
batch_ks.each do |k|
c_scale = c_scale * (-b * (b + l - k + 1)) / ((k - b + l) * k) if k != b - l
rem = 0
(batch_ks.size - 1).downto(0) do |i|
quo = prod_coef[i + 1] + rem
rem = quo * k
batch_coef[i] += c_scale * quo
end
end

batch_sum = BigDecimal(0)
batch_coef.each_with_index do |coef, i|
batch_sum = batch_sum.add(xn[i].mult(coef.numerator, internal_xn_prec).div(coef.denominator, internal_xn_prec), internal_xn_prec)
end
sum = sum.add(batch_sum.mult(c, prec).div(batch_prod, prec), prec)
c = c.mult(c_scale.numerator, prec).div(c_scale.denominator, prec)
prod = prod.mult(batch_prod, prec)
end

# Perform shift.times {|i| prod = prod.mult(x - i, prec) } with batch processing
shift.times.to_a.each_slice(batch_size) do |batch_ks|
shift_prod, _prod_coef = _x_minus_k_prod_coef(batch_ks, xn, internal_xn_prec)
prod = prod.mult(shift_prod, prec)
end
else
# Binary Splitting Method (BSM) for short-digit inputs
prods = ((b-l)..(b+l)).map {|i| x - i } + shift.times.map {|i| x - i }
prods = prods.each_slice(2).map {|a, b| b ? a.mult(b, prec) : a } while prods.size != 1
prod = prods.first

# State represents: [Base_Denominator, Numerator, Denominator_Multiplier]
fractions = (b - l + 1..b + l).map do |i|
denominator = (x - i).mult((i - b + l) * i, prec)
numerator = (x - i + 1).mult(-b * (b + l - i + 1), prec)
[denominator, numerator, denominator]
end

while fractions.size > 1
fractions = fractions.each_slice(2).map do |a, b|
b ||= [BigDecimal(1), BigDecimal(0), BigDecimal(1)]
# Merge operation for BSM:
# a[0]/a[2] + a[1]/a[2] * (b[0]/b[2] + b[1]/b[2] * rest)
# = (a[0]*b[2] + a[1]*b[0]) / (a[2]*b[2]) + (a[1]*b[1]) / (a[2]*b[2]) * rest
[a[0].mult(b[2], prec).add(a[1].mult(b[0], prec), prec), a[1].mult(b[1], prec), a[2].mult(b[2], prec)]
end
end
sum = fractions[0][0].add(fractions[0][1], prec).div(fractions[0][2], prec).div(x - b + l, prec)
end

# Reconstruct Gamma(x_original) by reversing the scaling and applying shift formula
BigDecimal(b).power(x - (b - l), prec).div(prod.mult(sum, prec), prec).mult(c0, prec)
end

# call-seq:
# BigMath.lgamma(decimal, numeric) -> [BigDecimal, Integer]
#
Expand Down Expand Up @@ -790,8 +927,13 @@ def lgamma(x, prec)
end

prec2 += [-diff1_exponent, -diff2_exponent, 0].max
a, sum = _gamma_spouge_sum_part(x, prec2)
log_gamma = BigMath.log(sum, prec2).add((x - 0.5).mult(BigMath.log(x.add(a - 1, prec2), prec2), prec2) + 1 - x, prec)

if x < (prec / prec.bit_length)**2
log_gamma = BigMath.log(_gamma_lagrange(x, prec2), prec)
else
a, sum = _gamma_spouge_sum_part(x, prec2)
log_gamma = BigMath.log(sum, prec2).add((x - 0.5).mult(BigMath.log(x.add(a - 1, prec2), prec2), prec2) + 1 - x, prec)
end
[log_gamma, 1]
end
end
Expand Down
12 changes: 6 additions & 6 deletions test/bigdecimal/test_bigmath.rb
Original file line number Diff line number Diff line change
Expand Up @@ -567,11 +567,11 @@ def test_gamma
BigDecimal('0.28242294079603478742934215780245355184774949260912e456569'),
BigMath.gamma(100000, 50)
)
precisions = [50, 100, 150]
assert_converge_in_precision(precisions) {|n| gamma(BigDecimal("0.3"), n) }
assert_converge_in_precision(precisions) {|n| gamma(BigDecimal("-1.9" + "9" * 30), n) }
assert_converge_in_precision(precisions) {|n| gamma(BigDecimal("1234.56789"), n) }
assert_converge_in_precision(precisions) {|n| gamma(BigDecimal("-987.654321"), n) }
precisions = [100, 200, 300, 400]
assert_converge_in_precision() {|n| gamma(BigDecimal("0.3"), n) }
assert_converge_in_precision() {|n| gamma(BigDecimal("-1.9" + "9" * 30), n) }
assert_converge_in_precision() {|n| gamma(BigDecimal("1234.56789"), n) }
assert_converge_in_precision() {|n| gamma(BigDecimal("-987.654321"), n) }
end

def test_lgamma
Expand All @@ -587,7 +587,7 @@ def test_lgamma
assert_equal(sign, bigsign)
end
assert_equal([BigMath.log(PI(120).sqrt(120), 100), 1], lgamma(BigDecimal("0.5"), 100))
precisions = [50, 100, 150]
precisions = [50, 100, 150, 200]
assert_converge_in_precision(precisions) {|n| lgamma(BigDecimal("0." + "9" * 80), n).first }
assert_converge_in_precision(precisions) {|n| lgamma(BigDecimal("1." + "0" * 80 + "1"), n).first }
assert_converge_in_precision(precisions) {|n| lgamma(BigDecimal("1." + "9" * 80), n).first }
Expand Down