Skip to content

Commit 5996a80

Browse files
committed
Update translation: lectures/jax_intro.md
1 parent a7ccc94 commit 5996a80

1 file changed

Lines changed: 34 additions & 196 deletions

File tree

lectures/jax_intro.md

Lines changed: 34 additions & 196 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ translation:
1515
JAX as a NumPy Replacement: JAX به عنوان جایگزین NumPy
1616
JAX as a NumPy Replacement::Similarities: شباهت‌ها
1717
JAX as a NumPy Replacement::Differences: تفاوت‌ها
18-
JAX as a NumPy Replacement::Differences::Precision: دقت
19-
JAX as a NumPy Replacement::Differences::Immutability: تغییرناپذیری
20-
JAX as a NumPy Replacement::Differences::A workaround: راه‌حل جایگزین
18+
JAX as a NumPy Replacement::Differences::Speed!: دقت
19+
JAX as a NumPy Replacement::Differences::Precision: تغییرناپذیری
20+
JAX as a NumPy Replacement::Differences::Immutability: راه‌حل جایگزین
2121
Functional Programming: برنامه‌نویسی تابعی
2222
Functional Programming::Pure functions: توابع خالص
2323
Functional Programming::Examples: مثال‌ها
@@ -26,18 +26,12 @@ translation:
2626
Random numbers::Why explicit random state?: چرا وضعیت تصادفی صریح؟
2727
Random numbers::Why explicit random state?::NumPy's approach: رویکرد NumPy
2828
Random numbers::Why explicit random state?::JAX's approach: رویکرد JAX
29-
JIT compilation: کامپایل JIT
30-
JIT compilation::A simple example: یک مثال ساده
31-
JIT compilation::A simple example::With NumPy: با NumPy
32-
JIT compilation::A simple example::With JAX: با JAX
33-
JIT compilation::A simple example::Changing array sizes: تغییر اندازه آرایه‌ها
34-
JIT compilation::Evaluating a more complicated function: ارزیابی یک تابع پیچیده‌تر
35-
JIT compilation::Evaluating a more complicated function::With NumPy: با NumPy
36-
JIT compilation::Evaluating a more complicated function::With JAX: با JAX
37-
JIT compilation::How JIT compilation works: نحوه کار کامپایل JIT
38-
JIT compilation::Compiling the whole function: کامپایل کل تابع
39-
JIT compilation::Compiling non-pure functions: کامپایل توابع غیرخالص
40-
JIT compilation::Summary: خلاصه
29+
JIT Compilation: کامپایل JIT
30+
JIT Compilation::With NumPy: با NumPy
31+
JIT Compilation::With JAX: با JAX
32+
JIT Compilation::Compiling the Whole Function: کامپایل کل تابع
33+
JIT Compilation::How JIT compilation works: نحوه کار کامپایل JIT
34+
JIT Compilation::Compiling non-pure functions: کامپایل توابع غیرخالص
4135
Vectorization with `vmap`: برداری‌سازی با `vmap`
4236
Vectorization with `vmap`::A simple example: یک مثال ساده
4337
Vectorization with `vmap`::Combining transformations: ترکیب تبدیل‌ها
@@ -541,119 +535,17 @@ random_sum_jax(key)
541535

542536
کامپایلر just-in-time (JIT) JAX اجرا را با تولید کد ماشین کارآمد که با هم اندازه وظیفه و هم سخت‌افزار متفاوت است، تسریع می‌کند.
543537

