Skip to content

Commit 16ca96c

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent 7f1c42c commit 16ca96c

1 file changed

Lines changed: 82 additions & 106 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 82 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ translation:
1717
Vectorized operations::Parallelized Numba: Numba موازی شده
1818
Vectorized operations::Vectorized code with JAX: کد برداری شده با JAX
1919
Vectorized operations::JAX plus vmap: JAX به علاوه vmap
20-
Vectorized operations::JAX plus vmap::Version 1: نسخه 1
21-
Vectorized operations::vmap version 2: نسخه 2 vmap
2220
Vectorized operations::Summary: خلاصه
2321
Sequential operations: عملیات ترتیبی
2422
Sequential operations::Numba Version: نسخه Numba
@@ -27,7 +25,7 @@ translation:
2725
Overall recommendations: توصیه‌های کلی
2826
---
2927

30-
(parallel)=
28+
(numpy_numba_jax)=
3129
```{raw} jupyter
3230
<div id="qe-notebook-header" align="right" style="text-align:right;">
3331
<a href="https://quantecon.org/" title="quantecon.org">
@@ -156,7 +154,7 @@ for x in grid:
156154
grid = np.linspace(-3, 3, 3_000)
157155
x, y = np.meshgrid(grid, grid)
158156
159-
with qe.Timer(precision=8):
157+
with qe.Timer():
160158
z_max_numpy = np.max(f(x, y))
161159
162160
print(f"NumPy result: {z_max_numpy:.6f}")
@@ -179,13 +177,17 @@ def compute_max_numba(grid):
179177
for x in grid:
180178
for y in grid:
181179
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
182-
if z > m:
183-
m = z
180+
m = max(m, z)
184181
return m
182+
```
183+
184+
بیایید آن را آزمایش کنیم:
185185

186+
```{code-cell} ipython3
186187
grid = np.linspace(-3, 3, 3_000)
187188
188-
with qe.Timer(precision=8):
189+
with qe.Timer():
190+
# First run
189191
z_max_numba = compute_max_numba(grid)
190192
191193
print(f"Numba result: {z_max_numba:.6f}")
@@ -194,22 +196,23 @@ print(f"Numba result: {z_max_numba:.6f}")
194196
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود.
195197

196198
```{code-cell} ipython3
197-
with qe.Timer(precision=8):
199+
with qe.Timer():
200+
# Second run
198201
compute_max_numba(grid)
199202
```
200203

201-
بسته به دستگاه شما، نسخه Numba می‌تواند کمی کندتر یا کمی سریعتر از NumPy باشد.
204+
بسته به دستگاه شما، نسخه Numba ممکن است کندتر یا سریعتر از NumPy باشد.
202205

203-
از یک طرف، NumPy محاسبات کارآمد (مانند Numba) را با مقداری چندنخی (برخلاف این کد Numba) ترکیب می‌کند که مزیتی فراهم می‌کند.
206+
در اکثر موارد، Numba کمی بهتر است.
207+
208+
از یک طرف، NumPy محاسبات کارآمد را با مقداری چندنخی ترکیب می‌کند که مزیتی فراهم می‌کند.
204209

205210
از طرف دیگر، روال Numba از حافظه بسیار کمتری استفاده می‌کند، زیرا ما فقط با یک شبکه یک‌بعدی کار می‌کنیم.
206211

207212
### Numba موازی شده
208213

209214
حالا بیایید موازی‌سازی با Numba را با استفاده از `prange` امتحان کنیم:
210215

211-
در اینجا یک تلاش ساده و *نادرست* آمده است.
212-
213216
```{code-cell} ipython3
214217
@numba.jit(parallel=True)
215218
def compute_max_numba_parallel(grid):
@@ -220,57 +223,25 @@ def compute_max_numba_parallel(grid):
220223
x = grid[i]
221224
y = grid[j]
222225
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
223-
if z > m:
224-
m = z
226+
m = max(m, z)
225227
return m
226-
227228
```
228229

229-
معمولاً این نتیجه نادرستی برمی‌گرداند:
230-
231-
```{code-cell} ipython3
232-
z_max_parallel_incorrect = compute_max_numba_parallel(grid)
233-
print(f"Numba result: {z_max_parallel_incorrect} 😱")
234-
```
235-
236-
دلیل این است که متغیر `m` بین نخ‌ها مشترک است و به درستی کنترل نمی‌شود.
237-
238-
وقتی چندین نخ سعی می‌کنند همزمان `m` را بخوانند و بنویسند، با یکدیگر تداخل می‌کنند.
239-
240-
نخ‌ها مقادیر قدیمی `m` را می‌خوانند یا به‌روزرسانی‌های یکدیگر را بازنویسی می‌کنند --- یا `m` هرگز از مقدار اولیه خود به‌روزرسانی نمی‌شود.
241-
242-
در اینجا یک نسخه با دقت بیشتری نوشته شده است.
230+
در اینجا یک اجرای گرم‌کننده و آزمایش آمده است.
243231

244232
```{code-cell} ipython3
245-
@numba.jit(parallel=True)
246-
def compute_max_numba_parallel(grid):
247-
n = len(grid)
248-
row_maxes = np.empty(n)
249-
for i in numba.prange(n):
250-
row_max = -np.inf
251-
for j in range(n):
252-
x = grid[i]
253-
y = grid[j]
254-
z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
255-
if z > row_max:
256-
row_max = z
257-
row_maxes[i] = row_max
258-
return np.max(row_maxes)
259-
```
233+
with qe.Timer():
234+
# First run
235+
z_max_parallel = compute_max_numba_parallel(grid)
260236
261-
اکنون بلوک کدی که `for i in numba.prange(n)` روی آن عمل می‌کند بین `i` ها مستقل است.
262-
263-
هر نخ به یک عنصر جداگانه از آرایه `row_maxes` می‌نویسد و موازی‌سازی ایمن است.
264-
265-
```{code-cell} ipython3
266-
z_max_parallel = compute_max_numba_parallel(grid)
267237
print(f"Numba result: {z_max_parallel:.6f}")
268238
```
269239

270-
در اینجا زمان‌بندی آمده است.
240+
در اینجا زمان‌بندی برای نسخه از پیش کامپایل شده آمده است.
271241

272242
```{code-cell} ipython3
273-
with qe.Timer(precision=8):
243+
with qe.Timer():
244+
# Second run
274245
compute_max_numba_parallel(grid)
275246
```
276247

@@ -284,8 +255,7 @@ with qe.Timer(precision=8):
284255

285256
اما تفاوت‌هایی نیز وجود دارد که در اینجا آنها را برجسته می‌کنیم.
286257

287-
بیایید با تابع شروع کنیم.
288-
258+
بیایید با تابع شروع کنیم که `np` را به `jnp` تغییر می‌دهد و `jax.jit` را اضافه می‌کند.
289259

290260
```{code-cell} ipython3
291261
@jax.jit
@@ -299,9 +269,15 @@ def f(x, y):
299269
```{code-cell} ipython3
300270
grid = jnp.linspace(-3, 3, 3_000)
301271
x_mesh, y_mesh = jnp.meshgrid(grid, grid)
272+
```
273+
274+
حالا بیایید اجرا و زمان‌بندی کنیم
302275

