diff --git a/.translate/state/numba.md.yml b/.translate/state/numba.md.yml index c5f1f07..73a8e2c 100644 --- a/.translate/state/numba.md.yml +++ b/.translate/state/numba.md.yml @@ -1,6 +1,6 @@ -source-sha: cc9c3256dc35bd277cb25d0089f0a0452c0fa94e -synced-at: "2026-03-20" +source-sha: a897eb875ecbf8975872fd6bc0a6b2e91a6ba701 +synced-at: "2026-04-12" model: claude-sonnet-4-6 -mode: NEW -section-count: 8 -tool-version: 0.13.1 +mode: UPDATE +section-count: 6 +tool-version: 0.14.1 diff --git a/lectures/numba.md b/lectures/numba.md index e6739c7..765ae7d 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -14,11 +14,9 @@ translation: headings: Overview: 概述 Compiling Functions: 编译函数 - Compiling Functions::An Example: 一个示例 - Compiling Functions::How and When it Works: 工作原理及适用场景 - Decorator Notation: 装饰器语法 + Compiling Functions::An Example: 示例 + Compiling Functions::How and When it Works: 工作原理与适用时机 Type Inference: 类型推断 - Compiling Classes: 编译类 Dangers and Limitations: 危险与局限 Dangers and Limitations::Limitations: 局限性 'Dangers and Limitations::A Gotcha: Global Variables': 一个陷阱:全局变量 @@ -26,7 +24,7 @@ translation: Exercises: 练习 --- -(speed)= +(numba_lecture)= ```{raw} jupyter
@@ -62,29 +60,30 @@ mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n ## 概述 -在{doc}`之前的讲座 `中,我们学习了向量化,这是提高数值计算速度和效率的一种方法。 +在 {doc}`之前的讲座 ` 中,我们讨论了向量化,它通过将数组处理操作批量发送到高效的底层代码来提高执行速度。 -向量化将数组处理操作批量发送到高效的底层代码。 +然而,正如 {ref}`在那次讲座中所讨论的 `,传统的向量化方案(如 MATLAB 和 NumPy 中的方案)存在一些弱点。 -然而,正如{ref}`之前所讨论的 `,向量化有几个弱点。 +* 对于复合数组操作,内存消耗极大 +* 对某些算法无效或无法实现。 -其一是在处理大量数据时内存消耗极大。 +绕过这些问题的一种方法是使用 [Numba](https://numba.pydata.org/),这是一个面向数值计算的 Python **即时(JIT)编译器**。 -其二是能够完全向量化的算法集并非通用的。 +Numba 在运行时将函数编译为本地机器码指令。 -事实上,对于某些算法,向量化是无效的。 +编译成功后,其性能可与编译后的 C 或 Fortran 相媲美。 -幸运的是,一个名为 [Numba](https://numba.pydata.org/) 的新 Python 库解决了许多这些问题。 +此外,Numba 还可以完成其他有用的技巧,例如 {ref}`多线程 ` 或通过 `numba.cuda` 与 GPU 交互。 -它通过一种称为**即时(JIT)编译**的技术来实现这一点。 +Numba 的 JIT 编译器在许多方面与 Julia 中的 JIT 编译器类似。 -核心思想是在运行时将函数编译为本地机器码指令。 +主要区别在于它没有那么雄心勃勃,只尝试编译语言的一个较小子集。 -编译成功后,编译后的代码速度极快。 +虽然这听起来像是一个缺陷,但在某些方面这也是一个优势。 -除了编译带来的速度提升之外,Numba 还专门为数值计算而设计,并且还可以完成其他技巧,例如{ref}`多线程 `。 +Numba 精简、易用,并且在其所做的事情上非常出色。 -本讲座将介绍主要思路。 +本讲座将介绍核心思路。 (numba_link)= ## {index}`编译函数 ` @@ -92,26 +91,25 @@ mpl.rcParams['font.family'] = ['Source Han Serif SC'] # i18n ```{index} single: Python; Numba ``` -如上所述,Numba 的主要用途是在运行时将函数编译为快速的本地机器码。 (quad_map_eg)= -### 一个示例 +### 示例 -让我们考虑一个难以向量化的问题:给定初始条件,生成差分方程的轨迹。 +让我们考虑一个难以向量化(即交由数组处理操作完成)的问题。 -我们将采用的差分方程是二次映射 +该问题涉及通过二次映射生成轨迹 $$ -x_{t+1} = \alpha x_t (1 - x_t) + x_{t+1} = \alpha x_t (1 - x_t) $$ -在下文中,我们设定 +在下文中,我们设置 ```{code-cell} ipython3 α = 4.0 ``` -以下是一条典型轨迹的图像,从 $x_0 = 0.1$ 开始,以 $t$ 为横轴 +以下是一条典型轨迹的图像,从 $x_0 = 0.1$ 开始,x 轴表示 $t$ ```{code-cell} ipython3 def qm(x0, n): @@ -129,7 +127,7 @@ ax.set_ylabel('$x_{t}$', fontsize = 12) plt.show() ``` -要使用 Numba 加速函数 `qm`,我们的第一步是 +要使用 Numba 加速函数 `qm`,第一步是 ```{code-cell} ipython3 from numba import jit @@ -141,7 +139,7 @@ qm_numba = jit(qm) 我们稍后将解释这意味着什么。 -让我们对这两个版本进行相同的函数调用计时和比较,首先从原始函数 `qm` 开始: +让我们对这两个版本进行相同的函数调用计时并比较,首先从原始函数 `qm` 开始: ```{code-cell} ipython3 n = 10_000_000 @@ -151,7 +149,7 @@ with qe.Timer() as timer1: time1 = timer1.elapsed ``` -现在让我们试试 qm_numba +现在让我们尝试 qm_numba ```{code-cell} ipython3 with qe.Timer() as timer2: @@ -159,9 +157,9 @@ with qe.Timer() as timer2: time2 = timer2.elapsed ``` -这已经是一个非常大的速度提升。 +这已经是非常大的速度提升。 -事实上,下次及之后每次运行时速度都会更快,因为函数已经被编译并存储在内存中: +实际上,下一次及之后所有运行会更快,因为函数已经被编译并存储在内存中: (qm_numba_result)= @@ -172,80 +170,38 @@ time3 = timer3.elapsed ``` ```{code-cell} ipython3 -time1 / time3 # 计算速度提升倍数 +time1 / time3 # Calculate speed gain ``` -相对于修改的简单和清晰程度,这种速度提升令人印象深刻。 +### 工作原理与适用时机 -### 工作原理及适用场景 +Numba 尝试利用 [LLVM Project](https://llvm.org/) 提供的基础设施生成快速机器码。 -Numba 尝试使用 [LLVM 项目](https://llvm.org/) 提供的基础设施生成快速的机器码。 +它通过动态推断类型信息来实现这一目标。 -它通过即时推断类型信息来实现这一点。 - -(有关类型的讨论,请参阅我们{doc}`之前关于科学计算的讲座 `。) +(参见我们 {doc}`早前的讲座 `,其中讨论了科学计算中的类型问题。) 基本思路如下: -* Python 非常灵活,因此我们可以用多种类型调用函数 `qm`。 +* Python 非常灵活,因此我们可以用多种类型调用函数 qm。 * 例如,`x0` 可以是 NumPy 数组或列表,`n` 可以是整数或浮点数,等等。 -* 这使得*预*编译函数(即在运行时之前编译)变得困难。 -* 然而,当我们实际调用函数时,比如运行 `qm(0.5, 10)`,`x0` 和 `n` 的类型就变得清晰了。 -* 此外,一旦知道输入类型,`qm` 中其他变量的类型也可以被推断出来。 -* 因此,Numba 和其他 JIT 编译器的策略是等到这一时刻,*然后*再编译函数。 - -这就是为什么它被称为"即时"编译。 - -请注意,如果您调用 `qm(0.5, 10)`,然后紧跟着调用 `qm(0.9, 20)`,编译只发生在第一次调用时。 +* 这使得*提前*(即在运行时之前)生成高效机器码非常困难。 +* 然而,当我们实际*调用*函数时,例如运行 `qm(0.5, 10)`,`x0` 和 `n` 的类型就变得明确了。 +* 此外,一旦输入类型已知,`qm` 中*其他变量*的类型也*可以被推断出来*。 +* 因此,Numba 和其他 JIT 编译器的策略是*等到函数被调用时再进行编译*。 -编译后的代码会被缓存并按需复用。 +这被称为"即时"编译。 -这就是为什么在上面的代码中,`time3` 比 `time2` 小。 +注意,如果你先调用 `qm(0.5, 10)`,然后再调用 `qm(0.9, 20)`,编译只在第一次调用时发生。 -## 装饰器语法 +这是因为编译后的代码会被缓存并在需要时重复使用。 -在上面的代码中,我们通过调用以下方式创建了 `qm` 的 JIT 编译版本 +这就是为什么在上面的代码中,`time3` 小于 `time2`。 -```{code-cell} ipython3 -qm_numba = jit(qm) +```{admonition} 备注 +在实践中,我们不直接写 `qm_numba = jit(qm)`,而是使用*装饰器*语法,在函数定义前加上 `@jit`。这等价于在定义之后添加 `qm = jit(qm)`。我们在本讲座的其余部分都使用这种语法。(有关装饰器的更多内容,请参见 {doc}`python_advanced_features`。) ``` -在实践中,这通常使用另一种*装饰器*语法来完成。 - -(我们在{doc}`单独的讲座 `中讨论装饰器,但在此阶段您可以跳过细节。) - -让我们看看这是如何完成的。 - -要将函数定向为 JIT 编译,我们可以在函数定义前放置 `@jit`。 - -以下是 `qm` 的写法 - -```{code-cell} ipython3 -@jit -def qm(x0, n): - x = np.empty(n+1) - x[0] = x0 - for t in range(n): - x[t+1] = α * x[t] * (1 - x[t]) - return x -``` - -这等价于在函数定义后添加 `qm = jit(qm)`。 - -以下代码现在使用 JIT 编译版本: - -```{code-cell} ipython3 -with qe.Timer(precision=4): - qm(0.1, 100_000) -``` - -```{code-cell} ipython3 -with qe.Timer(precision=4): - qm(0.1, 100_000) -``` - -Numba 还为装饰器提供了几个参数以加速计算和缓存函数——请参阅[这里](https://numba.readthedocs.io/en/stable/user/performance-tips.html)。 - ## 类型推断 成功的类型推断是 JIT 编译的关键部分。 @@ -258,153 +214,37 @@ Numba 也与 NumPy 数组配合良好,因为它们具有明确定义的类型 这使它能够生成本地机器码,而无需调用 Python 运行时环境。 -在这种情况下,Numba 将与低级语言的机器码相媲美。 - 当 Numba 无法推断所有类型信息时,它将引发错误。 -例如,在下面这个(人为的)示例中,Numba 在编译函数 `bootstrap` 时无法确定函数 `mean` 的类型 +例如,在下面的示例中,Numba 在编译 `iterate` 时无法确定函数 `g` 的类型 ```{code-cell} ipython3 @jit -def bootstrap(data, statistics, n): - bootstrap_stat = np.empty(n) - n = len(data) - for i in range(n_resamples): - resample = np.random.choice(data, size=n, replace=True) - bootstrap_stat[i] = statistics(resample) - return bootstrap_stat - -# 这里没有装饰器。 -def mean(data): - return np.mean(data) +def iterate(f, x0, n): + x = x0 + for t in range(n): + x = f(x) + return x -data = np.array((2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2)) -n_resamples = 10 +# 未经 jit 编译 +def g(x): + return np.cos(x) - 2 * np.sin(x) # 这段代码会抛出错误 try: - bootstrap(data, mean, n_resamples) + iterate(g, 0.5, 100) except Exception as e: print(e) ``` -在这种情况下,我们可以通过编译 `mean` 来轻松修复这个错误。 +我们可以通过编译 `g` 来轻松修复这个错误。 ```{code-cell} ipython3 @jit -def mean(data): - return np.mean(data) - -with qe.Timer(): - bootstrap(data, mean, n_resamples) -``` - -## 编译类 - -如上所述,目前 Numba 只能编译 Python 的一个子集。 - -然而,这个子集一直在扩展。 - -值得注意的是,Numba 现在在编译类方面相当有效。 - -如果一个类被成功编译,那么它的方法就像 JIT 编译的函数一样运行。 +def g(x): + return np.cos(x) - 2 * np.sin(x) -举一个例子,让我们考虑在{doc}`本讲座 `中创建的用于分析索洛-斯旺增长模型的类。 - -要编译这个类,我们使用 `@jitclass` 装饰器: - -```{code-cell} ipython3 -from numba import float64 -from numba.experimental import jitclass -``` - -注意,我们还导入了一个叫做 `float64` 的东西。 - -这是一种表示标准浮点数的数据类型。 - -我们在这里导入它是因为当 Numba 尝试处理类时,它需要一些关于类型的额外帮助。 - -以下是我们的代码: - -```{code-cell} ipython3 -solow_data = [ - ('n', float64), - ('s', float64), - ('δ', float64), - ('α', float64), - ('z', float64), - ('k', float64) -] - -@jitclass(solow_data) -class Solow: - r""" - Implements the Solow growth model with the update rule - - k_{t+1} = [(s z k^α_t) + (1 - δ)k_t] /(1 + n) - - """ - def __init__(self, n=0.05, # population growth rate - s=0.25, # savings rate - δ=0.1, # depreciation rate - α=0.3, # share of labor - z=2.0, # productivity - k=1.0): # current capital stock - - self.n, self.s, self.δ, self.α, self.z = n, s, δ, α, z - self.k = k - - def h(self): - "Evaluate the h function" - # 解包参数(去掉 self 以简化符号) - n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z - # 应用更新规则 - return (s * z * self.k**α + (1 - δ) * self.k) / (1 + n) - - def update(self): - "Update the current state (i.e., the capital stock)." - self.k = self.h() - - def steady_state(self): - "Compute the steady state value of capital." - # 解包参数(去掉 self 以简化符号) - n, s, δ, α, z = self.n, self.s, self.δ, self.α, self.z - # 计算并返回稳态 - return ((s * z) / (n + δ))**(1 / (1 - α)) - - def generate_sequence(self, t): - "Generate and return a time series of length t" - path = [] - for i in range(t): - path.append(self.k) - self.update() - return path -``` - -首先,我们在 `solow_data` 中指定了类的实例数据类型。 - -之后,将类定向为 JIT 编译只需在类定义前添加 `@jitclass(solow_data)` 即可。 - -当我们调用类中的方法时,这些方法就像函数一样被即时编译。 - -```{code-cell} ipython3 -s1 = Solow() -s2 = Solow(k=8.0) - -T = 60 -fig, ax = plt.subplots() - -# 绘制共同的稳态资本值 -ax.plot([s1.steady_state()]*T, 'k-', label='稳态') - -# 为每个经济体绘制时间序列 -for s in s1, s2: - lb = f'从初始状态 {s.k} 出发的资本序列' - ax.plot(s.generate_sequence(T), 'o-', lw=2, alpha=0.6, label=lb) -ax.set_ylabel('$k_{t}$', fontsize=12) -ax.set_xlabel('$t$', fontsize=12) -ax.legend() -plt.show() +iterate(g, 0.5, 100) ``` ## 危险与局限 @@ -449,16 +289,14 @@ print(add_a(10)) 当 Numba 为函数编译机器码时,它将全局变量视为常量,以确保类型稳定性。 +为了避免这种情况,请将值作为函数参数传递,而不是依赖全局变量。 + (multithreading)= ## Numba 中的多线程循环 -除了 JIT 编译之外,Numba 还为 CPU 上的并行计算提供了强大支持。 - -通过在多个 CPU 核心上分配计算任务,我们可以为许多数值算法实现显著的速度提升。 +除了 JIT 编译之外,Numba 还为 CPU 和 GPU 上的并行计算提供支持。 -Numba 中并行化的关键工具是 `prange` 函数,它告诉 Numba 在可用的 CPU 核心上并行执行循环迭代。 - -这种多线程方法适用于科学计算和定量经济学中的广泛问题。 +Numba 中 CPU 并行化的关键工具是 `prange` 函数,它告诉 Numba 在可用的 CPU 核心上并行执行循环迭代。 为了说明,让我们首先看一个简单的单线程(即非并行化)代码片段。 @@ -479,18 +317,15 @@ $$ 以下是代码: ```{code-cell} ipython3 -from numpy.random import randn -from numba import njit - -@njit +@jit def h(w, r=0.1, s=0.3, v1=0.1, v2=1.0): """ Updates household wealth. """ # 抽取冲击 - R = np.exp(v1 * randn()) * (1 + r) - y = np.exp(v2 * randn()) + R = np.exp(v1 * np.random.randn()) * (1 + r) + y = np.exp(v2 * np.random.randn()) # 更新财富 w = R * s * w + y @@ -516,29 +351,15 @@ plt.show() 现在,假设我们有一个庞大的家庭群体,并且想知道中位财富将是多少。 -这个问题很难用纸笔求解,因此我们将使用模拟。 - -具体来说,我们将模拟大量家庭,然后计算该群体的中位财富。 +这个问题很难用纸笔求解,因此我们将使用模拟: -假设我们对这一中位数随时间变化的长期平均值感兴趣。 - -事实证明,对于我们上面选择的参数规格,我们可以通过截取长时间模拟结束时群体中位财富的一个单期快照来计算这个值。 - -此外,只要模拟期足够长,初始条件就不重要。 - -* 这是由于一种称为遍历性的性质,我们将在[后面](https://python.quantecon.org/finite_markov.html#id15)讨论。 - -因此,总结来说,我们将通过以下方式模拟 50,000 个家庭: - -1. 任意设定初始财富为 1,以及 -1. 向前模拟 1,000 个时期。 - -然后我们将计算最终时期的中位财富。 +1. 向前模拟大量家庭 +2. 计算中位财富 以下是代码: ```{code-cell} ipython3 -@njit +@jit def compute_long_run_median(w0=1, T=1000, num_reps=50_000): obs = np.empty(num_reps) @@ -565,7 +386,7 @@ with qe.Timer(): ```{code-cell} ipython3 from numba import prange -@njit(parallel=True) +@jit(parallel=True) def compute_long_run_median_parallel(w0=1, T=1000, num_reps=50_000): obs = np.empty(num_reps) @@ -587,6 +408,10 @@ with qe.Timer(): 速度提升非常显著。 +请注意,我们是跨家庭而非跨时间进行并行化——单个家庭跨时期的更新本质上是顺序的。 + +关于基于 GPU 的并行化,请参阅我们的 {doc}`JAX 相关讲座 `。 + ## 练习 ```{exercise} @@ -922,4 +747,4 @@ def compute_call_price_parallel(β=β, 如果您使用的是具有多个 CPU 的机器,差异应该很显著。 ```{solution-end} -``` \ No newline at end of file +```