544-
### یک مثال ساده
545-
546-
فرض کنید می‌خواهیم تابع کسینوس را در نقاط بسیاری ارزیابی کنیم.
547-
548-
```{code-cell}
549-
n = 50_000_000
550-
x = np.linspace(0, 10, n)
551-
```
552-
553-
#### با NumPy
554-
555-
بیایید ابتدا با NumPy امتحان کنیم
556-
557-
```{code-cell}
558-
with qe.Timer():
559-
y = np.cos(x)
560-
```
561-
562-
و یک بار دیگر.
563-
564-
```{code-cell}
565-
with qe.Timer():
566-
y = np.cos(x)
567-
```
568-
569-
در اینجا NumPy از یک فایل باینری از پیش ساخته شده، کامپایل شده از کد سطح پایین نوشته شده با دقت، برای اعمال کسینوس به یک آرایه از اعداد اعشاری استفاده می‌کند.
570-
571-
این فایل باینری با NumPy ارسال می‌شود.
572-
573-
#### با JAX
574-
575-
اکنون بیایید با JAX امتحان کنیم.
576-
577-
```{code-cell}
578-
x = jnp.linspace(0, 10, n)
579-
```
580-
581-
بیایید همان روش را زمان‌بندی کنیم.
582-
583-
```{code-cell}
584-
with qe.Timer():
585-
y = jnp.cos(x)
586-
jax.block_until_ready(y);
587-
```
588-
589-
```{note}
590-
در اینجا، به منظور اندازه‌گیری سرعت واقعی، از متد `block_until_ready` استفاده می‌کنیم تا مفسر را تا زمانی که نتایج محاسبات برگردانده شوند، نگه دارد.
591-
592-
این امر ضروری است زیرا JAX از ارسال ناهمزمان استفاده می‌کند که به مفسر Python اجازه می‌دهد از محاسبات عددی جلوتر برود.
593-
594-
برای کدهای زمان‌بندی نشده، می‌توانید خط حاوی `block_until_ready` را حذف کنید.
595-
```
596-
597-
598-
و بیایید دوباره آن را زمان‌بندی کنیم.
599-
600-
601-
```{code-cell}
602-
with qe.Timer():
603-
y = jnp.cos(x)
604-
jax.block_until_ready(y);
605-
```
606-
607-
روی GPU، این کد بسیار سریع‌تر از معادل NumPy آن اجرا می‌شود.
608-
609-
همچنین، معمولاً، اجرای دوم سریع‌تر از اولین اجرا به دلیل کامپایل JIT است.
610-
611-
این به این دلیل است که حتی توابع داخلی مانند `jnp.cos` نیز JIT-کامپایل می‌شوند --- و اجرای اول شامل زمان کامپایل است.
612-
613-
چرا JAX می‌خواهد توابع داخلی مانند `jnp.cos` را به جای ارائه نسخه‌های از پیش کامپایل شده، مانند NumPy، JIT-کامپایل کند؟
614-
615-
دلیل این است که کامپایلر JIT می‌خواهد روی *اندازه* آرایه در حال استفاده (و همچنین نوع داده) تخصصی شود.
616-
617-
اندازه برای تولید کد بهینه شده اهمیت دارد زیرا موازی‌سازی کارآمد نیاز به تطبیق اندازه وظیفه با سخت‌افزار موجود دارد.
618-
619-
به همین دلیل است که JAX منتظر می‌ماند تا اندازه آرایه را قبل از کامپایل ببیند --- که نیاز به یک رویکرد JIT-کامپایل شده به جای ارائه باینری‌های از پیش کامپایل شده دارد.
620-
621-
#### تغییر اندازه آرایه‌ها
622-
623-
در اینجا اندازه ورودی را تغییر می‌دهیم و زمان‌های اجرا را مشاهده می‌کنیم.
624-
625-
```{code-cell}
626-
x = jnp.linspace(0, 10, n + 1)
627-
```
628-
629-
```{code-cell}
630-
with qe.Timer():
631-
y = jnp.cos(x)
632-
jax.block_until_ready(y);
633-
```
634-
635-
636-
```{code-cell}
637-
with qe.Timer():
638-
y = jnp.cos(x)
639-
jax.block_until_ready(y);
640-
```
641-
642-
معمولاً، زمان اجرا افزایش می‌یابد و سپس دوباره کاهش می‌یابد (این روی GPU واضح‌تر خواهد بود).
538+
قدرت کامپایلر JIT JAX در ترکیب با سخت‌افزار موازی را {ref}`بالاتر <jax_speed>` دیدیم، هنگامی که `cos` را روی یک آرایه بزرگ اعمال کردیم.
643539

644-
این به این دلیل است که کامپایلر JIT روی اندازه آرایه تخصصی می‌شود تا موازی‌سازی را بهره‌برداری کند --- و از این رو کد کامپایل شده جدیدی را هنگام تغییر اندازه آرایه تولید می‌کند.
645-
646-
### ارزیابی یک تابع پیچیده‌تر
647-
648-
بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم.
540+
بیایید همان کار را با یک تابع پیچیده‌تر امتحان کنیم:
649541

650542
```{code-cell}
651543
def f(x):
652544
y = np.cos(2 * x**2) + np.sqrt(np.abs(x)) + 2 * np.sin(x**4) - x**2
653545
return y
654546
```
655547

656-
#### با NumPy
548+
### با NumPy
657549

