@@ -17,8 +17,6 @@ translation:
1717 Vectorized operations::Parallelized Numba : Numba موازی شده
1818 Vectorized operations::Vectorized code with JAX : کد برداری شده با JAX
1919 Vectorized operations::JAX plus vmap : JAX به علاوه vmap
20- Vectorized operations::JAX plus vmap::Version 1 : نسخه 1
21- Vectorized operations::vmap version 2 : نسخه 2 vmap
2220 Vectorized operations::Summary : خلاصه
2321 Sequential operations : عملیات ترتیبی
2422 Sequential operations::Numba Version : نسخه Numba
@@ -27,7 +25,7 @@ translation:
2725 Overall recommendations : توصیههای کلی
2826---
2927
30- (parallel )=
28+ (numpy_numba_jax )=
3129``` {raw} jupyter
3230<div id="qe-notebook-header" align="right" style="text-align:right;">
3331 <a href="https://quantecon.org/" title="quantecon.org">
@@ -156,7 +154,7 @@ for x in grid:
156154grid = np.linspace(-3, 3, 3_000)
157155x, y = np.meshgrid(grid, grid)
158156
159- with qe.Timer(precision=8 ):
157+ with qe.Timer():
160158 z_max_numpy = np.max(f(x, y))
161159
162160print(f"NumPy result: {z_max_numpy:.6f}")
@@ -179,13 +177,17 @@ def compute_max_numba(grid):
179177 for x in grid:
180178 for y in grid:
181179 z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
182- if z > m:
183- m = z
180+ m = max(m, z)
184181 return m
182+ ```
183+
184+ بیایید آن را آزمایش کنیم:
185185
186+ ``` {code-cell} ipython3
186187grid = np.linspace(-3, 3, 3_000)
187188
188- with qe.Timer(precision=8):
189+ with qe.Timer():
190+ # First run
189191 z_max_numba = compute_max_numba(grid)
190192
191193print(f"Numba result: {z_max_numba:.6f}")
@@ -194,22 +196,23 @@ print(f"Numba result: {z_max_numba:.6f}")
194196بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود.
195197
196198``` {code-cell} ipython3
197- with qe.Timer(precision=8):
199+ with qe.Timer():
200+ # Second run
198201 compute_max_numba(grid)
199202```
200203
201- بسته به دستگاه شما، نسخه Numba میتواند کمی کندتر یا کمی سریعتر از NumPy باشد.
204+ بسته به دستگاه شما، نسخه Numba ممکن است کندتر یا سریعتر از NumPy باشد.
202205
203- از یک طرف، NumPy محاسبات کارآمد (مانند Numba) را با مقداری چندنخی (برخلاف این کد Numba) ترکیب میکند که مزیتی فراهم میکند.
206+ در اکثر موارد، Numba کمی بهتر است.
207+
208+ از یک طرف، NumPy محاسبات کارآمد را با مقداری چندنخی ترکیب میکند که مزیتی فراهم میکند.
204209
205210از طرف دیگر، روال Numba از حافظه بسیار کمتری استفاده میکند، زیرا ما فقط با یک شبکه یکبعدی کار میکنیم.
206211
207212### Numba موازی شده
208213
209214حالا بیایید موازیسازی با Numba را با استفاده از ` prange ` امتحان کنیم:
210215
211- در اینجا یک تلاش ساده و * نادرست* آمده است.
212-
213216``` {code-cell} ipython3
214217@numba.jit(parallel=True)
215218def compute_max_numba_parallel(grid):
@@ -220,57 +223,25 @@ def compute_max_numba_parallel(grid):
220223 x = grid[i]
221224 y = grid[j]
222225 z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
223- if z > m:
224- m = z
226+ m = max(m, z)
225227 return m
226-
227228```
228229
229- معمولاً این نتیجه نادرستی برمیگرداند:
230-
231- ``` {code-cell} ipython3
232- z_max_parallel_incorrect = compute_max_numba_parallel(grid)
233- print(f"Numba result: {z_max_parallel_incorrect} 😱")
234- ```
235-
236- دلیل این است که متغیر ` m ` بین نخها مشترک است و به درستی کنترل نمیشود.
237-
238- وقتی چندین نخ سعی میکنند همزمان ` m ` را بخوانند و بنویسند، با یکدیگر تداخل میکنند.
239-
240- نخها مقادیر قدیمی ` m ` را میخوانند یا بهروزرسانیهای یکدیگر را بازنویسی میکنند --- یا ` m ` هرگز از مقدار اولیه خود بهروزرسانی نمیشود.
241-
242- در اینجا یک نسخه با دقت بیشتری نوشته شده است.
230+ در اینجا یک اجرای گرمکننده و آزمایش آمده است.
243231
244232``` {code-cell} ipython3
245- @numba.jit(parallel=True)
246- def compute_max_numba_parallel(grid):
247- n = len(grid)
248- row_maxes = np.empty(n)
249- for i in numba.prange(n):
250- row_max = -np.inf
251- for j in range(n):
252- x = grid[i]
253- y = grid[j]
254- z = np.cos(x**2 + y**2) / (1 + x**2 + y**2)
255- if z > row_max:
256- row_max = z
257- row_maxes[i] = row_max
258- return np.max(row_maxes)
259- ```
233+ with qe.Timer():
234+ # First run
235+ z_max_parallel = compute_max_numba_parallel(grid)
260236
261- اکنون بلوک کدی که ` for i in numba.prange(n) ` روی آن عمل میکند بین ` i ` ها مستقل است.
262-
263- هر نخ به یک عنصر جداگانه از آرایه ` row_maxes ` مینویسد و موازیسازی ایمن است.
264-
265- ``` {code-cell} ipython3
266- z_max_parallel = compute_max_numba_parallel(grid)
267237print(f"Numba result: {z_max_parallel:.6f}")
268238```
269239
270- در اینجا زمانبندی آمده است.
240+ در اینجا زمانبندی برای نسخه از پیش کامپایل شده آمده است.
271241
272242``` {code-cell} ipython3
273- with qe.Timer(precision=8):
243+ with qe.Timer():
244+ # Second run
274245 compute_max_numba_parallel(grid)
275246```
276247
@@ -284,8 +255,7 @@ with qe.Timer(precision=8):
284255
285256اما تفاوتهایی نیز وجود دارد که در اینجا آنها را برجسته میکنیم.
286257
287- بیایید با تابع شروع کنیم.
288-
258+ بیایید با تابع شروع کنیم که ` np ` را به ` jnp ` تغییر میدهد و ` jax.jit ` را اضافه میکند.
289259
290260``` {code-cell} ipython3
291261@jax.jit
@@ -299,9 +269,15 @@ def f(x, y):
299269``` {code-cell} ipython3
300270grid = jnp.linspace(-3, 3, 3_000)
301271x_mesh, y_mesh = jnp.meshgrid(grid, grid)
272+ ```
273+
274+ حالا بیایید اجرا و زمانبندی کنیم
302275
303- with qe.Timer(precision=8):
276+ ``` {code-cell} ipython3
277+ with qe.Timer():
278+ # First run
304279 z_max = jnp.max(f(x_mesh, y_mesh))
280+ # Hold interpreter
305281 z_max.block_until_ready()
306282
307283print(f"Plain vanilla JAX result: {z_max:.6f}")
@@ -310,8 +286,10 @@ print(f"Plain vanilla JAX result: {z_max:.6f}")
310286بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود.
311287
312288``` {code-cell} ipython3
313- with qe.Timer(precision=8):
289+ with qe.Timer():
290+ # Second run
314291 z_max = jnp.max(f(x_mesh, y_mesh))
292+ # Hold interpreter
315293 z_max.block_until_ready()
316294```
317295
@@ -339,14 +317,14 @@ x_mesh.nbytes + y_mesh.nbytes
339317
340318خوشبختانه، JAX رویکرد متفاوتی را با استفاده از [ jax.vmap] ( https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html ) میپذیرد.
341319
342- #### نسخه 1
320+ ایده ` vmap ` این است که برداریسازی را به مراحل تقسیم کند و تابعی که روی مقادیر منفرد عمل میکند را به تابعی تبدیل کند که روی آرایهها عمل میکند.
343321
344- در اینجا یک راه برای اعمال ` vmap ` آمده است.
322+ در اینجا نحوه اعمال آن به مسئله ما آمده است.
345323
346324``` {code-cell} ipython3
347- # f را تنظیم کنید تا f(x, y) را در هر x برای هر y داده شده محاسبه کند
325+ # Set up f to compute f(x, y) at every x for any given y
348326f_vec_x = lambda y: f(grid, y)
349- # یک تابع دوم ایجاد کنید که این عملیات را روی تمام y برداری کند
327+ # Create a second function that vectorizes this operation over all y
350328f_vec = jax.vmap(f_vec_x)
351329```
352330
@@ -355,49 +333,37 @@ f_vec = jax.vmap(f_vec_x)
355333بیایید زمانبندی را ببینیم:
356334
357335``` {code-cell} ipython3
358- with qe.Timer(precision=8 ):
336+ with qe.Timer():
359337 z_max = jnp.max(f_vec(grid))
360338 z_max.block_until_ready()
361339
362340print(f"JAX vmap v1 result: {z_max:.6f}")
363341```
364342
365343``` {code-cell} ipython3
366- with qe.Timer(precision=8 ):
344+ with qe.Timer():
367345 z_max = jnp.max(f_vec(grid))
368346 z_max.block_until_ready()
369347```
370348
371- با اجتناب از آرایههای ورودی بزرگ ` x_mesh ` و ` y_mesh ` ، این نسخه ` vmap ` از حافظه بسیار کمتری استفاده میکند.
372-
373- وقتی روی CPU اجرا میشود، زمان اجرای آن شبیه به نسخه meshgrid است.
374-
375- وقتی روی GPU اجرا میشود، معمولاً به طور قابل توجهی سریعتر است.
376-
377- در واقع، استفاده از ` vmap ` مزیت دیگری دارد: به ما اجازه میدهد برداریسازی را به مراحل تقسیم کنیم.
378-
379- این منجر به کدی میشود که اغلب راحتتر از کد برداری شده سنتی قابل درک است.
349+ با اجتناب از آرایههای ورودی بزرگ ` x_mesh ` و ` y_mesh ` ، این نسخه ` vmap ` از حافظه بسیار کمتری استفاده میکند و زمان اجرا نیز تغییر چندانی نمیکند.
380350
381- ما این ایدهها را بیشتر هنگام حل مسائل بزرگتر بررسی خواهیم کرد.
351+ این خوب است --- اما هنوز از دستاوردهای سرعت بهره نمیبریم!
382352
383- ### نسخه 2 vmap
353+ اول توجه کنید که کد بالا آرایه دوبعدی کامل ` f(x,y) ` را محاسبه میکند که پیش از گرفتن حداکثر، سربارهایی ایجاد میکند.
384354
385- میتوانیم با استفاده از vmap همچنان کارآمدتر از نظر حافظه باشیم .
355+ دوم، فراخوانی ` jnp.max ` خارج از تابع JIT-compiled شده ` f ` قرار دارد، بنابراین کامپایلر نمیتواند این عملیات را در یک kernel واحد ادغام کند .
386356
387- در حالی که در نسخه قبلی از آرایههای ورودی بزرگ اجتناب میکنیم، هنوز آرایه خروجی بزرگ ` f(x,y) ` را قبل از محاسبه حداکثر ایجاد میکنیم.
388-
389- بیایید یک رویکرد کمی متفاوت را امتحان کنیم که max را به داخل میبرد.
390-
391- به دلیل این تغییر، ما هرگز آرایه دوبعدی ` f(x,y) ` را محاسبه نمیکنیم.
357+ میتوانیم هر دو مشکل را با انتقال max به داخل و پوشش دادن همه چیز در یک ` @jax.jit ` واحد برطرف کنیم:
392358
393359``` {code-cell} ipython3
394360@jax.jit
395- def compute_max_vmap_v2 (grid):
396- # یک تابع بسازید که حداکثر را در امتداد هر سطر بگیرد
361+ def compute_max_vmap (grid):
362+ # Construct a function that takes the max along each row
397363 f_vec_x_max = lambda y: jnp.max(f(grid, y))
398- # تابع را برداری کنید تا بتوانیم روی تمام سطرها همزمان فراخوانی کنیم
364+ # Vectorize the function so we can call on all rows simultaneously
399365 f_vec_max = jax.vmap(f_vec_x_max)
400- # تابع برداری شده را فراخوانی کنید و حداکثر را بگیرید
366+ # Call the vectorized function and take the max
401367 return jnp.max(f_vec_max(grid))
402368```
403369
@@ -408,24 +374,32 @@ def compute_max_vmap_v2(grid):
408374
409375ما این تابع را روی تمام سطرها اعمال میکنیم و سپس حداکثر max های سطر را میگیریم.
410376
377+ از آنجایی که max را به داخل انتقال میدهیم، هرگز آرایه دوبعدی کامل ` f(x,y) ` را نمیسازیم و حافظه بیشتری ذخیره میکنیم.
378+
379+ و از آنجایی که همه چیز زیر یک ` @jax.jit ` واحد است، کامپایلر میتواند تمام عملیات را در یک kernel بهینه ادغام کند.
380+
411381بیایید آن را امتحان کنیم.
412382
413383``` {code-cell} ipython3
414- with qe.Timer(precision=8):
415- z_max = compute_max_vmap_v2(grid).block_until_ready()
384+ with qe.Timer():
385+ # First run
386+ z_max = compute_max_vmap(grid)
387+ # Hold interpreter
388+ z_max.block_until_ready()
416389
417- print(f"JAX vmap v2 result: {z_max:.6f}")
390+ print(f"JAX vmap result: {z_max:.6f}")
418391```
419392
420393بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود:
421394
422395``` {code-cell} ipython3
423- with qe.Timer(precision=8):
424- z_max = compute_max_vmap_v2(grid).block_until_ready()
396+ with qe.Timer():
397+ # Second run
398+ z_max = compute_max_vmap(grid)
399+ # Hold interpreter
400+ z_max.block_until_ready()
425401```
426402
427- اگر این را روی GPU اجرا میکنید، همانطور که ما این کار را میکنیم، باید افزایش سرعت قابل توجه دیگری را ببینید.
428-
429403### خلاصه
430404
431405به نظر ما، JAX برنده برای عملیات برداری شده است.
@@ -467,14 +441,16 @@ def qm(x0, n, α=4.0):
467441``` {code-cell} ipython3
468442n = 10_000_000
469443
470- with qe.Timer(precision=8):
444+ with qe.Timer():
445+ # First run
471446 x = qm(0.1, n)
472447```
473448
474449بیایید دوباره اجرا کنیم تا زمان کامپایل حذف شود:
475450
476451``` {code-cell} ipython3
477- with qe.Timer(precision=8):
452+ with qe.Timer():
453+ # Second run
478454 x = qm(0.1, n)
479455```
480456
@@ -493,7 +469,7 @@ Numba این عملیات ترتیبی را به طور بسیار کارآمد
493469``` {code-cell} ipython3
494470cpu = jax.devices("cpu")[0]
495471
496- @partial(jax.jit, static_argnums=(1 ,), device=cpu)
472+ @partial(jax.jit, static_argnames=('n' ,), device=cpu)
497473def qm_jax(x0, n, α=4.0):
498474 def update(x, t):
499475 x_new = α * x * (1 - x)
@@ -506,32 +482,32 @@ def qm_jax(x0, n, α=4.0):
506482این کد خواندن آسانی ندارد اما، در اصل، ` lax.scan ` به طور مکرر ` update ` را فراخوانی میکند و بازگشتهای ` x_new ` را در یک آرایه جمع میکند.
507483
508484``` {note}
509- خوانندگان تیزبین متوجه خواهند شد که ما `device=cpu` را در decorator `jax.jit` مشخص میکنیم.
510-
511- محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهرهبرداری GPU از موازیسازی باقی میگذارد.
512-
513- در نتیجه، سربار راهاندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسبتر برای این بار کاری میکند.
514-
515- خوانندگان کنجکاو میتوانند حذف این گزینه را امتحان کنند تا ببینند چگونه عملکرد تغییر میکند.
485+ ما `device=cpu` را در decorator `jax.jit` مشخص میکنیم زیرا این محاسبه از بسیاری عملیات ترتیبی کوچک تشکیل شده است که فرصت کمی برای بهرهبرداری GPU از موازیسازی باقی میگذارد. در نتیجه، سربار راهاندازی kernel تمایل دارد روی GPU غالب شود و CPU را متناسبتر برای این بار کاری میکند.
516486```
517487
518488بیایید آن را با همان پارامترها زمانبندی کنیم:
519489
520490``` {code-cell} ipython3
521- with qe.Timer(precision=8):
522- x_jax = qm_jax(0.1, n).block_until_ready()
491+ with qe.Timer():
492+ # First run
493+ x_jax = qm_jax(0.1, n)
494+ # Hold interpreter
495+ x_jax.block_until_ready()
523496```
524497
525498بیایید دوباره اجرا کنیم تا سربار کامپایل حذف شود:
526499
527500``` {code-cell} ipython3
528- with qe.Timer(precision=8):
529- x_jax = qm_jax(0.1, n).block_until_ready()
501+ with qe.Timer():
502+ # Second run
503+ x_jax = qm_jax(0.1, n)
504+ # Hold interpreter
505+ x_jax.block_until_ready()
530506```
531507
532508JAX نیز برای این عملیات ترتیبی کاملاً کارآمد است.
533509
534- هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه میدهند، با این که Numba معمولاً (اما نه همیشه) سرعتهای کمی بهتری در عملیات کاملاً ترتیبی ارائه میدهد .
510+ هم JAX و هم Numba عملکرد قوی پس از کامپایل ارائه میدهند.
535511
536512### خلاصه
537513
@@ -545,7 +521,7 @@ JAX نیز برای این عملیات ترتیبی کاملاً کارآمد
545521
546522علاوه بر این، آرایههای تغییرناپذیر JAX به این معنی است که نمیتوانیم به سادگی عناصر آرایه را در جا بهروزرسانی کنیم و تکرار مستقیم الگوریتم مورد استفاده توسط Numba را سخت میکند.
547523
548- برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیادهسازی، و همچنین عملکرد بالا است.
524+ برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیادهسازی است.
549525
550526## توصیههای کلی
551527
0 commit comments