303-
with qe.Timer(precision=8):
276+
```{code-cell} ipython3
277+
with qe.Timer():
278+
# First run
304279
z_max = jnp.max(f(x_mesh, y_mesh))
280+
# Hold interpreter
305281
z_max.block_until_ready()
306282
307283
print(f"Plain vanilla JAX result: {z_max:.6f}")
@@ -310,8 +286,10 @@ print(f"Plain vanilla JAX result: {z_max:.6f}")
310286
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود.
311287

312288
```{code-cell} ipython3
313-
with qe.Timer(precision=8):
289+
with qe.Timer():
290+
# Second run
314291
z_max = jnp.max(f(x_mesh, y_mesh))
292+
# Hold interpreter
315293
z_max.block_until_ready()
316294
```
317295

@@ -339,14 +317,14 @@ x_mesh.nbytes + y_mesh.nbytes
339317

340318
خوشبختانه، JAX رویکرد متفاوتی را با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) می‌پذیرد.
341319

342-
#### نسخه 1
320+
ایده `vmap` این است که برداری‌سازی را به مراحل تقسیم کند و تابعی که روی مقادیر منفرد عمل می‌کند را به تابعی تبدیل کند که روی آرایه‌ها عمل می‌کند.
343321

344-
در اینجا یک راه برای اعمال `vmap` آمده است.
322+
در اینجا نحوه اعمال آن به مسئله ما آمده است.
345323

346324
```{code-cell} ipython3
347-
# f را تنظیم کنید تا f(x, y) را در هر x برای هر y داده شده محاسبه کند
325+
# Set up f to compute f(x, y) at every x for any given y
348326
f_vec_x = lambda y: f(grid, y)
349-
# یک تابع دوم ایجاد کنید که این عملیات را روی تمام y برداری کند
327+
# Create a second function that vectorizes this operation over all y
350328
f_vec = jax.vmap(f_vec_x)
351329
```
352330

@@ -355,49 +333,37 @@ f_vec = jax.vmap(f_vec_x)
355333
بیایید زمان‌بندی را ببینیم:
356334

357335
```{code-cell} ipython3
358-
with qe.Timer(precision=8):
336+
with qe.Timer():
359337
z_max = jnp.max(f_vec(grid))
360338
z_max.block_until_ready()
361339
362340
print(f"JAX vmap v1 result: {z_max:.6f}")
363341
```
364342

365343
```{code-cell} ipython3
366-
with qe.Timer(precision=8):
344+
with qe.Timer():
367345
z_max = jnp.max(f_vec(grid))
368346
z_max.block_until_ready()
369347
```
370348

371-
با اجتناب از آرایه‌های ورودی بزرگ `x_mesh` و `y_mesh`، این نسخه `vmap` از حافظه بسیار کمتری استفاده می‌کند.
372-
373-
وقتی روی CPU اجرا می‌شود، زمان اجرای آن شبیه به نسخه meshgrid است.
374-
375-
وقتی روی GPU اجرا می‌شود، معمولاً به طور قابل توجهی سریعتر است.
376-
377-
در واقع، استفاده از `vmap` مزیت دیگری دارد: به ما اجازه می‌دهد برداری‌سازی را به مراحل تقسیم کنیم.
378-
379-
این منجر به کدی می‌شود که اغلب راحت‌تر از کد برداری شده سنتی قابل درک است.
349+
با اجتناب از آرایه‌های ورودی بزرگ `x_mesh` و `y_mesh`، این نسخه `vmap` از حافظه بسیار کمتری استفاده می‌کند و زمان اجرا نیز تغییر چندانی نمی‌کند.
380350

381-
ما این ایده‌ها را بیشتر هنگام حل مسائل بزرگتر بررسی خواهیم کرد.
351+
این خوب است --- اما هنوز از دستاوردهای سرعت بهره نمی‌بریم!
382352

383-
### نسخه 2 vmap
353+
اول توجه کنید که کد بالا آرایه دوبعدی کامل `f(x,y)` را محاسبه می‌کند که پیش از گرفتن حداکثر، سربارهایی ایجاد می‌کند.
384354

385-
می‌توانیم با استفاده از vmap همچنان کارآمدتر از نظر حافظه باشیم.
355+
دوم، فراخوانی `jnp.max` خارج از تابع JIT-compiled شده `f` قرار دارد، بنابراین کامپایلر نمی‌تواند این عملیات را در یک kernel واحد ادغام کند.
386356

387-
در حالی که در نسخه قبلی از آرایه‌های ورودی بزرگ اجتناب می‌کنیم، هنوز آرایه خروجی بزرگ `f(x,y)` را قبل از محاسبه حداکثر ایجاد می‌کنیم.
388-
389-
بیایید یک رویکرد کمی متفاوت را امتحان کنیم که max را به داخل می‌برد.
390-
391-
به دلیل این تغییر، ما هرگز آرایه دوبعدی `f(x,y)` را محاسبه نمی‌کنیم.
357+
می‌توانیم هر دو مشکل را با انتقال max به داخل و پوشش دادن همه چیز در یک `@jax.jit` واحد برطرف کنیم:
392358

393359
```{code-cell} ipython3
394360
@jax.jit
395-
def compute_max_vmap_v2(grid):
396-
# یک تابع بسازید که حداکثر را در امتداد هر سطر بگیرد
361+
def compute_max_vmap(grid):
362+
# Construct a function that takes the max along each row
397363
f_vec_x_max = lambda y: jnp.max(f(grid, y))
398-
# تابع را برداری کنید تا بتوانیم روی تمام سطرها همزمان فراخوانی کنیم
364+
# Vectorize the function so we can call on all rows simultaneously
399365
f_vec_max = jax.vmap(f_vec_x_max)
400-
# تابع برداری شده را فراخوانی کنید و حداکثر را بگیرید
366+
# Call the vectorized function and take the max
401367
return jnp.max(f_vec_max(grid))
402368
```
403369

@@ -408,24 +374,32 @@ def compute_max_vmap_v2(grid):
408374

409375
ما این تابع را روی تمام سطرها اعمال می‌کنیم و سپس حداکثر max های سطر را می‌گیریم.
410376

377+
از آنجایی که max را به داخل انتقال می‌دهیم، هرگز آرایه دوبعدی کامل `f(x,y)` را نمی‌سازیم و حافظه بیشتری ذخیره می‌کنیم.
378+
379+
و از آنجایی که همه چیز زیر یک `@jax.jit` واحد است، کامپایلر می‌تواند تمام عملیات را در یک kernel بهینه ادغام کند.
380+
411381
بیایید آن را امتحان کنیم.
412382

413383
```{code-cell} ipython3
414-
with qe.Timer(precision=8):
415-
z_max = compute_max_vmap_v2(grid).block_until_ready()
384+
with qe.Timer():
385+
# First run
386+
z_max = compute_max_vmap(grid)
387+
# Hold interpreter
388+
z_max.block_until_ready()
416389
417-
print(f"JAX vmap v2 result: {z_max:.6f}")
390+
print(f"JAX vmap result: {z_max:.6f}")
418391
```
419392

420393
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود:
421394

422395
```{code-cell} ipython3
423-
with qe.Timer(precision=8):
424-
z_max = compute_max_vmap_v2(grid).block_until_ready()
396+
with qe.Timer():
397+
# Second run
398+
z_max = compute_max_vmap(grid)
399+
# Hold interpreter
400+
z_max.block_until_ready()
425401
```
426402

427-
اگر این را روی GPU اجرا می‌کنید، همانطور که ما این کار را می‌کنیم، باید افزایش سرعت قابل توجه دیگری را ببینید.
428-
429403
### خلاصه
430404

431405
به نظر ما، JAX برنده برای عملیات برداری شده است.
@@ -467,14 +441,16 @@ def qm(x0, n, α=4.0):
467441
```{code-cell} ipython3
468442
n = 10_000_000
469443
470-
with qe.Timer(precision=8):
444+
with qe.Timer():
445+
# First run
471446
x = qm(0.1, n)
472447
```
473448

474449
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود:
475450

476451
```{code-cell} ipython3
477-
with qe.Timer(precision=8):
452+
with qe.Timer():
453+
# Second run
478454
x = qm(0.1, n)
479455
```
480456

@@ -493,7 +469,7 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
493469
```{code-cell} ipython3
494470
cpu = jax.devices("cpu")[0]
495471
496-
@partial(jax.jit, static_argnums=(1,), device=cpu)
472+
@partial(jax.jit, static_argnames=('n',), device=cpu)
497473
def qm_jax(x0, n, α=4.0):
498474
def update(x, t):
499475
x_new = α * x * (1 - x)
@@ -506,32 +482,32 @@ def qm_jax(x0, n, α=4.0):
506482
این کد خواندن آسانی ندارد اما، در اصل، `lax.scan` به طور مکرر `update` را فراخوانی می‌کند و بازگشت‌های `x_new` را در یک آرایه جمع می‌کند.
507483

