From 8c655616329e7f5dfc64ff36338210c86e528bdb Mon Sep 17 00:00:00 2001 From: John Stachurski Date: Fri, 7 Mar 2025 16:40:02 +0900 Subject: [PATCH] misc --- lectures/numba.md | 97 ++++++++++++++++------------------------------- 1 file changed, 33 insertions(+), 64 deletions(-) diff --git a/lectures/numba.md b/lectures/numba.md index cc8794c8..fb458daf 100644 --- a/lectures/numba.md +++ b/lectures/numba.md @@ -4,7 +4,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.14.4 + jupytext_version: 1.16.7 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -118,9 +118,9 @@ plt.show() To speed the function `qm` up using Numba, our first step is ```{code-cell} ipython3 -from numba import njit +from numba import jit -qm_numba = njit(qm) +qm_numba = jit(qm) ``` The function `qm_numba` is a version of `qm` that is "targeted" for @@ -146,7 +146,7 @@ qm_numba(0.1, int(n)) time2 = qe.toc() ``` -This is already a massive speed gain. +This is already a very large speed gain. In fact, the next time and all subsequent times it runs even faster as the function has been compiled and is in memory: @@ -162,7 +162,7 @@ time3 = qe.toc() time1 / time3 # Calculate speed gain ``` -This kind of speed gain is huge relative to how simple and clear the implementation is. +This kind of speed gain is impressive relative to how simple and clear the modification is. ### How and When it Works @@ -177,10 +177,10 @@ The basic idea is this: * Python is very flexible and hence we could call the function qm with many types. * e.g., `x0` could be a NumPy array or a list, `n` could be an integer or a float, etc. -* This makes it hard to *pre*-compile the function. -* However, when we do actually call the function, say by executing `qm(0.5, 10)`, +* This makes it hard to *pre*-compile the function (i.e., compile before runtime). +* However, when we do actually call the function, say by running `qm(0.5, 10)`, the types of `x0` and `n` become clear. -* Moreover, the types of other variables in `qm` can be inferred once the input is known. +* Moreover, the types of other variables in `qm` can be inferred once the input types are known. * So the strategy of Numba and other JIT compilers is to wait until this moment, and *then* compile the function. @@ -190,26 +190,28 @@ Note that, if you make the call `qm(0.5, 10)` and then follow it with `qm(0.9, 2 The compiled code is then cached and recycled as required. +This is why, in the code above, `time3` is smaller than `time2`. + ## Decorator Notation In the code above we created a JIT compiled version of `qm` via the call ```{code-cell} ipython3 -qm_numba = njit(qm) +qm_numba = jit(qm) ``` In practice this would typically be done using an alternative *decorator* syntax. -(We will explain all about decorators in a {doc}`later lecture ` but you can skip the details at this stage.) +(We discuss decorators in a {doc}`separate lecture ` but you can skip the details at this stage.) Let's see how this is done. -To target a function for JIT compilation we can put `@njit` before the function definition. +To target a function for JIT compilation we can put `@jit` before the function definition. Here's what this looks like for `qm` ```{code-cell} ipython3 -@njit +@jit def qm(x0, n): x = np.empty(n+1) x[0] = x0 @@ -218,7 +220,7 @@ def qm(x0, n): return x ``` -This is equivalent to `qm = njit(qm)`. +This is equivalent to adding `qm = jit(qm)` after the function definition. The following now uses the jitted version: @@ -228,13 +230,19 @@ The following now uses the jitted version: qm(0.1, 100_000) ``` -Numba provides several arguments for decorators to accelerate computation and cache functions [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html). +```{code-cell} ipython3 +%%time + +qm(0.1, 100_000) +``` + +Numba also provides several arguments for decorators to accelerate computation and cache functions -- see [here](https://numba.readthedocs.io/en/stable/user/performance-tips.html). In the [following lecture on parallelization](parallel), we will discuss how to use the `parallel` argument to achieve automatic parallelization. ## Type Inference -Clearly type inference is a key part of JIT compilation. +Successful type inference is a key part of JIT compilation. As you can imagine, inferring types is easier for simple Python objects (e.g., simple scalar data types such as floats and integers). @@ -248,10 +256,10 @@ In such a setting, Numba will be on par with machine code from low-level languag When Numba cannot infer all type information, it will raise an error. -For example, in the case below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap` +For example, in the (artificial) setting below, Numba is unable to determine the type of function `mean` when compiling the function `bootstrap` ```{code-cell} ipython3 -@njit +@jit def bootstrap(data, statistics, n): bootstrap_stat = np.empty(n) n = len(data) @@ -260,69 +268,30 @@ def bootstrap(data, statistics, n): bootstrap_stat[i] = statistics(resample) return bootstrap_stat +# No decorator here. def mean(data): return np.mean(data) -data = np.array([2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2]) +data = np.array((2.3, 3.1, 4.3, 5.9, 2.1, 3.8, 2.2)) n_resamples = 10 -print('Type of function:', type(mean)) - -#Error +# This code throws an error try: bootstrap(data, mean, n_resamples) except Exception as e: print(e) ``` -But Numba recognizes JIT-compiled functions +We can fix this error easily in this case by compiling `mean`. ```{code-cell} ipython3 -@njit +@jit def mean(data): return np.mean(data) -print('Type of function:', type(mean)) - %time bootstrap(data, mean, n_resamples) ``` -We can check the signature of the JIT-compiled function - -```{code-cell} ipython3 -bootstrap.signatures -``` - -The function `bootstrap` takes one `float64` floating point array, one function called `mean` and an `int64` integer. - -Now let's see what happens when we change the inputs. - -Running it again with a larger integer for `n` and a different set of data does not change the signature of the function. - -```{code-cell} ipython3 -data = np.array([4.1, 1.1, 2.3, 1.9, 0.1, 2.8, 1.2]) -%time bootstrap(data, mean, 100) -bootstrap.signatures -``` - -As expected, the second run is much faster. - -Let's try to change the data again and use an integer array as data - -```{code-cell} ipython3 -data = np.array([1, 2, 3, 4, 5], dtype=np.int64) -%time bootstrap(data, mean, 100) -bootstrap.signatures -``` - -Note that a second signature is added. - -It also takes longer to run, suggesting that Numba recompiles this function as the type changes. - -Overall, type inference helps Numba to achieve its performance, but it also limits what Numba supports and sometimes requires careful type checks. - -You can refer to the list of supported Python and Numpy features [here](https://numba.pydata.org/numba-doc/dev/reference/pysupported.html). - ## Compiling Classes As mentioned above, at present Numba can only compile a subset of Python. @@ -509,7 +478,7 @@ Consider the following example ```{code-cell} ipython3 a = 1 -@njit +@jit def add_a(x): return a + x @@ -549,7 +518,7 @@ Here is one solution: ```{code-cell} ipython3 from random import uniform -@njit +@jit def calculate_pi(n=1_000_000): count = 0 for i in range(n): @@ -677,7 +646,7 @@ qe.toc() Next let's implement a Numba version, which is easy ```{code-cell} ipython3 -compute_series_numba = njit(compute_series) +compute_series_numba = jit(compute_series) ``` Let's check we still get the right numbers