Skip to content

Commit be6eeae

Browse files
authored
Small edits (#526)
* misc * misc
1 parent a897eb8 commit be6eeae

File tree

1 file changed

+106
-89
lines changed

1 file changed

+106
-89
lines changed

lectures/numba.md

Lines changed: 106 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -42,37 +42,39 @@ import matplotlib.pyplot as plt
4242
```
4343

4444

45-
4645
## Overview
4746

4847
In an {doc}`earlier lecture <need_for_speed>` we discussed vectorization,
4948
which can improve execution speed by sending array processing operations in batch to efficient low-level code.
5049

5150
However, as {ref}`discussed in that lecture <numba-p_c_vectorization>`,
52-
traditional vectorization schemes, such as those found in MATLAB and NumPy, have weaknesses.
51+
traditional vectorization schemes have weaknesses:
5352

5453
* Highly memory-intensive for compound array operations
55-
* Ineffective or impossible for some algorithms.
54+
* Ineffective or impossible for some algorithms
5655

5756
One way to circumvent these problems is by using [Numba](https://numba.pydata.org/), a
58-
**just in time (JIT) compiler** for Python that is oriented towards numerical work.
57+
**just in time (JIT) compiler** for Python.
5958

60-
Numba compiles functions to native machine code instructions during runtime.
59+
Numba compiles functions to native machine code instructions at runtime.
6160

6261
When it succeeds, the result is performance comparable to compiled C or Fortran.
6362

64-
In addition, Numba can do other useful tricks, such as {ref}`multithreading` or
65-
interfacing with GPUs (through `numba.cuda`).
63+
In addition, Numba can do useful tricks such as {ref}`multithreading <multithreading>`.
6664

67-
Numba's JIT compiler is in many ways similar to the JIT compiler in Julia
65+
This lecture introduces the core ideas.
6866

69-
The main difference is that it is less ambitious, attempting to compile a smaller subset of the language.
7067

71-
Although this might sound like a deficiency, it is in some ways an advantage.
68+
```{note}
69+
Some readers might be curious about the relationship between Numba and [Julia](https://julialang.org/),
70+
which contains its own JIT compiler. While the two compilers are similar in
71+
many ways, Numba is less ambitious, attempting only to compile a small subset of
72+
the Python language. Although this might sound like a deficiency, it is also a
73+
strength: the more restrictive nature of Numba makes it easy to use well and
74+
good at what it does.
75+
```
7276

73-
Numba is lean, easy to use, and very good at what it does.
7477

75-
This lecture introduces the core ideas.
7678

7779
(numba_link)=
7880
## {index}`Compiling Functions <single: Compiling Functions>`
@@ -93,16 +95,14 @@ $$
9395
x_{t+1} = \alpha x_t (1 - x_t)
9496
$$
9597

96-
In what follows we set
98+
In what follows we set $\alpha = 4$.
9799

98-
```{code-cell} ipython3
99-
α = 4.0
100-
```
100+
#### Base Version
101101

102102
Here's the plot of a typical trajectory, starting from $x_0 = 0.1$, with $t$ on the x-axis
103103

104104
```{code-cell} ipython3
105-
def qm(x0, n):
105+
def qm(x0, n, α=4.0):
106106
x = np.empty(n+1)
107107
x[0] = x0
108108
for t in range(n):
@@ -117,103 +117,119 @@ ax.set_ylabel('$x_{t}$', fontsize = 12)
117117
plt.show()
118118
```
119119

120-
To speed the function `qm` up using Numba, our first step is
120+
Let's see how long this takes to run for large $n$
121121

122122
```{code-cell} ipython3
123-
from numba import jit
123+
n = 10_000_000
124+
125+
with qe.Timer() as timer1:
126+
# Time Python base version
127+
x = qm(0.1, int(n))
124128
125-
qm_numba = jit(qm)
126129
```
127130

128-
The function `qm_numba` is a version of `qm` that is "targeted" for
129-
JIT-compilation.
130131

131-
We will explain what this means momentarily.
132+
#### Acceleration via Numba
133+
134+
To speed the function `qm` up using Numba, we first import the `jit` function
132135

133-
Let's time and compare identical function calls across these two versions, starting with the original function `qm`:
134136

135137
```{code-cell} ipython3
136-
n = 10_000_000
138+
from numba import jit
139+
```
137140

138-
with qe.Timer() as timer1:
139-
qm(0.1, int(n))
140-
time1 = timer1.elapsed
141+
Now we apply it to `qm`, producing a new function:
142+
143+
```{code-cell} ipython3
144+
qm_numba = jit(qm)
141145
```
142146

143-
Now let's try qm_numba
147+
The function `qm_numba` is a version of `qm` that is "targeted" for
148+
JIT-compilation.
149+
150+
We will explain what this means momentarily.
151+
152+
Let's time this new version:
144153

145154
```{code-cell} ipython3
146155
with qe.Timer() as timer2:
147-
qm_numba(0.1, int(n))
148-
time2 = timer2.elapsed
156+
# Time jitted version
157+
x = qm_numba(0.1, int(n))
149158
```
150159

151-
This is already a very large speed gain.
160+
This is a large speed gain.
152161

153-
In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory:
162+
In fact, the next time and all subsequent times it runs even faster as the
163+
function has been compiled and is in memory:
154164

155165
(qm_numba_result)=
156166

157167
```{code-cell} ipython3
158168
with qe.Timer() as timer3:
159-
qm_numba(0.1, int(n))
160-
time3 = timer3.elapsed
169+
# Second run
170+
x = qm_numba(0.1, int(n))
161171
```
162172

173+
Here's the speed gain
174+
163175
```{code-cell} ipython3
164-
time1 / time3 # Calculate speed gain
176+
timer1.elapsed / timer3.elapsed
165177
```
166178

179+
This is a big boost for a small modification to our original code.
180+
181+
Let's discuss how this works.
167182

168183
### How and When it Works
169184

170-
Numba attempts to generate fast machine code using the infrastructure provided by the [LLVM Project](https://llvm.org/).
185+
Numba attempts to generate fast machine code using the infrastructure provided
186+
by the [LLVM Project](https://llvm.org/).
171187

172188
It does this by inferring type information on the fly.
173189

174190
(See our {doc}`earlier lecture <need_for_speed>` on scientific computing for a discussion of types.)
175191

176192
The basic idea is this:
177193

178-
* Python is very flexible and hence we could call the function qm with many
179-
types.
194+
* Python is very flexible and hence we could call the function qm with many types.
180195
* e.g., `x0` could be a NumPy array or a list, `n` could be an integer or a float, etc.
181196
* This makes it very difficult to generate efficient machine code *ahead of time* (i.e., before runtime).
182197
* However, when we do actually *call* the function, say by running `qm(0.5, 10)`,
183-
the types of `x0` and `n` become clear.
198+
the types of `x0`, `α` and `n` are determined.
184199
* Moreover, the types of *other variables* in `qm` *can be inferred once the input types are known*.
185200
* So the strategy of Numba and other JIT compilers is to *wait until the function is called*, and then compile.
186201

187202
That is called "just-in-time" compilation.
188203

189-
Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9,
190-
20)`, compilation only takes place on the first call.
204+
Note that, if you make the call `qm_numba(0.5, 10)` and then follow it with `qm_numba(0.9, 20)`, compilation only takes place on the first call.
191205

192206
This is because compiled code is cached and reused as required.
193207

194-
This is why, in the code above, `time3` is smaller than `time2`.
208+
This is why, in the code above, the second run of `qm_numba` is faster.
195209

196210
```{admonition} Remark
197-
In practice, rather than writing `qm_numba = jit(qm)`, we use *decorator* syntax and put `@jit` before the function definition. This is equivalent to adding `qm = jit(qm)` after the definition. We use this syntax throughout the rest of the lecture. (See {doc}`python_advanced_features` for more on decorators.)
211+
In practice, rather than writing `qm_numba = jit(qm)`, we typically use
212+
*decorator* syntax and put `@jit` before the function definition. This is
213+
equivalent to adding `qm = jit(qm)` after the definition.
198214
```
199215

200216

201-
## Type Inference
217+
## Sharp Bits
202218

203-
Successful type inference is a key part of JIT compilation.
219+
Numba is relatively easy to use but not always seamless.
204220

205-
As you can imagine, inferring types is easier for simple Python objects (e.g.,
206-
simple scalar data types such as floats and integers).
221+
Let's review some of the issues users run into.
207222

208-
Numba also plays well with NumPy arrays, which have well-defined types.
223+
### Typing
209224

210-
In an ideal setting, Numba can infer all necessary type information.
225+
Successful type inference is the key to JIT compilation.
211226

212-
This allows it to generate efficient native machine code, without having to call the Python runtime environment.
227+
In an ideal setting, Numba can infer all necessary type information.
213228

214-
When Numba cannot infer all type information, it will raise an error.
229+
When Numba *cannot* infer all type information, it will raise an error.
215230

216-
For example, in the setting below, Numba is unable to determine the type of the function `g` when compiling `iterate`
231+
For example, in the setting below, Numba is unable to determine the type of the
232+
function `g` when compiling `iterate`
217233

218234
```{code-cell} ipython3
219235
@jit
@@ -234,7 +250,7 @@ except Exception as e:
234250
print(e)
235251
```
236252

237-
We can fix this easily by compiling `g`.
253+
In the present case, we can fix this easily by compiling `g`.
238254

239255
```{code-cell} ipython3
240256
@jit
@@ -244,28 +260,16 @@ def g(x):
244260
iterate(g, 0.5, 100)
245261
```
246262

263+
In other cases, such as when we want to use functions from external libaries
264+
such as `SciPy`, there might not be any easy workaround.
247265

248-
## Dangers and Limitations
249-
250-
Let's add some cautionary notes.
251-
252-
### Limitations
253-
254-
As we've seen, Numba needs to infer type information on
255-
all variables to generate fast machine-level instructions.
256266

257-
For large routines or those using external libraries, this process can easily fail.
267+
### Global Variables
258268

259-
Hence, it's best to focus on speeding up small, time-critical snippets of code.
269+
Another thing to be careful about when using Numba is handling of global
270+
variables.
260271

261-
This will give you much better performance than blanketing your Python programs with `@jit` statements.
262-
263-
264-
### A Gotcha: Global Variables
265-
266-
Here's another thing to be careful about when using Numba.
267-
268-
Consider the following example
272+
For example, consider the following code
269273

270274
```{code-cell} ipython3
271275
a = 1
@@ -284,9 +288,10 @@ print(add_a(10))
284288
```
285289

286290
Notice that changing the global had no effect on the value returned by the
287-
function.
291+
function 😱.
288292

289-
When Numba compiles machine code for functions, it treats global variables as constants to ensure type stability.
293+
When Numba compiles machine code for functions, it treats global variables as
294+
constants to ensure type stability.
290295

291296
To avoid this, pass values as function arguments rather than relying on globals.
292297

@@ -320,15 +325,11 @@ Here's the code:
320325

321326
```{code-cell} ipython3
322327
@jit
323-
def h(w, r=0.1, s=0.3, v1=0.1, v2=1.0):
324-
"""
325-
Updates household wealth.
326-
"""
327-
328+
def update(w, r=0.1, s=0.3, v1=0.1, v2=1.0):
329+
" Updates household wealth. "
328330
# Draw shocks
329331
R = np.exp(v1 * np.random.randn()) * (1 + r)
330332
y = np.exp(v2 * np.random.randn())
331-
332333
# Update wealth
333334
w = R * s * w + y
334335
return w
@@ -343,7 +344,7 @@ T = 100
343344
w = np.empty(T)
344345
w[0] = 5
345346
for t in range(T-1):
346-
w[t+1] = h(w[t])
347+
w[t+1] = update(w[t])
347348
348349
ax.plot(w)
349350
ax.set_xlabel('$t$', fontsize=12)
@@ -365,21 +366,30 @@ Here's the code:
365366
```{code-cell} ipython3
366367
@jit
367368
def compute_long_run_median(w0=1, T=1000, num_reps=50_000):
368-
369369
obs = np.empty(num_reps)
370+
# For each household
370371
for i in range(num_reps):
372+
# Set the initial condition and run forward in time
371373
w = w0
372374
for t in range(T):
373-
w = h(w)
375+
w = update(w)
376+
# Record the final value
374377
obs[i] = w
375-
378+
# Take the median of all final values
376379
return np.median(obs)
377380
```
378381

379382
Let's see how fast this runs:
380383

381384
```{code-cell} ipython3
382385
with qe.Timer():
386+
# Warm up
387+
compute_long_run_median()
388+
```
389+
390+
```{code-cell} ipython3
391+
with qe.Timer():
392+
# Second run
383393
compute_long_run_median()
384394
```
385395

@@ -391,22 +401,29 @@ To do so, we add the `parallel=True` flag and change `range` to `prange`:
391401
from numba import prange
392402
393403
@jit(parallel=True)
394-
def compute_long_run_median_parallel(w0=1, T=1000, num_reps=50_000):
395-
404+
def compute_long_run_median_parallel(
405+
w0=1, T=1000, num_reps=50_000
406+
):
396407
obs = np.empty(num_reps)
397-
for i in prange(num_reps):
408+
for i in prange(num_reps): # Parallelize over households
398409
w = w0
399410
for t in range(T):
400-
w = h(w)
411+
w = update(w)
401412
obs[i] = w
402-
403413
return np.median(obs)
404414
```
405415

406416
Let's look at the timing:
407417

408418
```{code-cell} ipython3
409419
with qe.Timer():
420+
# Warm up
421+
compute_long_run_median_parallel()
422+
```
423+
424+
```{code-cell} ipython3
425+
with qe.Timer():
426+
# Second run
410427
compute_long_run_median_parallel()
411428
```
412429

0 commit comments

Comments
 (0)