Skip to content

Commit 48251cf

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent 279646c commit 48251cf

1 file changed

Lines changed: 54 additions & 76 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 54 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ translation:
2121
Sequential operations: عملیات ترتیبی
2222
Sequential operations::Numba Version: نسخه Numba
2323
Sequential operations::JAX Version: نسخه JAX
24+
Sequential operations::JAX Version::First Attempt: تلاش اول
25+
Sequential operations::JAX Version::Second Attempt: تلاش دوم
2426
Sequential operations::Summary: خلاصه
2527
Overall recommendations: توصیه‌های کلی
2628
---
@@ -137,33 +139,34 @@ m = -np.inf
137139
for x in grid:
138140
for y in grid:
139141
z = f(x, y)
140-
if z > m:
141-
m = z
142+
m = max(m, z)
142143
```
143144

144145
### برداری‌سازی NumPy
145146

146-
اگر به برداری‌سازی به سبک NumPy تغییر دهیم، می‌توانیم از یک شبکه بسیار بزرگتر استفاده کنیم و کد نسبتاً سریع اجرا می‌شود.
147+
بیایید به NumPy تغییر دهیم و از یک شبکه بزرگتر استفاده کنیم
147148

148149
در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند.
149150

150-
(این استراتژی به Matlab بازمی‌گردد.)
151-
152151
```{code-cell} ipython3
152+
# Large grid
153153
grid = np.linspace(-3, 3, 3_000)
154-
x, y = np.meshgrid(grid, grid)
154+
155+
x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
155156
156157
with qe.Timer():
157158
z_max_numpy = np.max(f(x, y))
158-
159-
print(f"NumPy result: {z_max_numpy:.6f}")
160159
```
161160

162161
در نسخه برداری شده، تمام حلقه‌ها در کد کامپایل شده انجام می‌شوند.
163162

164-
علاوه بر این، NumPy از چندنخی ضمنی استفاده می‌کند، به طوری که حداقل مقداری موازی‌سازی رخ می‌دهد.
163+
استفاده از `meshgrid` به ما امکان می‌دهد حلقه for تودرتو را تکرار کنیم.
165164

166-
(موازی‌سازی نمی‌تواند بسیار کارآمد باشد زیرا فایل باینری قبل از اینکه اندازه آرایه‌های `x` و `y` را ببیند کامپایل می‌شود.)
165+
خروجی باید نزدیک به یک باشد:
166+
167+
```{code-cell} ipython3
168+
print(f"NumPy result: {z_max_numpy:.6f}")
169+
```
167170

168171
### مقایسه با Numba
169172

@@ -188,8 +191,6 @@ grid = np.linspace(-3, 3, 3_000)
188191
with qe.Timer():
189192
# First run
190193
z_max_numba = compute_max_numba(grid)
191-
192-
print(f"Numba result: {z_max_numba:.6f}")
193194
```
194195

195196
بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود.
@@ -232,8 +233,6 @@ def compute_max_numba_parallel(grid):
232233
with qe.Timer():
233234
# First run
234235
z_max_parallel = compute_max_numba_parallel(grid)
235-
236-
print(f"Numba result: {z_max_parallel:.6f}")
237236
```
238237

239238
در اینجا زمان‌بندی برای نسخه از پیش کامپایل شده آمده است.
@@ -244,27 +243,30 @@ with qe.Timer():
244243
compute_max_numba_parallel(grid)
245244
```
246245

247-
اگر چندین هسته دارید، باید حداقل برخی مزایا را از موازی‌سازی در اینجا ببینید.
246+
اگر چندین هسته دارید، باید مزایایی از موازی‌سازی در اینجا ببینید.
248247

249-
برای دستگاه‌های قدرتمندتر و اندازه‌های شبکه بزرگتر، موازی‌سازی می‌تواند افزایش سرعت قابل توجهی ایجاد کند، حتی روی CPU.
248+
بیایید مطمئن شویم که نتیجه صحیح را به دست می‌آوریم (نزدیک به یک):
250249

251-
### کد برداری شده با JAX
250+
```{code-cell} ipython3
251+
print(f"Numba result: {z_max_parallel:.6f}")
252+
```
253+
254+
برای دستگاه‌های قدرتمند و اندازه‌های شبکه بزرگتر، موازی‌سازی می‌تواند افزایش سرعت مفیدی ایجاد کند، حتی روی CPU.
252255

253-
در ظاهر، کد برداری شده در JAX شبیه به کد NumPy است.
256+
### کد برداری شده با JAX
254257

255-
اما تفاوت‌هایی نیز وجود دارد که در اینجا آنها را برجسته می‌کنیم.
258+
بیایید رویکرد برداری شده NumPy را با JAX تکرار کنیم.
256259