658550
ابتدا با NumPy امتحان خواهیم کرد
659551

@@ -664,10 +556,11 @@ x = np.linspace(0, 10, n)
664556

665557
```{code-cell}
666558
with qe.Timer():
559+
# Time NumPy code
667560
y = f(x)
668561
```
669562

670-
#### با JAX
563+
### با JAX
671564

672565
اکنون بیایید دوباره با JAX امتحان کنیم.
673566

@@ -677,86 +570,36 @@ with qe.Timer():
677570
def f(x):
678571
y = jnp.cos(2 * x**2) + jnp.sqrt(jnp.abs(x)) + 2 * jnp.sin(x**4) - x**2
679572
return y
680-
```
681573
682-
اکنون بیایید آن را زمان‌بندی کنیم.
683574
684-
```{code-cell}
685575
x = jnp.linspace(0, 10, n)
686576
```
687577

578+
اکنون بیایید آن را زمان‌بندی کنیم.
579+
688580
```{code-cell}
689581
with qe.Timer():
582+
# First call
690583
y = f(x)
584+
# Hold interpreter
691585
jax.block_until_ready(y);
692586
```
693587

694588
```{code-cell}
695589
with qe.Timer():
590+
# Second call
696591
y = f(x)
592+
# Hold interpreter
697593
jax.block_until_ready(y);
698594
```
699595

700596
نتیجه مشابه مثال `cos` است --- JAX سریع‌تر است، به ویژه در اجرای دوم پس از کامپایل JIT.
701597

702-
علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم *کل* تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد.
703-
704-
### نحوه کار کامپایل JIT
705-
706-
هنگامی که `jax.jit` را به یک تابع اعمال می‌کنیم، JAX آن را *ردیابی* می‌کند: به جای اجرای فوری عملیات‌ها، دنباله عملیات‌ها را به صورت یک گراف محاسباتی ثبت می‌کند و آن گراف را به کامپایلر [XLA](https://openxla.org/xla) تحویل می‌دهد.
707-
708-
سپس XLA عملیات‌ها را در یک هسته کامپایل شده واحد بهینه‌سازی و ادغام می‌کند که متناسب با سخت‌افزار موجود (CPU، GPU، یا TPU) طراحی شده است.
709-
710-
نمودار زیر این خط لوله را برای یک تابع ساده نشان می‌دهد:
711-
712-
```{code-cell} ipython3
713-
:tags: [hide-input]
714-
715-
fig, ax = plt.subplots(figsize=(7, 2))
716-
ax.set_xlim(-0.2, 7.2)
717-
ax.set_ylim(0.2, 2.2)
718-
ax.axis('off')
719-
720-
# Boxes for pipeline stages
721-
stages = [
722-
(0.7, 1.2, "Python\nfunction"),
723-
(2.6, 1.2, "computational\ngraph"),
724-
(4.5, 1.2, "optimized\nkernel"),
725-
(6.4, 1.2, "fast\nexecution"),
726-
]
727-
728-
colors = ["#e3f2fd", "#fff9c4", "#f3e5f5", "#d4edda"]
729-
730-
for (x, y, label), color in zip(stages, colors):
731-
box = mpatches.FancyBboxPatch(
732-
(x - 0.7, y - 0.5), 1.4, 1.0,
733-
boxstyle="round,pad=0.15",
734-
facecolor=color, edgecolor="black", linewidth=1.5)
735-
ax.add_patch(box)
736-
ax.text(x, y, label, ha='center', va='center', fontsize=9)
737-
738-
# Arrows with labels
739-
arrows = [
740-
(1.4, 1.9, "trace"),
741-
(3.3, 3.8, "XLA"),
742-
(5.2, 5.7, "run"),
743-
]
744-
745-
for x_start, x_end, label in arrows:
746-
ax.annotate("", xy=(x_end, 1.2), xytext=(x_start, 1.2),
747-
arrowprops=dict(arrowstyle="->", lw=1.5, color="gray"))
748-
ax.text((x_start + x_end) / 2, 1.55, label,
749-
ha='center', fontsize=8, color='gray')
750-
751-
plt.tight_layout()
752-
plt.show()
753-
```
754-
755-
اولین فراخوانی به یک تابع JIT-کامپایل شده سربار کامپایل دارد، اما فراخوانی‌های بعدی با همان شکل‌ها و نوع‌های ورودی از کد کامپایل شده کش‌شده استفاده می‌کنند و با سرعت کامل اجرا می‌شوند.
598+
علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- می‌توانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیات‌های منفرد.
756599

757600
### کامپایل کل تابع
758601

759-
کامپایلر just-in-time (JIT) JAX می‌تواند اجرا را در درون توابع با ادغام عملیات جبر خطی در یک هسته بهینه شده واحد تسریع کند.
602+
کامپایلر just-in-time (JIT) JAX می‌تواند اجرا را در درون توابع با ادغام عملیات آرایه‌ای در یک هسته بهینه شده واحد تسریع کند.
760603

761604
بیایید این را با تابع `f` امتحان کنیم:
762605

@@ -766,21 +609,24 @@ f_jax = jax.jit(f)
766609

767610
```{code-cell}
768611
with qe.Timer():
612+
# First run
769613
y = f_jax(x)
614+
# Hold interpreter
770615
jax.block_until_ready(y);
771616
```
772617

773618
```{code-cell}
774619
with qe.Timer():
620+
# Second run
775621
y = f_jax(x)
622+
# Hold interpreter
776623
jax.block_until_ready(y);
777624
```
778625

779626
زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم و به کامپایلر اجازه دادیم به طور تهاجمی‌تری بهینه‌سازی کند.
780627

781628
برای مثال، کامپایلر می‌تواند چندین فراخوانی به شتاب‌دهنده سخت‌افزاری و ایجاد تعدادی آرایه میانی را حذف کند.
782629

783-
784630
اتفاقاً، نحو رایج‌تر هنگام هدف قرار دادن یک تابع برای کامپایلر JIT این است
785631

786632
```{code-cell} ipython3
@@ -789,6 +635,14 @@ def f(x):
789635
pass # put function body here
790636
```
791637

638+
### نحوه کار کامپایل JIT
639+
640+
هنگامی که `jax.jit` را به یک تابع اعمال می‌کنیم، JAX آن را *ردیابی* می‌کند: به جای اجرای فوری عملیات‌ها، دنباله عملیات‌ها را به صورت یک گراف محاسباتی ثبت می‌کند و آن گراف را به کامپایلر [XLA](https://openxla.org/xla) تحویل می‌دهد.
641+
642+
سپس XLA عملیات‌ها را در یک هسته کامپایل شده واحد بهینه‌سازی و ادغام می‌کند که متناسب با سخت‌افزار موجود (CPU، GPU، یا TPU) طراحی شده است.
643+
644+
اولین فراخوانی به یک تابع JIT-کامپایل شده سربار کامپایل دارد، اما فراخوانی‌های بعدی با همان شکل‌ها و نوع‌های ورودی از کد کامپایل شده کش‌شده استفاده می‌کنند و با سرعت کامل اجرا می‌شوند.
645+
792646
### کامپایل توابع غیرخالص
793647

794648
اکنون که دیدیم کامپایل JIT چقدر قدرتمند می‌تواند باشد، درک رابطه آن با توابع خالص مهم است.
@@ -837,22 +691,6 @@ f(x)
837691

838692
درس اخلاقی داستان: هنگام استفاده از JAX، توابع خالص بنویسید!
839693

840-
### خلاصه
841-
842-
اکنون می‌توانیم ببینیم که چرا هم توسعه‌دهندگان و هم کامپایلرها از توابع خالص بهره می‌برند.
843-
844-
ما توابع خالص را دوست داریم زیرا آنها
845-
846-
* به تست کمک می‌کنند: هر تابع می‌تواند به صورت جداگانه عمل کند
847-
* رفتار قطعی و از این رو تکرارپذیری را ترویج می‌کنند
848-
* از باگ‌هایی که از تغییر وضعیت مشترک ناشی می‌شوند، جلوگیری می‌کنند
849-
850-
کامپایلر توابع خالص و برنامه‌نویسی تابعی را دوست دارد زیرا
851-
852-
* وابستگی‌های داده صریح هستند، که به بهینه‌سازی محاسبات پیچیده کمک می‌کند
853-
* توابع خالص راحت‌تر قابل تمایز هستند (autodiff)
854-
* توابع خالص راحت‌تر موازی‌سازی و بهینه‌سازی می‌شوند (به وضعیت قابل تغییر مشترک وابسته نیستند)
855-
856694
## برداری‌سازی با `vmap`
857695

858696
یکی دیگر از تبدیل‌های قدرتمند JAX، `jax.vmap` است که به‌طور خودکار

0 commit comments

Comments
 (0)