You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lectures/numpy_vs_numba_vs_jax.md
+60-21Lines changed: 60 additions & 21 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -36,23 +36,23 @@ translation:
36
36
37
37
# NumPy در مقابل Numba در مقابل JAX
38
38
39
-
در سخنرانیهای قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
39
+
در درسهای قبلی، سه کتابخانه اصلی برای محاسبات علمی و عددی را بحث کردیم:
40
40
41
41
*[NumPy](numpy)
42
42
*[Numba](numba)
43
43
*[JAX](jax_intro)
44
44
45
45
کدام یک را باید در هر موقعیت استفاده کنیم؟
46
46
47
-
این سخنرانی به آن سؤال پاسخ میدهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
47
+
این درس به آن سؤال پاسخ میدهد، حداقل تا حدی، با بحث در مورد برخی موارد استفاده.
48
48
49
49
قبل از شروع، توجه میکنیم که دو مورد اول یک جفت طبیعی هستند: NumPy و Numba به خوبی با هم کار میکنند.
50
50
51
51
JAX، از سوی دیگر، به تنهایی میایستد.
52
52
53
53
هنگام بررسی هر رویکرد، نه تنها کارایی و رد پای حافظه، بلکه وضوح و سهولت استفاده را نیز در نظر خواهیم گرفت.
54
54
55
-
علاوه بر آنچه در Anaconda موجود است، این سخنرانی به کتابخانههای زیر نیاز دارد:
55
+
علاوه بر آنچه در Anaconda موجود است، این درس به کتابخانههای زیر نیاز دارد:
56
56
57
57
```{code-cell} ipython3
58
58
---
@@ -67,7 +67,6 @@ tags: [hide-output]
67
67
ما از import های زیر استفاده خواهیم کرد.
68
68
69
69
```{code-cell} ipython3
70
-
import random
71
70
from functools import partial
72
71
73
72
import numpy as np
@@ -455,15 +454,60 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
455
454
456
455
### نسخه JAX
457
456
458
-
حالا بیایید یک نسخه JAX با استفاده از `lax.scan` ایجاد کنیم:
457
+
حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set` ایجاد کنیم که، همانطور که {ref}`در درس JAX بحث شد <jax_at_workaround>`، راهحلی برای آرایههای تغییرناپذیر فراهم میکند.
459
458
460
-
(ما `n` را ایستا نگه میداریم زیرا بر اندازه آرایه تأثیر میگذارد و از این رو JAX میخواهد روی مقدار آن در کد کامپایل شده تخصصی شود.)
459
+
ما از `lax.fori_loop` استفاده میکنیم که نسخهای از حلقه for است که میتواند توسط XLA کامپایل شود.
* ما `n` را ایستا نگه میداریم زیرا بر اندازه آرایه تأثیر میگذارد و از این رو JAX میخواهد روی مقدار آن در کد کامپایل شده تخصصی شود.
478
+
* ما به CPU از طریق `device=cpu` متصل میمانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازیسازی GPU باقی میگذارد.
479
+
480
+
اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد میکند، در داخل یک تابع کامپایلشده با JIT، کامپایلر تشخیص میدهد که آرایه قدیمی دیگر مورد نیاز نیست و بهروزرسانی را در جا انجام میدهد.
481
+
482
+
بیایید آن را با همان پارامترها زمانبندی کنیم:
483
+
484
+
```{code-cell} ipython3
485
+
with qe.Timer():
486
+
# First run
487
+
x_jax = qm_jax_fori(0.1, n)
488
+
# Hold interpreter
489
+
x_jax.block_until_ready()
490
+
```
491
+
492
+
بیایید دوباره اجرا کنیم تا سربار کامپایل حذف شود:
493
+
494
+
```{code-cell} ipython3
495
+
with qe.Timer():
496
+
# Second run
497
+
x_jax = qm_jax_fori(0.1, n)
498
+
# Hold interpreter
499
+
x_jax.block_until_ready()
500
+
```
501
+
502
+
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
503
+
504
+
روش دیگری برای پیادهسازی حلقه وجود دارد که از `lax.scan` استفاده میکند.
505
+
506
+
این روش جایگزین، به طور قابل بحث، بیشتر با رویکرد تابعی JAX همسو است --- اگرچه سینتکس آن به خاطر سپردن دشواری دارد.
این کد خواندن آسانی ندارد اما، در اصل، `lax.scan` به طور مکرر `update` را فراخوانی میکند و بازگشتهای `x_new` را در یک آرایه جمع میکند.
476
520
477
-
```{note}
478
-
ما `device=cpu` را در decorator `jax.jit` مشخص میکنیم زیرا این محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهرهبرداری GPU از موازیسازی باقی میگذارد. در نتیجه، سربار راهاندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسبتر برای این بار کاری میکند.
479
-
```
480
-
481
521
بیایید آن را با همان پارامترها زمانبندی کنیم:
482
522
483
523
```{code-cell} ipython3
484
524
with qe.Timer():
485
525
# First run
486
-
x_jax = qm_jax(0.1, n)
526
+
x_jax = qm_jax_scan(0.1, n)
487
527
# Hold interpreter
488
528
x_jax.block_until_ready()
489
529
```
@@ -493,13 +533,11 @@ with qe.Timer():
493
533
```{code-cell} ipython3
494
534
with qe.Timer():
495
535
# Second run
496
-
x_jax = qm_jax(0.1, n)
536
+
x_jax = qm_jax_scan(0.1, n)
497
537
# Hold interpreter
498
538
x_jax.block_until_ready()
499
539
```
500
540
501
-
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
502
-
503
541
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه میدهند.
504
542
505
543
### خلاصه
@@ -510,9 +548,9 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
510
548
511
549
این دقیقاً نحوه تفکر اکثر برنامهنویسان در مورد الگوریتم است.
512
550
513
-
نسخه JAX، از سوی دیگر، نیاز به استفاده از `lax.scan`دارد که به طور قابل توجهی کمتر شهودی است.
551
+
نسخههای JAX، از سوی دیگر، نیاز به استفاده از `lax.fori_loop` یا `lax.scan`دارند که هر دو کمتر شهودی از یک حلقه استاندارد Python هستند.
514
552
515
-
علاوه بر این، آرایههای تغییرناپذیر JAX به این معنی است که نمیتوانیم به سادگی عناصر آرایه را در جا بهروزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت میکند.
553
+
در حالی که سینتکس `at[t].set` در JAX بهروزرسانی عنصر به عنصر را ممکن میسازد، کد کلی همچنان سختتر از معادل Numba برای خواندن است.
516
554
517
555
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیادهسازی است.
518
556
@@ -532,11 +570,12 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
532
570
533
571
کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است.
534
572
535
-
JAX میتواند مسائل ترتیبی را از طریق `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است و برای کارهای کاملاً ترتیبی، بهرهوری اضافی ناچیز است.
536
-
537
-
با این حال، `lax.scan` یک مزیت مهم دارد: از مشتقگیری خودکار در طول حلقه پشتیبانی میکند، که Numba قادر به انجام آن نیست.
573
+
JAX میتواند مسائل ترتیبی را از طریق `lax.fori_loop` یا `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است.
538
574
575
+
```{note}
576
+
یک مزیت مهم `lax.fori_loop` و `lax.scan` این است که از مشتقگیری خودکار در طول حلقه پشتیبانی میکنند، که Numba قادر به انجام آن نیست.
539
577
اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیتهای یک مسیر نسبت به پارامترهای مدل)، JAX علیرغم نحو کمتر طبیعیاش، انتخاب بهتری است.
578
+
```
540
579
541
580
در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.
0 commit comments