Skip to content

Commit f3a2447

Browse files
authored
🌐 [translation-sync] Improve and simplify JAX lectures (#109)
* Update translation: lectures/jax_intro.md * Update translation: .translate/state/jax_intro.md.yml * Update translation: lectures/numpy_vs_numba_vs_jax.md * Update translation: .translate/state/numpy_vs_numba_vs_jax.md.yml
1 parent 4d9ab6e commit f3a2447

4 files changed

Lines changed: 119 additions & 192 deletions

File tree

.translate/state/jax_intro.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
2-
synced-at: "2026-04-13"
1+
source-sha: 450bafecd23db638602150b47f4272b98aad3146
2+
synced-at: "2026-04-14"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 7

.translate/state/numpy_vs_numba_vs_jax.md.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
source-sha: 11e7d823f7f355f5025d40cab40bf801b3262e56
2-
synced-at: "2026-04-13"
1+
source-sha: 450bafecd23db638602150b47f4272b98aad3146
2+
synced-at: "2026-04-14"
33
model: claude-sonnet-4-6
44
mode: UPDATE
55
section-count: 3

lectures/jax_intro.md

Lines changed: 61 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ translation:
2727
Functional Programming::Examples: مثال‌ها
2828
Functional Programming::Why Functional Programming?: چرا برنامه‌نویسی تابعی؟
2929
Random numbers: اعداد تصادفی
30-
Random numbers::Random number generation: تولید اعداد تصادفی
31-
Random numbers::Why explicit random state?: چرا وضعیت تصادفی صریح؟
32-
Random numbers::Why explicit random state?::NumPy's approach: رویکرد NumPy
33-
Random numbers::Why explicit random state?::JAX's approach: رویکرد JAX
30+
Random numbers::NumPy / MATLAB Approach: رویکرد NumPy / MATLAB
31+
Random numbers::JAX: JAX
32+
Random numbers::Benefits: مزایا
3433
JIT Compilation: کامپایل JIT
3534
JIT Compilation::With NumPy: با NumPy
3635
JIT Compilation::With JAX: با JAX
@@ -416,15 +415,31 @@ JAX از سبک برنامه‌نویسی تابعی استفاده می‌کن
416415

417416
## اعداد تصادفی
418417

419-
اعداد تصادفی در JAX نسبت به آنچه در NumPy یا Matlab می‌یابید بسیار متفاوت هستند.
418+
اعداد تصادفی در JAX نسبت به آنچه در NumPy یا MATLAB می‌یابید بسیار متفاوت هستند.
420419

421-
در ابتدا ممکن است نحو را نسبتاً مفصل بیابید.
420+
### رویکرد NumPy / MATLAB
422421

423-
اما به زودی متوجه خواهید شد که نحو و معناشناسی برای حفظ سبک برنامه‌نویسی تابعی که به تازگی مورد بحث قرار دادیم، ضروری است.
422+
در NumPy / MATLAB، تولید اعداد تصادفی با حفظ وضعیت سراسری پنهان کار می‌کند.
424423

425-
علاوه بر این، کنترل کامل وضعیت تصادفی برای برنامه‌نویسی موازی، مانند زمانی که می‌خواهیم آزمایش‌های مستقل را در چندین رشته اجرا کنیم، ضروری است.
424+
```{code-cell} ipython3
425+
np.random.seed(42)
426+
print(np.random.randn(2))
427+
```
428+
429+
هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، وضعیت پنهان به‌روزرسانی می‌شود:
430+
431+
```{code-cell} ipython3
432+
print(np.random.randn(2))
433+
```
434+
435+
این تابع *خالص نیست* زیرا:
426436

427-
### تولید اعداد تصادفی
437+
* غیرقطعی است: ورودی‌های یکسان، خروجی‌های متفاوت
438+
* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد
439+
440+
در موازی‌سازی خطرناک است --- باید با دقت کنترل کرد که در هر رشته چه اتفاقی می‌افتد!
441+
442+
### JAX
428443

429444
در JAX، وضعیت مولد اعداد تصادفی به صورت صریح کنترل می‌شود.
430445

@@ -545,119 +560,48 @@ def gen_random_matrices(key, n=2, k=3):
545560
key, subkey = jax.random.split(key)
546561
A = jax.random.uniform(subkey, (n, n))
547562
matrices.append(A)
548-
print(A)
549563
return matrices
550564
```
551565

552566
```{code-cell} ipython3
553567
seed = 42
554568
key = jax.random.key(seed)
555-
matrices = gen_random_matrices(key)
556-
```
557-
558-
همچنین می‌توانیم هنگام تکرار در یک حلقه از `fold_in` استفاده کنیم:
559-
560-
```{code-cell} ipython3
561-
def gen_random_matrices(key, n=2, k=3):
562-
matrices = []
563-
for i in range(k):
564-
step_key = jax.random.fold_in(key, i)
565-
A = jax.random.uniform(step_key, (n, n))
566-
matrices.append(A)
567-
print(A)
568-
return matrices
569-
```
570-
571-
```{code-cell} ipython3
572-
key = jax.random.key(seed)
573-
matrices = gen_random_matrices(key)
574-
```
575-
576-
### چرا وضعیت تصادفی صریح؟
577-
578-
چرا JAX به این رویکرد نسبتاً مفصل برای تولید اعداد تصادفی نیاز دارد؟
579-
580-
یکی از دلایل حفظ توابع خالص است.
581-
582-
بیایید ببینیم که چگونه تولید اعداد تصادفی با توابع خالص با مقایسه NumPy و JAX مرتبط است.
583-
584-
#### رویکرد NumPy
585-
586-
در NumPy، تولید اعداد تصادفی با حفظ وضعیت سراسری پنهان کار می‌کند.
587-
588-
هر بار که یک تابع تصادفی را فراخوانی می‌کنیم، این وضعیت به‌روزرسانی می‌شود:
589-
590-
```{code-cell} ipython3
591-
np.random.seed(42)
592-
print(np.random.randn()) # Updates state of random number generator
593-
print(np.random.randn()) # Updates state of random number generator
569+
gen_random_matrices(key)
594570
```
595571

596-
هر فراخوانی یک مقدار متفاوت را برمی‌گرداند، حتی اگر ما همان تابع را با همان ورودی‌ها (بدون آرگومان، در این مورد) فراخوانی می‌کنیم.
572+
این تابع *خالص* است
597573

598-
این تابع *خالص نیست* زیرا:
599-
600-
* غیرقطعی است: ورودی‌های یکسان (در این مورد هیچ) خروجی‌های متفاوت می‌دهند
601-
* دارای عوارض جانبی است: وضعیت مولد اعداد تصادفی سراسری را تغییر می‌دهد
602-
603-
#### رویکرد JAX
604-
605-
همانطور که در بالا دیدیم، JAX رویکرد متفاوتی اتخاذ می‌کند و تصادفی بودن را از طریق کلیدها صریح می‌کند.
606-
607-
برای مثال،
608-
609-
```{code-cell} ipython3
610-
def random_sum_jax(key):
611-
key1, key2 = jax.random.split(key)
612-
x = jax.random.normal(key1)
613-
y = jax.random.normal(key2)
614-
return x + y
615-
```
616-
617-
با همان کلید، همیشه نتیجه یکسانی دریافت می‌کنیم:
618-
619-
```{code-cell} ipython3
620-
key = jax.random.key(42)
621-
random_sum_jax(key)
622-
```
623-
624-
```{code-cell} ipython3
625-
random_sum_jax(key)
626-
```
627-
628-
برای دریافت نمونه‌های جدید باید یک کلید جدید ارائه دهیم.
629-
630-
تابع `random_sum_jax` خالص است زیرا:
631-
632-
* قطعی است: کلید یکسان همیشه خروجی یکسان تولید می‌کند
574+
* قطعی است: ورودی‌های یکسان، خروجی یکسان
633575
* بدون عوارض جانبی: هیچ وضعیت پنهانی تغییر نمی‌کند
634576

577+
### مزایا
578+
635579
صریح بودن JAX مزایای قابل توجهی به همراه دارد:
636580

637581
* تکرارپذیری: با استفاده مجدد از کلیدها، تکرار نتایج آسان است
638-
* موازی‌سازی: هر رشته می‌تواند کلید خاص خود را بدون تضاد داشته باشد
639-
* اشکال‌زدایی: نبود وضعیت پنهان استدلال در مورد کد را آسان‌تر می‌کند
582+
* موازی‌سازی: کنترل آنچه در رشته‌های جداگانه اتفاق می‌افتد
583+
* اشکال‌زدایی: نبود وضعیت پنهان آزمایش کد را آسان‌تر می‌کند
640584
* سازگاری با JIT: کامپایلر می‌تواند توابع خالص را به طور تهاجمی‌تری بهینه کند
641585

642-
نکته آخر در بخش بعدی گسترش داده می‌شود.
643-
644586
## کامپایل JIT
645587

646588
کامپایلر just-in-time (JIT) JAX اجرا را با تولید کد ماشین کارآمد که با هم اندازه وظیفه و هم سخت‌افزار متفاوت است، تسریع می‌کند.
647589

648590
ما قدرت کامپایلر JIT JAX را در ترکیب با سخت‌افزار موازی {ref}`در بالا <jax_speed>` مشاهده کردیم، هنگامی که `cos` را روی یک آرایه بزرگ اعمال کردیم.
649591

650-
بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم:
592+
در اینجا کامپایل JIT را برای توابع پیچیده‌تر بررسی می‌کنیم.
593+
594+
### با NumPy
595+
596+
ابتدا با NumPy امتحان خواهیم کرد، با استفاده از
651597

652598
```{code-cell}
653599
def f(x):
654600
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
655601
return y
656602
```
657603

658-
### با NumPy
659-
660-
ابتدا با NumPy امتحان خواهیم کرد
604+
بیایید با `x` بزرگ اجرا کنیم
661605

662606
```{code-cell}
663607
n = 50_000_000
@@ -670,9 +614,17 @@ with qe.Timer():
670614
y = f(x)
671615
```
672616

673-
### با JAX
617+
مدل اجرای **Eager**
618+
619+
* هر عملیات بلافاصله هنگامی که با آن مواجه می‌شود اجرا می‌شود و نتیجه آن را قبل از شروع عملیات بعدی مادی‌سازی می‌کند.
620+
621+
معایب
622+
623+
* موازی‌سازی حداقل
624+
* ردپای حافظه سنگین --- آرایه‌های میانی زیادی تولید می‌کند
625+
* خواندن/نوشتن حافظه زیاد
674626

675-
اکنون بیایید دوباره با JAX امتحان کنیم.
627+
### با JAX
676628

677629
به عنوان اولین مرحله، `np` را در همه جا با `jnp` جایگزین می‌کنیم:
678630

@@ -703,14 +655,15 @@ with qe.Timer():
703655
jax.block_until_ready(y);
704656
```
705657

706-
نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در
707-
اجرای دوم پس از کامپایل JIT.
658+
نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در اجرای دوم پس از کامپایل JIT.
708659

709-
علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد.
660+
اما همچنان از اجرای eager استفاده می‌کنیم --- حافظه و خواندن/نوشتن زیاد.
710661

711662
### کامپایل کل تابع
712663

713-
کامپایلر just-in-time (JIT) JAX می‌تواند اجرا را در درون توابع با ادغام عملیات آرایه‌ای در یک هسته بهینه شده واحد تسریع کند.
664+
خوشبختانه، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد.
665+
666+
کامپایلر تمام عملیات آرایه‌ای را در یک هسته بهینه‌شده واحد ادغام می‌کند.
714667

715668
بیایید این را با تابع `f` امتحان کنیم:
716669

@@ -734,9 +687,11 @@ with qe.Timer():
734687
jax.block_until_ready(y);
735688
```
736689

737-
زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم و به کامپایلر اجازه دادیم به طور تهاجمی‌تری بهینه‌سازی کند.
690+
زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم.
738691

739-
برای مثال، کامپایلر می‌تواند چندین فراخوانی به شتاب‌دهنده سخت‌افزاری و ایجاد تعدادی آرایه میانی را حذف کند.
692+
* بهینه‌سازی تهاجمی بر اساس کل دنباله محاسباتی
693+
* حذف چندین فراخوانی به شتاب‌دهنده سخت‌افزاری
694+
* عدم ایجاد آرایه‌های میانی
740695

741696
اتفاقاً، نحو رایج‌تر هنگام هدف قرار دادن یک تابع برای کامپایلر JIT این است
742697

@@ -756,11 +711,9 @@ def f(x):
756711

757712
### کامپایل توابع غیرخالص
758713

759-
اکنون که دیدیم کامپایل JIT چقدر قدرتمند می‌تواند باشد، درک رابطه آن با توابع خالص مهم است.
714+
در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود!
760715

761-
در حالی که JAX معمولاً هنگام کامپایل توابع ناخالص خطا نمی‌دهد، اجرا غیرقابل پیش‌بینی می‌شود.
762-
763-
در اینجا تصویری از این واقعیت با استفاده از متغیرهای سراسری آورده شده است:
716+
در اینجا تصویری از این واقعیت آورده شده است:
764717

765718
```{code-cell} ipython3
766719
a = 1 # global
@@ -840,17 +793,13 @@ for row in X:
840793

841794
با این حال، حلقه‌های Python کُند هستند و نمی‌توانند به‌طور کارآمد توسط JAX کامپایل یا موازی‌سازی شوند.
842795

843-
استفاده از `vmap` محاسبه را روی شتاب‌دهنده نگه می‌دارد و با سایر
844-
تبدیل‌های JAX مانند `jit` و `grad` ترکیب می‌شود:
796+
با استفاده از `vmap`، می‌توانیم از حلقه‌ها اجتناب کنیم و محاسبه را روی شتاب‌دهنده نگه داریم:
845797

846798
```{code-cell} ipython3
847-
batch_mm_diff = jax.vmap(mm_diff)
848-
batch_mm_diff(X)
799+
batch_mm_diff = jax.vmap(mm_diff) # Create a new "vectorized" version
800+
batch_mm_diff(X) # Apply to each row of X
849801
```
850802

851-
تابع `mm_diff` برای یک آرایه منفرد نوشته شده بود، و `vmap` به‌طور خودکار
852-
آن را برای عمل سطربه‌سطر روی یک ماتریس ارتقا داد --- بدون حلقه، بدون تغییر شکل.
853-
854803
### ترکیب تبدیل‌ها
855804

856805
یکی از نقاط قوت JAX این است که تبدیل‌ها به‌طور طبیعی با هم ترکیب می‌شوند.

0 commit comments

Comments
 (0)