257260
بیایید با تابع شروع کنیم که `np` را به `jnp` تغییر می‌دهد و `jax.jit` را اضافه می‌کند.
258261

259-
260262
```{code-cell} ipython3
261263
@jax.jit
262264
def f(x, y):
263265
return jnp.cos(x**2 + y**2) / (1 + x**2 + y**2)
264266
265267
```
266268

267-
همانند NumPy، برای به دست آوردن شکل درست و محاسبه حلقه `for` تودرتوی صحیح، می‌توانیم از عملیات `meshgrid` طراحی شده برای این منظور استفاده کنیم:
269+
از رویکرد meshgrid به سبک NumPy استفاده می‌کنیم:
268270

269271
```{code-cell} ipython3
270272
grid = jnp.linspace(-3, 3, 3_000)
@@ -321,68 +323,37 @@ x_mesh.nbytes + y_mesh.nbytes
321323

322324
در اینجا نحوه اعمال آن به مسئله ما آمده است.
323325

324-
```{code-cell} ipython3
325-
# f را تنظیم کنید تا f(x, y) را در هر x برای هر y داده شده محاسبه کند
326-
f_vec_x = lambda y: f(grid, y)
327-
# یک تابع دوم ایجاد کنید که این عملیات را روی تمام y برداری کند
328-
f_vec = jax.vmap(f_vec_x)
329-
```
330-
331-
اکنون `f_vec` هنگام فراخوانی با آرایه تخت `grid`، `f(x,y)` را در هر `x,y` محاسبه می‌کند.
332-
333-
بیایید زمان‌بندی را ببینیم:
334-
335-
```{code-cell} ipython3
336-
with qe.Timer():
337-
z_max = jnp.max(f_vec(grid))
338-
z_max.block_until_ready()
339-
340-
print(f"JAX vmap v1 result: {z_max:.6f}")
341-
```
342-
343-
```{code-cell} ipython3
344-
with qe.Timer():
345-
z_max = jnp.max(f_vec(grid))
346-
z_max.block_until_ready()
347-
```
348-
349-
با اجتناب از آرایه‌های ورودی بزرگ `x_mesh` و `y_mesh`، این نسخه `vmap` از حافظه بسیار کمتری با زمان اجرای مشابه استفاده می‌کند.
350-
351-
اما هنوز برخی بهره‌های سرعت را از دست می‌دهیم.
352-
353-
کد فوق آرایه دوبعدی کامل `f(x,y)` را محاسبه می‌کند و سپس max را می‌گیرد.
354-
355-
علاوه بر این، فراخوانی `jnp.max` خارج از تابع JIT-کامپایل شده `f` قرار دارد، بنابراین کامپایلر نمی‌تواند این عملیات را در یک kernel واحد ادغام کند.
356-
357-
می‌توانیم هر دو مشکل را با انتقال max به داخل و پوشاندن همه چیز در یک `@jax.jit` واحد برطرف کنیم:
358-
359326
```{code-cell} ipython3
360327
@jax.jit
361328
def compute_max_vmap(grid):
362-
# یک تابع بسازید که حداکثر را در امتداد هر سطر بگیرد
329+
# Construct a function that takes the max over all x for given y
363330
f_vec_x_max = lambda y: jnp.max(f(grid, y))
364-
# تابع را برداری کنید تا بتوانیم روی تمام سطرها همزمان فراخوانی کنیم
331+
# Vectorize the function so we can call on all y simultaneously
365332
f_vec_max = jax.vmap(f_vec_x_max)
366-
# تابع برداری شده را فراخوانی کنید و حداکثر را بگیرید
367-
return jnp.max(f_vec_max(grid))
333+
# Compute the max across x at every y
334+
maxes = f_vec_max(grid)
335+
# Compute the max of the maxes and return
336+
return jnp.max(maxes)
368337
```
369338

370-
در اینجا
371-
372-
* `f_vec_x_max` حداکثر را در امتداد هر سطر داده شده محاسبه می‌کند
373-
* `f_vec_max` یک نسخه برداری شده است که می‌تواند حداکثر تمام سطرها را به صورت موازی محاسبه کند.
339+
توجه کنید که هرگز
374340

375-
ما این تابع را روی تمام سطرها اعمال می‌کنیم و سپس حداکثر max های سطر را می‌گیریم.
341+
* شبکه دوبعدی `x_mesh`
342+
* شبکه دوبعدی `y_mesh` یا
343+
* آرایه دوبعدی `f(x,y)`
376344

377-
چون max را به داخل منتقل می‌کنیم، هرگز آرایه دوبعدی کامل `f(x,y)` را نمی‌سازیم و حافظه بیشتری صرفه‌جویی می‌شود.
345+
را نمی‌سازیم.
378346

379347
و چون همه چیز زیر یک `@jax.jit` واحد قرار دارد، کامپایلر می‌تواند تمام عملیات را در یک kernel بهینه ادغام کند.
380348

381349
بیایید آن را امتحان کنیم.
382350

383351
```{code-cell} ipython3
384352
with qe.Timer():
385-
z_max = compute_max_vmap(grid).block_until_ready()
353+
# First run
354+
z_max = compute_max_vmap(grid)
355+
# Hold interpreter
356+
z_max.block_until_ready()
386357
387358
print(f"JAX vmap result: {z_max:.6f}")
388359
```
@@ -391,7 +362,10 @@ print(f"JAX vmap result: {z_max:.6f}")
391362

392363
```{code-cell} ipython3
393364
with qe.Timer():
394-
z_max = compute_max_vmap(grid).block_until_ready()
365+
# Second run
366+
z_max = compute_max_vmap(grid)
367+
# Hold interpreter
368+
z_max.block_until_ready()
395369
```
396370

397371
### خلاصه
@@ -448,13 +422,15 @@ with qe.Timer():
448422

449423
Numba این عملیات ترتیبی را به طور بسیار کارآمد مدیریت می‌کند.
450424

451-
توجه کنید که اجرای دوم پس از تکمیل کامپایل JIT به طور قابل توجهی سریعتر است.
425+
### نسخه JAX
426+
427+
ما نمی‌توانیم مستقیماً `numba.jit` را با `jax.jit` جایگزین کنیم زیرا آرایه‌های JAX تغییرناپذیر هستند.
452428

453-
کامپایل Numba معمولاً بسیار سریع است و عملکرد کد حاصل برای عملیات ترتیبی مانند این عالی است.
429+
اما می‌توانیم این عملیات را پیاده‌سازی کنیم.
454430

455-
### نسخه JAX
431+
#### تلاش اول
456432

457-
حالا بیایید یک نسخه JAX با استفاده از سینتکس `at[t].set` ایجاد کنیم که، همان‌طور که {ref}`در درس JAX بحث شد <jax_at_workaround>`، راه‌حلی برای آرایه‌های تغییرناپذیر فراهم می‌کند.
433+
در اینجا یک راه‌حل با استفاده از سینتکس `at[t].set` ارائه می‌شود که {ref}`در درس JAX بحث شد <jax_at_workaround>`.
458434

459435
ما از `lax.fori_loop` استفاده می‌کنیم که نسخه‌ای از حلقه for است که می‌تواند توسط XLA کامپایل شود.
460436

@@ -477,7 +453,7 @@ def qm_jax_fori(x0, n, α=4.0):
477453
* ما `n` را ایستا نگه می‌داریم زیرا بر اندازه آرایه تأثیر می‌گذارد و از این رو JAX می‌خواهد روی مقدار آن در کد کامپایل شده تخصصی شود.
478454
* ما به CPU از طریق `device=cpu` متصل می‌مانیم زیرا این بار کاری ترتیبی از بسیاری عملیات کوچک تشکیل شده است که فرصت کمی برای موازی‌سازی GPU باقی می‌گذارد.
479455

480-
اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد.
456+
مهم: اگرچه `at[t].set` در هر مرحله ظاهراً یک آرایه جدید ایجاد می‌کند، در داخل یک تابع کامپایل‌شده با JIT، کامپایلر تشخیص می‌دهد که آرایه قدیمی دیگر مورد نیاز نیست و به‌روزرسانی را در جا انجام می‌دهد!
481457

482458
بیایید آن را با همان پارامترها زمان‌بندی کنیم:
483459

@@ -499,7 +475,9 @@ with qe.Timer():
499475
x_jax.block_until_ready()
500476
```
501477

502-
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
478+
JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است!
479+
480+
#### تلاش دوم
503481

504482
روش دیگری برای پیاده‌سازی حلقه وجود دارد که از `lax.scan` استفاده می‌کند.
505483

@@ -538,11 +516,11 @@ with qe.Timer():
538516
x_jax.block_until_ready()
539517
```
540518

541-
هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه می‌دهند.
519+
شگفت‌انگیز است که JAX نیز پس از کامپایل عملکرد قوی ارائه می‌دهد.
542520

543521
### خلاصه
544522

545-
در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه می‌دهند، *تفاوت‌های قابل توجهی در خوانایی کد و سهولت استفاده وجود دارد*.
523+
در حالی که هم Numba و هم JAX عملکرد قوی برای عملیات ترتیبی ارائه می‌دهند، تفاوت‌هایی در خوانایی کد و سهولت استفاده وجود دارد.
546524

547525
نسخه Numba ساده و طبیعی برای خواندن است: ما به سادگی یک آرایه اختصاص می‌دهیم و آن را عنصر به عنصر با استفاده از یک حلقه استاندارد Python پر می‌کنیم.
548526

0 commit comments

Comments
 (0)