@@ -454,13 +454,17 @@ Compare speed with and without Numba when the sample size is large.
454454Here is one solution:
455455
456456``` {code-cell} ipython3
457+ n = 1_000_000
457458rng = np.random.default_rng()
459+ u_draws = rng.uniform(size=n)
460+ v_draws = rng.uniform(size=n)
458461
459462@jit
460- def calculate_pi(rng, n=1_000_000):
463+ def calculate_pi(u_draws, v_draws):
464+ n = len(u_draws)
461465 count = 0
462466 for i in range(n):
463- u, v = rng.uniform(0, 1), rng.uniform(0, 1)
467+ u, v = u_draws[i], v_draws[i]
464468 d = np.sqrt((u - 0.5)**2 + (v - 0.5)**2)
465469 if d < 0.5:
466470 count += 1
@@ -473,12 +477,12 @@ Now let's see how fast it runs:
473477
474478``` {code-cell} ipython3
475479with qe.Timer():
476- calculate_pi(rng )
480+ calculate_pi(u_draws, v_draws )
477481```
478482
479483``` {code-cell} ipython3
480484with qe.Timer():
481- calculate_pi(rng )
485+ calculate_pi(u_draws, v_draws )
482486```
483487
484488If we switch off JIT compilation by removing ` @jit ` , the code takes around
@@ -552,12 +556,13 @@ p, q = 0.1, 0.2 # Prob of leaving low and high state respectively
552556Here's a pure Python version of the function
553557
554558``` {code-cell} ipython3
559+ n = 1_000_000
555560rng = np.random.default_rng()
561+ U = rng.uniform(0, 1, size=n)
556562
557- def compute_series(n, rng ):
563+ def compute_series(n, U ):
558564 x = np.empty(n, dtype=np.int64)
559565 x[0] = 1 # Start in state 1
560- U = rng.uniform(0, 1, size=n)
561566 for t in range(1, n):
562567 current_x = x[t-1]
563568 if current_x == 0:
@@ -571,8 +576,7 @@ Let's run this code and check that the fraction of time spent in the low
571576state is about 0.666
572577
573578``` {code-cell} ipython3
574- n = 1_000_000
575- x = compute_series(n, rng)
579+ x = compute_series(n, U)
576580print(np.mean(x == 0)) # Fraction of time x is in state 0
577581```
578582
@@ -582,7 +586,7 @@ Now let's time it:
582586
583587``` {code-cell} ipython3
584588with qe.Timer():
585- compute_series(n, rng )
589+ compute_series(n, U )
586590```
587591
588592Next let's implement a Numba version, which is easy
@@ -594,15 +598,15 @@ compute_series_numba = jit(compute_series)
594598Let's check we still get the right numbers
595599
596600``` {code-cell} ipython3
597- x = compute_series_numba(n, rng )
601+ x = compute_series_numba(n, U )
598602print(np.mean(x == 0))
599603```
600604
601605Let's see the time
602606
603607``` {code-cell} ipython3
604608with qe.Timer():
605- compute_series_numba(n, rng )
609+ compute_series_numba(n, U )
606610```
607611
608612This is a nice speed improvement for one line of code!
760764
761765Using this fact, the solution can be written as follows.
762766
767+ Note that random draws are kept inside the inner loop rather than pre-allocated,
768+ to avoid creating large shock arrays of size ` M * n ` .
769+
763770
764771``` {code-cell} ipython3
765772M = 10_000_000
@@ -783,7 +790,6 @@ def compute_call_price_parallel(β=β,
783790 s = np.log(S0)
784791 h = h0
785792 # Simulate forward in time
786- # Draws are kept inside the loop to avoid pre-allocating large shock arrays.
787793 for t in range(n):
788794 s = s + μ + np.exp(h) * np.random.randn()
789795 h = ρ * h + ν * np.random.randn()
0 commit comments