Skip to content

Commit aef82bd

Browse files
Address reviewer feedback on rng usage in numba.md
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 397a891 commit aef82bd

1 file changed

Lines changed: 18 additions & 12 deletions

File tree

lectures/numba.md

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -454,13 +454,17 @@ Compare speed with and without Numba when the sample size is large.
454454
Here is one solution:
455455

456456
```{code-cell} ipython3
457+
n = 1_000_000
457458
rng = np.random.default_rng()
459+
u_draws = rng.uniform(size=n)
460+
v_draws = rng.uniform(size=n)
458461
459462
@jit
460-
def calculate_pi(rng, n=1_000_000):
463+
def calculate_pi(u_draws, v_draws):
464+
n = len(u_draws)
461465
count = 0
462466
for i in range(n):
463-
u, v = rng.uniform(0, 1), rng.uniform(0, 1)
467+
u, v = u_draws[i], v_draws[i]
464468
d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2)
465469
if d < 0.5:
466470
count += 1
@@ -473,12 +477,12 @@ Now let's see how fast it runs:
473477

474478
```{code-cell} ipython3
475479
with qe.Timer():
476-
calculate_pi(rng)
480+
calculate_pi(u_draws, v_draws)
477481
```
478482

479483
```{code-cell} ipython3
480484
with qe.Timer():
481-
calculate_pi(rng)
485+
calculate_pi(u_draws, v_draws)
482486
```
483487

484488
If we switch off JIT compilation by removing `@jit`, the code takes around
@@ -552,12 +556,13 @@ p, q = 0.1, 0.2 # Prob of leaving low and high state respectively
552556
Here's a pure Python version of the function
553557

554558
```{code-cell} ipython3
559+
n = 1_000_000
555560
rng = np.random.default_rng()
561+
U = rng.uniform(0, 1, size=n)
556562
557-
def compute_series(n, rng):
563+
def compute_series(n, U):
558564
x = np.empty(n, dtype=np.int64)
559565
x[0] = 1 # Start in state 1
560-
U = rng.uniform(0, 1, size=n)
561566
for t in range(1, n):
562567
current_x = x[t-1]
563568
if current_x == 0:
@@ -571,8 +576,7 @@ Let's run this code and check that the fraction of time spent in the low
571576
state is about 0.666
572577

573578
```{code-cell} ipython3
574-
n = 1_000_000
575-
x = compute_series(n, rng)
579+
x = compute_series(n, U)
576580
print(np.mean(x == 0)) # Fraction of time x is in state 0
577581
```
578582

@@ -582,7 +586,7 @@ Now let's time it:
582586

583587
```{code-cell} ipython3
584588
with qe.Timer():
585-
compute_series(n, rng)
589+
compute_series(n, U)
586590
```
587591

588592
Next let's implement a Numba version, which is easy
@@ -594,15 +598,15 @@ compute_series_numba = jit(compute_series)
594598
Let's check we still get the right numbers
595599

596600
```{code-cell} ipython3
597-
x = compute_series_numba(n, rng)
601+
x = compute_series_numba(n, U)
598602
print(np.mean(x == 0))
599603
```
600604

601605
Let's see the time
602606

603607
```{code-cell} ipython3
604608
with qe.Timer():
605-
compute_series_numba(n, rng)
609+
compute_series_numba(n, U)
606610
```
607611

608612
This is a nice speed improvement for one line of code!
@@ -760,6 +764,9 @@ $$
760764

761765
Using this fact, the solution can be written as follows.
762766

767+
Note that random draws are kept inside the inner loop rather than pre-allocated,
768+
to avoid creating large shock arrays of size `M * n`.
769+
763770

764771
```{code-cell} ipython3
765772
M = 10_000_000
@@ -783,7 +790,6 @@ def compute_call_price_parallel(β=β,
783790
s = np.log(S0)
784791
h = h0
785792
# Simulate forward in time
786-
# Draws are kept inside the loop to avoid pre-allocating large shock arrays.
787793
for t in range(n):
788794
s = s + μ + np.exp(h) * np.random.randn()
789795
h = ρ * h + ν * np.random.randn()

0 commit comments

Comments
 (0)