508484
```{note}
509-
خوانندگان تیزبین متوجه خواهند شد که ما `device=cpu` را در decorator `jax.jit` مشخص می‌کنیم.
510-
511-
محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهره‌برداری GPU از موازی‌سازی باقی می‌گذارد.
512-
513-
در نتیجه، سربار راه‌اندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسب‌تر برای این بار کاری می‌کند.
514-
515-
خوانندگان کنجکاو می‌توانند حذف این گزینه را امتحان کنند تا ببینند چگونه عملکرد تغییر می‌کند.
485+
ما `device=cpu` را در decorator `jax.jit` مشخص می‌کنیم زیرا این محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهره‌برداری GPU از موازی‌سازی باقی می‌گذارد. در نتیجه، سربار راه‌اندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسب‌تر برای این بار کاری می‌کند.
516486
```
517487

518488
بیایید آن را با همان پارامترها زمان‌بندی کنیم:
519489

520490
```{code-cell} ipython3
521-
with qe.Timer(precision=8):
522-
x_jax = qm_jax(0.1, n).block_until_ready()
491+
with qe.Timer():
492+
# First run
493+
x_jax = qm_jax(0.1, n)
494+
# Hold interpreter
495+
x_jax.block_until_ready()
523496
```
524497

525498
بیایید دوباره اجرا کنیم تا سربار کامپایل حذف شود:
526499

527500
```{code-cell} ipython3
528-
with qe.Timer(precision=8):
529-
x_jax = qm_jax(0.1, n).block_until_ready()
501+
with qe.Timer():
502+
# Second run
503+
x_jax = qm_jax(0.1, n)
504+
# Hold interpreter
505+
x_jax.block_until_ready()
530506
```
531507

532508
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
533509

534-
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند، با این که Numba معمولاً (اما نه همیشه) سرعت‌های کمی بهتری در عملیات کاملاً ترتیبی ارائه می‌دهد.
510+
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند.
535511

536512
### خلاصه
537513

@@ -545,7 +521,7 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
545521

546522
علاوه بر این، آرایه‌های تغییرناپذیر JAX به این معنی است که نمی‌توانیم به سادگی عناصر آرایه را در جا به‌روزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت می‌کند.
547523

548-
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی، و همچنین عملکرد بالا است.
524+
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی است.
549525

550526
## توصیه‌های کلی
551527

0 commit comments

Comments
 (0)