You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: lectures/numpy_vs_numba_vs_jax.md
+54-76Lines changed: 54 additions & 76 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -21,6 +21,8 @@ translation:
21
21
Sequential operations: عملیات ترتیبی
22
22
Sequential operations::Numba Version: نسخه Numba
23
23
Sequential operations::JAX Version: نسخه JAX
24
+
Sequential operations::JAX Version::First Attempt: تلاش اول
25
+
Sequential operations::JAX Version::Second Attempt: تلاش دوم
24
26
Sequential operations::Summary: خلاصه
25
27
Overall recommendations: توصیههای کلی
26
28
---
@@ -137,33 +139,34 @@ m = -np.inf
137
139
for x in grid:
138
140
for y in grid:
139
141
z = f(x, y)
140
-
if z > m:
141
-
m = z
142
+
m = max(m, z)
142
143
```
143
144
144
145
### برداریسازی NumPy
145
146
146
-
اگر به برداریسازی به سبک NumPy تغییر دهیم، میتوانیم از یک شبکه بسیار بزرگتر استفاده کنیم و کد نسبتاً سریع اجرا میشود.
147
+
بیایید به NumPy تغییر دهیم و از یک شبکه بزرگتر استفاده کنیم
147
148
148
149
در اینجا از `np.meshgrid` برای ایجاد شبکههای ورودی دوبعدی `x` و `y` استفاده میکنیم به گونهای که `f(x, y)` تمام ارزیابیها را روی شبکه حاصلضرب تولید میکند.
149
150
150
-
(این استراتژی به Matlab بازمیگردد.)
151
-
152
151
```{code-cell} ipython3
152
+
# Large grid
153
153
grid = np.linspace(-3, 3, 3_000)
154
-
x, y = np.meshgrid(grid, grid)
154
+
155
+
x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
155
156
156
157
with qe.Timer():
157
158
z_max_numpy = np.max(f(x, y))
158
-
159
-
print(f"NumPy result: {z_max_numpy:.6f}")
160
159
```
161
160
162
161
در نسخه برداری شده، تمام حلقهها در کد کامپایل شده انجام میشوند.
163
162
164
-
علاوه بر این، NumPy از چندنخی ضمنی استفاده میکند، به طوری که حداقل مقداری موازیسازی رخ میدهد.
163
+
استفاده از `meshgrid` به ما امکان میدهد حلقه for تودرتو را تکرار کنیم.
165
164
166
-
(موازیسازی نمیتواند بسیار کارآمد باشد زیرا فایل باینری قبل از اینکه اندازه آرایههای `x` و `y` را ببیند کامپایل میشود.)
Numba این عملیات ترتیبی را به طور بسیار کارآمد مدیریت میکند.
450
424
451
-
توجه کنید که اجرای دوم پس از تکمیل کامپایل JIT به طور قابل توجهی سریعتر است.
425
+
### نسخه JAX
426
+
427
+
ما نمیتوانیم مستقیماً `numba.jit` را با `jax.jit` جایگزین کنیم زیرا آرایههای JAX تغییرناپذیر هستند.
452
428
453
-
کامپایل Numba معمولاً بسیار سریع است و عملکرد کد حاصل برای عملیات ترتیبی مانند این عالی است.
429
+
اما میتوانیم این عملیات را پیادهسازی کنیم.
454
430
455
-
###نسخه JAX
431
+
#### تلاش اول
456
432
457
-
حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set`ایجاد کنیم که، همانطور که {ref}`در درس JAX بحث شد <jax_at_workaround>`، راهحلی برای آرایههای تغییرناپذیر فراهم میکند.
433
+
در اینجا یک راهحل با استفاده از سینتکس `at[t].set`ارائه میشود که{ref}`در درس JAX بحث شد <jax_at_workaround>`.
458
434
459
435
ما از `lax.fori_loop` استفاده میکنیم که نسخهای از حلقه for است که میتواند توسط XLA کامپایل شود.
460
436
@@ -477,7 +453,7 @@ def qm_jax_fori(x0, n, α=4.0):
477
453
* ما `n` را ایستا نگه میداریم زیرا بر اندازه آرایه تأثیر میگذارد و از این رو JAX میخواهد روی مقدار آن در کد کامپایل شده تخصصی شود.
478
454
* ما به CPU از طریق `device=cpu` متصل میمانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازیسازی GPU باقی میگذارد.
479
455
480
-
اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد میکند، در داخل یک تابع کامپایلشده با JIT، کامپایلر تشخیص میدهد که آرایه قدیمی دیگر مورد نیاز نیست و بهروزرسانی را در جا انجام میدهد.
456
+
مهم: اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد میکند، در داخل یک تابع کامپایلشده با JIT، کامپایلر تشخیص میدهد که آرایه قدیمی دیگر مورد نیاز نیست و بهروزرسانی را در جا انجام میدهد!
481
457
482
458
بیایید آن را با همان پارامترها زمانبندی کنیم:
483
459
@@ -499,7 +475,9 @@ with qe.Timer():
499
475
x_jax.block_until_ready()
500
476
```
501
477
502
-
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
478
+
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است!
479
+
480
+
#### تلاش دوم
503
481
504
482
روش دیگری برای پیادهسازی حلقه وجود دارد که از `lax.scan` استفاده میکند.
505
483
@@ -538,11 +516,11 @@ with qe.Timer():
538
516
x_jax.block_until_ready()
539
517
```
540
518
541
-
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه میدهند.
519
+
شگفتانگیز است که JAX نیز پس از کامپایل عملکرد قوی ارائه میدهد.
542
520
543
521
### خلاصه
544
522
545
-
در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه میدهند، *تفاوتهای قابل توجهی در خوانایی کد و سهولت استفاده وجود دارد*.
523
+
در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه میدهند، تفاوتهایی در خوانایی کد و سهولت استفاده وجود دارد.
546
524
547
525
نسخه Numba ساده و طبیعی برای خواندن است: ما به سادگی یک آرایه اختصاص میدهیم و آن را عنصر به عنصر با استفاده از یک حلقه استاندارد Python پر میکنیم.
0 commit comments