Skip to content

Commit 777c279

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent d5205d6 commit 777c279

1 file changed

Lines changed: 60 additions & 21 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,23 @@ translation:
3636

3737
# NumPy در مقابل Numba در مقابل JAX
3838

39-
در سخنرانی‌های قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
39+
در درس‌های قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
4040

4141
* [NumPy](numpy)
4242
* [Numba](numba)
4343
* [JAX](jax_intro)
4444

4545
کدام یک را باید در هر موقعیت استفاده کنیم؟
4646

47-
این سخنرانی به آن سؤال پاسخ می‌دهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
47+
این درس به آن سؤال پاسخ می‌دهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
4848

4949
قبل از شروع، توجه می‌کنیم که دو مورد اول یک جفت طبیعی هستند: NumPy و Numba به خوبی با هم کار می‌کنند.
5050

5151
JAX، از سوی دیگر، به تنهایی می‌ایستد.
5252

5353
هنگام بررسی هر رویکرد، نه تنها کارایی و رد پای حافظه، بلکه وضوح و سهولت استفاده را نیز در نظر خواهیم گرفت.
5454

55-
علاوه بر آنچه در Anaconda موجود است، این سخنرانی به کتابخانه‌های زیر نیاز دارد:
55+
علاوه بر آنچه در Anaconda موجود است، این درس به کتابخانه‌های زیر نیاز دارد:
5656

5757
```{code-cell} ipython3
5858
---
@@ -67,7 +67,6 @@ tags: [hide-output]
6767
ما از import های زیر استفاده خواهیم کرد.
6868

6969
```{code-cell} ipython3
70-
import random
7170
from functools import partial
7271
7372
import numpy as np
@@ -455,15 +454,60 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
455454

456455
### نسخه JAX
457456

458-
حالا بیایید یک نسخه JAX با استفاده از `lax.scan` ایجاد کنیم:
457+
حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set` ایجاد کنیم که، همان‌طور که {ref}`در درس JAX بحث شد <jax_at_workaround>`، راه‌حلی برای آرایه‌های تغییرناپذیر فراهم می‌کند.
459458

460-
(ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود.)
459+
ما از `lax.fori_loop` استفاده می‌کنیم که نسخه‌ای از حلقه for است که می‌تواند توسط XLA کامپایل شود.
461460

462461
```{code-cell} ipython3
463462
cpu = jax.devices("cpu")[0]
464463
465-
@partial(jax.jit, static_argnames=('n',), device=cpu)
466-
def qm_jax(x0, n, α=4.0):
464+
@partial(jax.jit, static_argnames=("n",), device=cpu)
465+
def qm_jax_fori(x0, n, α=4.0):
466+
467+
x = jnp.empty(n + 1).at[0].set(x0)
468+
469+
def update(t, x):
470+
return x.at[t + 1].set(α * x[t] * (1 - x[t]))
471+
472+
x = lax.fori_loop(0, n, update, x)
473+
return x
474+
475+
```
476+
477+
* ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود.
478+
* ما به CPU از طریق `device=cpu` متصل می‌مانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد.
479+
480+
اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد.
481+
482+
بیایید آن را با همان پارامترها زمان‌بندی کنیم:
483+
484+
```{code-cell} ipython3
485+
with qe.Timer():
486+
# First run
487+
x_jax = qm_jax_fori(0.1, n)
488+
# Hold interpreter
489+
x_jax.block_until_ready()
490+
```
491+
492+
بیایید دوباره اجرا کنیم تا سربار کامپایل حذف شود:
493+
494+
```{code-cell} ipython3
495+
with qe.Timer():
496+
# Second run
497+
x_jax = qm_jax_fori(0.1, n)
498+
# Hold interpreter
499+
x_jax.block_until_ready()
500+
```
501+
502+
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
503+
504+
روش دیگری برای پیاده‌سازی حلقه وجود دارد که از `lax.scan` استفاده می‌کند.
505+
506+
این روش جایگزین، به طور قابل بحث، بیشتر با رویکرد تابعی JAX همسو است --- اگرچه سینتکس آن به خاطر سپردن دشواری دارد.
507+
508+
```{code-cell} ipython3
509+
@partial(jax.jit, static_argnames=("n",), device=cpu)
510+
def qm_jax_scan(x0, n, α=4.0):
467511
def update(x, t):
468512
x_new = α * x * (1 - x)
469513
return x_new, x_new
@@ -474,16 +518,12 @@ def qm_jax(x0, n, α=4.0):
474518

475519
این کد خواندن آسانی ندارد اما، در اصل، `lax.scan` به طور مکرر `update` را فراخوانی می‌کند و بازگشت‌های `x_new` را در یک آرایه جمع می‌کند.
476520

477-
```{note}
478-
ما `device=cpu` را در decorator `jax.jit` مشخص می‌کنیم زیرا این محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهره‌برداری GPU از موازی‌سازی باقی می‌گذارد. در نتیجه، سربار راه‌اندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسب‌تر برای این بار کاری می‌کند.
479-
```
480-
481521
بیایید آن را با همان پارامترها زمان‌بندی کنیم:
482522

483523
```{code-cell} ipython3
484524
with qe.Timer():
485525
# First run
486-
x_jax = qm_jax(0.1, n)
526+
x_jax = qm_jax_scan(0.1, n)
487527
# Hold interpreter
488528
x_jax.block_until_ready()
489529
```
@@ -493,13 +533,11 @@ with qe.Timer():
493533
```{code-cell} ipython3
494534
with qe.Timer():
495535
# Second run
496-
x_jax = qm_jax(0.1, n)
536+
x_jax = qm_jax_scan(0.1, n)
497537
# Hold interpreter
498538
x_jax.block_until_ready()
499539
```
500540

501-
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
502-
503541
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند.
504542

505543
### خلاصه
@@ -510,9 +548,9 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
510548

511549
این دقیقاً نحوه تفکر اکثر برنامه‌نویسان در مورد الگوریتم است.
512550

513-
نسخه JAX، از سوی دیگر، نیاز به استفاده از `lax.scan` دارد که به طور قابل توجهی کمتر شهودی است.
551+
نسخه‌های JAX، از سوی دیگر، نیاز به استفاده از `lax.fori_loop` یا `lax.scan` دارند که هر دو کمتر شهودی از یک حلقه استاندارد Python هستند.
514552

515-
علاوه بر این، آرایه‌های تغییرناپذیر JAX به این معنی است که نمی‌توانیم به سادگی عناصر آرایه را در جا به‌روزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت می‌کند.
553+
در حالی که سینتکس `at[t].set` در JAX به‌روزرسانی عنصر به عنصر را ممکن می‌سازد، کد کلی همچنان سخت‌تر از معادل Numba برای خواندن است.
516554

517555
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی است.
518556

@@ -532,11 +570,12 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
532570

533571
کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است.
534572

535-
JAX می‌تواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است و برای کارهای کاملاً ترتیبی، بهره‌وری اضافی ناچیز است.
536-
537-
با این حال، `lax.scan` یک مزیت مهم دارد: از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کند، که Numba قادر به انجام آن نیست.
573+
JAX می‌تواند مسائل ترتیبی را از طریق `lax.fori_loop` یا `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است.
538574

575+
```{note}
576+
یک مزیت مهم `lax.fori_loop` و `lax.scan` این است که از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کنند، که Numba قادر به انجام آن نیست.
539577
اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است.
578+
```
540579

541580
در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.
542581

0 commit comments

Comments
 (0)