Skip to content

Commit 18662e6

Browse files
committed
Update translation: lectures/numpy_vs_numba_vs_jax.md
1 parent aeed159 commit 18662e6

1 file changed

Lines changed: 63 additions & 46 deletions

File tree

lectures/numpy_vs_numba_vs_jax.md

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ translation:
1313
Vectorized operations: عملیات برداری شده
1414
Vectorized operations::Problem Statement: بیان مسئله
1515
Vectorized operations::NumPy vectorization: برداری‌سازی NumPy
16+
Vectorized operations::Memory Issues: مشکلات حافظه
1617
Vectorized operations::A Comparison with Numba: مقایسه با Numba
1718
Vectorized operations::Parallelized Numba: Numba موازی شده
1819
Vectorized operations::Vectorized code with JAX: کد برداری شده با JAX
@@ -146,16 +147,33 @@ for x in grid:
146147

147148
بیایید به NumPy تغییر دهیم و از یک شبکه بزرگتر استفاده کنیم
148149

150+
```{code-cell} ipython3
151+
grid = np.linspace(-3, 3, 3_000) # Large grid
152+
```
153+
154+
به عنوان اولین گام برداری‌سازی ممکن است چیزی شبیه به این امتحان کنیم
155+
156+
```{code-cell} ipython3
157+
# Large grid
158+
z = np.max(f(grid, grid)) # This is wrong!
159+
```
160+
161+
مشکل اینجا این است که `f(grid, grid)` از حلقه تودرتو پیروی نمی‌کند.
162+
163+
از نظر شکل بالا، این کد فقط مقادیر `f` را روی قطر محاسبه می‌کند.
164+
165+
برای اینکه NumPy را مجبور کنیم `f(x,y)` را روی هر جفت `x,y` محاسبه کند، باید از `np.meshgrid` استفاده کنیم.
166+
149167
در اینجا از `np.meshgrid` برای ایجاد شبکه‌های ورودی دوبعدی `x` و `y` استفاده می‌کنیم به گونه‌ای که `f(x, y)` تمام ارزیابی‌ها را روی شبکه حاصلضرب تولید می‌کند.
150168

151169
```{code-cell} ipython3
152170
# Large grid
153171
grid = np.linspace(-3, 3, 3_000)
154172
155-
x, y = np.meshgrid(grid, grid) # MATLAB style meshgrid
173+
x_mesh, y_mesh = np.meshgrid(grid, grid) # MATLAB style meshgrid
156174
157175
with qe.Timer():
158-
z_max_numpy = np.max(f(x, y))
176+
z_max_numpy = np.max(f(x_mesh, y_mesh)) # This works
159177
```
160178

