Skip to content

Commit 397a891

Browse files
Update rng usage in numba.md
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 3213613 commit 397a891

1 file changed

Lines changed: 25 additions & 14 deletions

File tree

lectures/numba.md

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,13 @@ Compare speed with and without Numba when the sample size is large.
454454
Here is one solution:
455455

456456
```{code-cell} ipython3
457+
rng = np.random.default_rng()
458+
457459
@jit
458-
def calculate_pi(n=1_000_000):
460+
def calculate_pi(rng, n=1_000_000):
459461
count = 0
460462
for i in range(n):
461-
u, v = np.random.uniform(0, 1), np.random.uniform(0, 1)
463+
u, v = rng.uniform(0, 1), rng.uniform(0, 1)
462464
d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2)
463465
if d < 0.5:
464466
count += 1
@@ -471,12 +473,12 @@ Now let's see how fast it runs:
471473

472474
```{code-cell} ipython3
473475
with qe.Timer():
474-
calculate_pi()
476+
calculate_pi(rng)
475477
```
476478

477479
```{code-cell} ipython3
478480
with qe.Timer():
479-
calculate_pi()
481+
calculate_pi(rng)
480482
```
481483

482484
If we switch off JIT compilation by removing `@jit`, the code takes around
@@ -550,10 +552,12 @@ p, q = 0.1, 0.2 # Prob of leaving low and high state respectively
550552
Here's a pure Python version of the function
551553

552554
```{code-cell} ipython3
553-
def compute_series(n):
555+
rng = np.random.default_rng()
556+
557+
def compute_series(n, rng):
554558
x = np.empty(n, dtype=np.int64)
555559
x[0] = 1 # Start in state 1
556-
U = np.random.uniform(0, 1, size=n)
560+
U = rng.uniform(0, 1, size=n)
557561
for t in range(1, n):
558562
current_x = x[t-1]
559563
if current_x == 0:
@@ -568,7 +572,7 @@ state is about 0.666
568572

569573
```{code-cell} ipython3
570574
n = 1_000_000
571-
x = compute_series(n)
575+
x = compute_series(n, rng)
572576
print(np.mean(x == 0)) # Fraction of time x is in state 0
573577
```
574578

@@ -578,7 +582,7 @@ Now let's time it:
578582

579583
```{code-cell} ipython3
580584
with qe.Timer():
581-
compute_series(n)
585+
compute_series(n, rng)
582586
```
583587

584588
Next let's implement a Numba version, which is easy
@@ -590,15 +594,15 @@ compute_series_numba = jit(compute_series)
590594
Let's check we still get the right numbers
591595

592596
```{code-cell} ipython3
593-
x = compute_series_numba(n)
597+
x = compute_series_numba(n, rng)
594598
print(np.mean(x == 0))
595599
```
596600

597601
Let's see the time
598602

599603
```{code-cell} ipython3
600604
with qe.Timer():
601-
compute_series_numba(n)
605+
compute_series_numba(n, rng)
602606
```
603607

604608
This is a nice speed improvement for one line of code!
@@ -636,11 +640,17 @@ For the size of the Monte Carlo simulation, use something substantial, such as
636640
Here is one solution:
637641

638642
```{code-cell} ipython3
643+
n = 1_000_000
644+
rng = np.random.default_rng()
645+
u_draws = rng.uniform(size=n)
646+
v_draws = rng.uniform(size=n)
647+
639648
@jit(parallel=True)
640-
def calculate_pi(n=1_000_000):
649+
def calculate_pi(u_draws, v_draws):
650+
n = len(u_draws)
641651
count = 0
642652
for i in prange(n):
643-
u, v = np.random.uniform(0, 1), np.random.uniform(0, 1)
653+
u, v = u_draws[i], v_draws[i]
644654
d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2)
645655
if d < 0.5:
646656
count += 1
@@ -653,12 +663,12 @@ Now let's see how fast it runs:
653663

654664
```{code-cell} ipython3
655665
with qe.Timer():
656-
calculate_pi()
666+
calculate_pi(u_draws, v_draws)
657667
```
658668

659669
```{code-cell} ipython3
660670
with qe.Timer():
661-
calculate_pi()
671+
calculate_pi(u_draws, v_draws)
662672
```
663673

664674
By switching parallelization on and off (selecting `True` or
@@ -773,6 +783,7 @@ def compute_call_price_parallel(β=β,
773783
s = np.log(S0)
774784
h = h0
775785
# Simulate forward in time
786+
# Draws are kept inside the loop to avoid pre-allocating large shock arrays.
776787
for t in range(n):
777788
s = s + μ + np.exp(h) * np.random.randn()
778789
h = ρ * h + ν * np.random.randn()

0 commit comments

Comments
 (0)