[心得] 计算π到小数点下十亿位 ─ 超进化版

楼主: Schottky (顺风相送)   2021-02-26 03:27:19
和上篇一样,这篇是把计算圆周率π的程式照样修改,
执行速度也快了两倍多,
原本十亿位需要 1 小时 28 分钟,缩短到 40 分钟
速度变快的关键主要出在三处:
1. mpfr.div() 原本写错了,应该要把整数的 mpz 转成浮点数 mpfr 再除比较快,
我却写错了,直接拿 mpz 去除,不知为何这样会让运算速度变得超级慢。
2. 写档时 write_string() 里面有做排版,而这个排版的循环是一个字符一个字符
处理,速度非常的慢,我改成一次处理 50 个字符一行,速度就快了 50 倍。
3. 进度条处理得不好,不需更新进度时应该尽快 return,我却做了多余的
数学运算,多做一次当然没什么,多做七千万次就有影响效能了。
这个 Divide and Conquer 的写法很适合 multi-processing,
以及进度条可以改用 tqdm module,这些建议都不错,
不过请体谅我才刚学 Python 没几天,需要点时间消化 (汗)
这次程式码也放在 https://ideone.com/6YO1zU 方便大家复制贴上
#!/usr/bin/env python3
#
# pi.py - Calculate Pi
#
import sys
import time
import math
import gmpy2
from gmpy2 import mpfr
from gmpy2 import mpz
#
# Global Variables
#
count = 0
total = 0
grad = 0
step = 0
#
# Show Progress
#
def progress_init(max):
global count, total, grad, step
total = max
count = 0
step = int(total / 1000)
grad = int(step / 2)
def progress():
global count, total, grad, step
if (count > grad):
grad += step
g = int(math.floor(72.5*count/total+0.5))
p = int(math.floor(1000.5*count/total+0.5))
msg = "H" * g + "-" * (72-g) + " " + str(p/10) + "%\r"
if (grad > total):
msg += "\n"
print(msg, sep="", end="", flush=True)
#
# Write digit string
#
def write_string(digit_string):
fd = open("pi-py.txt", mode="w")
fd.write(" pi = ")
fd.write(digit_string[0])
fd.write(".")
for c in range(1, len(digit_string), 50):
if (c != 1):
fd.write("\t")
fd.write(digit_string[c:c+50])
if ((c % 1000) == 951):
fd.write(" << ")
fd.write(str(c+49))
fd.write("\r\n")
elif ((c % 500) == 451):
fd.write(" <\r\n")
else:
fd.write("\r\n")
# Final new-line
fd.write("\r\n")
fd.close()
#
# Recursive funcion.
#
def s(a, b, max):
global count
m = math.ceil((a + b) / 2)
if (b - a == 1):
if (a == 0):
r = 120 # 6!
q = mpz(640320**3)
p = gmpy2.sub( gmpy2.mul(q, 13591409),
gmpy2.mul(r, 13591409+545140134) )
else:
r = mpz(8 * (a*6+1) * (a*6+3) * (a*6+5))
q = mpz((b*640320)**3)
if ((b%2) == 0):
p = gmpy2.mul(mpz(13591409 + b*545140134), r)
else:
p = gmpy2.mul(mpz(-13591409 - b*545140134), r)
else:
p1, q1, r1 = s(a, m, max)
p2, q2, r2 = s(m, b, max)
# Merge
p = gmpy2.add( gmpy2.mul(p1, q2), gmpy2.mul(p2, r1) )
q = gmpy2.mul(q1, q2)
if (b != max):
r = gmpy2.mul(r1, r2)
else:
r = 0
count += 1
progress()
return p, q, r
#
# Calculate e
#
def calc_pi(digits):
global total
d = digits+1
n_terms = math.ceil(d*math.log(10)/(3*math.log(53360)))
precision = math.ceil(d * math.log2(10)) + 4
print("d = ", d, ", n = ", n_terms, ", precision = ", precision, sep="")
print("gmpy2 version:", gmpy2.version())
print("MP version:", gmpy2.mp_version())
print("MPFR version:", gmpy2.mpfr_version())
max_precision = gmpy2.get_max_precision()
print("max_precision =", max_precision)
max_emax = gmpy2.get_emax_max()
print("max_emax =", max_emax)
if (max_precision < precision):
print("Error! Max precision is too small! Program terminated.")
return
gmpy2.get_context().precision = precision
gmpy2.get_context().emax = max_emax
print("Real precision = ", gmpy2.get_context().precision)
progress_init(n_terms * 2 - 1) # Initialize progress bar
# Recursion
start_time = time.monotonic_ns()
p, q, r = s(0, n_terms, n_terms)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Recursion:", elapsed, "seconds.")
start_time = time.monotonic_ns()
q = gmpy2.mul(q, 426880)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Multiply by 426880:", elapsed, "seconds.")
start_time = time.monotonic_ns()
pf = mpfr(p)
qf = mpfr(q)
ef = gmpy2.div(qf, pf)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Grand Division:", elapsed, "seconds.")
start_time = time.monotonic_ns()
ef = gmpy2.mul(ef, gmpy2.sqrt(10005))
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Multiply by sqrt(10005):", elapsed, "seconds.")
# Convert to decimal digits
start_time = time.monotonic_ns()
estr, exp, prec = mpfr.digits(ef)
estr = estr[0:d]
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Convert to decimal digits:", elapsed, "seconds.")
# Write to file
start_time = time.monotonic_ns()
write_string(estr)
end_time = time.monotonic_ns()
elapsed = (end_time - start_time) / 1000000000
print("Write to file:", elapsed, "seconds.")
#
# main program
#
if __name__ == '__main__':
argc = len(sys.argv)
if (argc >= 2):
digits = int(sys.argv[1])
else:
digits = 100000
calc_pi(digits)
# End of pi.py
作者: idletime (idle)   2021-03-06 23:06:00
榨效能专用

Links booklink

Contact Us: admin [ a t ] ucptt.com