161179
در نسخه برداری شده، تمام حلقه‌ها در کد کامپایل شده انجام می‌شوند.
@@ -168,9 +186,29 @@ with qe.Timer():
168186
print(f"NumPy result: {z_max_numpy:.6f}")
169187
```
170188

189+
### مشکلات حافظه
190+
191+
پس ما راه‌حل صحیح را در زمان معقول داریم --- اما مصرف حافظه بسیار زیاد است.
192+
193+
در حالی که آرایه‌های تخت حافظه کمی دارند
194+
195+
```{code-cell} ipython3
196+
grid.nbytes
197+
```
198+
199+
شبکه‌های mesh دوبعدی هستند و از این رو از نظر حافظه بسیار فشرده‌اند
200+
201+
```{code-cell} ipython3
202+
x_mesh.nbytes + y_mesh.nbytes
203+
```
204+
205+
علاوه بر این، اجرای بلادرنگ NumPy آرایه‌های میانی زیادی با همان اندازه ایجاد می‌کند!
206+
207+
این نوع مصرف حافظه می‌تواند یک مشکل بزرگ در محاسبات تحقیقاتی واقعی باشد.
208+
171209
### مقایسه با Numba
172210

173-
حالا بیایید ببینیم آیا می‌توانیم با استفاده از Numba با یک حلقه ساده به عملکرد بهتری دست یابیم.
211+
بیایید ببینیم آیا می‌توانیم با استفاده از Numba با یک حلقه ساده به عملکرد بهتری دست یابیم.
174212

175213
```{code-cell} ipython3
176214
@numba.jit
@@ -201,13 +239,13 @@ with qe.Timer():
201239
compute_max_numba(grid)
202240
```
203241

204-
بسته به دستگاه شما، نسخه Numba ممکن است کندتر یا سریعتر از NumPy باشد.
242+
توجه کنید که تقریباً هیچ حافظه‌ای استفاده نمی‌کنیم --- فقط به `grid` یک‌بعدی نیاز داریم.
205243

206-
در اکثر موارد، Numba کمی بهتر است.
244+
علاوه بر این، سرعت اجرا خوب است.
207245

208-
از یک طرف، NumPy محاسبات کارآمد را با مقداری چندنخی ترکیب می‌کند که مزیتی فراهم می‌کند.
246+
در اکثر دستگاه‌ها، نسخه Numba تا حدودی سریعتر از NumPy خواهد بود.
209247

210-
از طرف دیگر، روال Numba از حافظه بسیار کمتری استفاده می‌کند، زیرا ما فقط با یک شبکه یک‌بعدی کار می‌کنیم.
248+
دلیل آن کد ماشین کارآمد به علاوه خواندن و نوشتن کمتر حافظه است.
211249

212250
### Numba موازی شده
213251

@@ -301,39 +339,25 @@ with qe.Timer():
301339

302340
### JAX به علاوه vmap
303341

304-
یک مشکل با کد NumPy و کد JAX وجود دارد:
305-
306-
در حالی که آرایه‌های تخت حافظه کمی دارند
307-
308-
```{code-cell} ipython3
309-
grid.nbytes
310-
```
311-
312-
شبکه‌های mesh فشرده از نظر حافظه هستند
342+
چون از `jax.jit` در بالا استفاده کردیم، از ایجاد بسیاری از آرایه‌های میانی اجتناب کردیم.
313343

314-
```{code-cell} ipython3
315-
x_mesh.nbytes + y_mesh.nbytes
316-
```
344+
اما همچنان آرایه‌های بزرگ `z_max`، `x_mesh` و `y_mesh` را ایجاد می‌کنیم.
317345

318-
این استفاده اضافی از حافظه می‌تواند یک مشکل بزرگ در محاسبات تحقیقاتی واقعی باشد.
319-
320-
خوشبختانه، JAX رویکرد متفاوتی را با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) می‌پذیرد.
321-
322-
ایده `vmap` این است که برداری‌سازی را به مراحل تقسیم کند و تابعی که روی مقادیر تکی عمل می‌کند را به تابعی تبدیل کند که روی آرایه‌ها عمل می‌کند.
346+
خوشبختانه، می‌توانیم با استفاده از [jax.vmap](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html) از این اجتناب کنیم.
323347

324348
در اینجا نحوه اعمال آن به مسئله ما آمده است.
325349

326350
```{code-cell} ipython3
327351
@jax.jit
328352
def compute_max_vmap(grid):
329353
# Construct a function that takes the max over all x for given y
330-
f_vec_x_max = lambda y: jnp.max(f(grid, y))
354+
compute_column_max = lambda y: jnp.max(f(grid, y))
331355
# Vectorize the function so we can call on all y simultaneously
332-
f_vec_max = jax.vmap(f_vec_x_max)
333-
# Compute the max across x at every y
334-
maxes = f_vec_max(grid)
335-
# Compute the max of the maxes and return
336-
return jnp.max(maxes)
356+
vectorized_compute_column_max = jax.vmap(compute_column_max)
357+
# Compute the column max at every row
358+
column_maxes = vectorized_compute_column_max(grid)
359+
# Compute the max of the column maxes and return
360+
return jnp.max(column_maxes)
337361
```
338362

339363
توجه کنید که هرگز
@@ -344,6 +368,8 @@ def compute_max_vmap(grid):
344368

345369
را نمی‌سازیم.
346370

371+
مانند Numba، فقط از آرایه تخت `grid` استفاده می‌کنیم.
372+
347373
و چون همه چیز زیر یک `@jax.jit` واحد قرار دارد، کامپایلر می‌تواند تمام عملیات را در یک kernel بهینه ادغام کند.
348374

349375
بیایید آن را امتحان کنیم.
@@ -374,13 +400,11 @@ with qe.Timer():
374400

375401
هم از نظر سرعت (از طریق JIT-compilation و موازی‌سازی) و هم از نظر کارایی حافظه (از طریق vmap) بر NumPy غلبه می‌کند.
376402

377-
علاوه بر این، رویکرد `vmap` گاهی اوقات می‌تواند منجر به کد به طور قابل توجهی واضح‌تری شود.
403+
همچنین هنگام اجرا روی GPU بر Numba نیز غلبه می‌کند.
378404

379-
در حالی که Numba چشمگیر است، زیبایی JAX این است که با عملیات کاملاً برداری شده، می‌توانیم دقیقاً همان کد را روی دستگاه‌های با شتاب‌دهنده سخت‌افزاری اجرا کنیم و بدون تلاش اضافی از تمام مزایا بهره‌مند شویم.
380-
381-
علاوه بر این، JAX قبلاً می‌داند چگونه بسیاری از عملیات آرایه رایج را به طور مؤثر موازی کند، که کلید اجرای سریع است.
382-
383-
برای اکثر موارد مواجه شده در اقتصاد، اقتصادسنجی و امور مالی، بسیار بهتر است که برای موازی‌سازی کارآمد به کامپایلر JAX تحویل دهیم تا اینکه سعی کنیم این روال‌ها را خودمان کدنویسی دستی کنیم.
405+
```{note}
406+
Numba می‌تواند برنامه‌نویسی GPU را از طریق `numba.cuda` پشتیبانی کند، اما در آن صورت باید موازی‌سازی را به صورت دستی انجام دهیم. برای اکثر موارد مواجه شده در اقتصاد، اقتصادسنجی و امور مالی، بسیار بهتر است که برای موازی‌سازی کارآمد به کامپایلر JAX تحویل دهیم تا اینکه سعی کنیم این روال‌ها را خودمان به صورت دستی کدنویسی کنیم.
407+
```
384408

385409
## عملیات ترتیبی
386410

@@ -530,8 +554,6 @@ with qe.Timer():
530554

531555
در حالی که سینتکس `at[t].set` در JAX به‌روزرسانی عنصر به عنصر را ممکن می‌سازد، کد کلی همچنان سخت‌تر از معادل Numba برای خواندن است.
532556

533-
برای این نوع عملیات ترتیبی، Numba برنده واضح از نظر وضوح کد و سهولت پیاده‌سازی است.
534-
535557
## توصیه‌های کلی
536558

537559
حال قدمی به عقب بر می‌داریم و مبادلات را خلاصه می‌کنیم.
@@ -544,17 +566,12 @@ with qe.Timer():
544566

545567
علاوه بر این، توابع JAX به‌صورت خودکار مشتق‌پذیر هستند، همان‌طور که در {doc}`autodiff` بررسی می‌کنیم.
546568

547-
برای **عملیات ترتیبی**، Numba مزایای آشکاری دارد.
569+
برای **عملیات ترتیبی**، Numba نحو بهتری دارد.
548570

549571
کد طبیعی و خوانا است --- صرفاً یک حلقه پایتون با یک decorator --- و کارایی آن عالی است.
550572

551573
JAX می‌تواند مسائل ترتیبی را از طریق `lax.fori_loop` یا `lax.scan` مدیریت کند، اما نحو آن کمتر شهودی است.
552574

553-
```{note}
554-
یک مزیت مهم `lax.fori_loop` و `lax.scan` این است که از مشتق‌گیری خودکار در طول حلقه پشتیبانی می‌کنند، که Numba قادر به انجام آن نیست.
555-
اگر نیاز دارید از طریق یک محاسبه ترتیبی مشتق بگیرید (مثلاً محاسبه حساسیت‌های یک مسیر نسبت به پارامترهای مدل)، JAX علی‌رغم نحو کمتر طبیعی‌اش، انتخاب بهتری است.
556-
```
557-
558-
در عمل، بسیاری از مسائل ترکیبی از هر دو الگو هستند.
575+
از سوی دیگر، نسخه‌های JAX از مشتق‌گیری خودکار پشتیبانی می‌کنند.
559576

560-
یک قاعده سرانگشتی مناسب: برای پروژه‌های جدید، به‌ویژه زمانی که شتاب‌دهی سخت‌افزاری یا مشتق‌پذیری ممکن است مفید باشد، به‌طور پیش‌فرض از JAX استفاده کنید، و هنگامی که یک حلقه ترتیبی فشرده نیاز به سرعت و خوانایی دارد، به Numba متوسل شوید.
577+
این ممکن است جالب توجه باشد اگر، برای مثال، بخواهیم حساسیت‌های یک مسیر را نسبت به پارامترهای مدل محاسبه کنیم.

0 commit comments

Comments
 (0)