You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX as a NumPy Replacement: JAX به عنوان جایگزین NumPy
16
16
JAX as a NumPy Replacement::Similarities: شباهتها
17
17
JAX as a NumPy Replacement::Differences: تفاوتها
18
-
JAX as a NumPy Replacement::Differences::Precision: دقت
19
-
JAX as a NumPy Replacement::Differences::Immutability: تغییرناپذیری
20
-
JAX as a NumPy Replacement::Differences::A workaround: راهحل جایگزین
18
+
JAX as a NumPy Replacement::Differences::Speed!: دقت
19
+
JAX as a NumPy Replacement::Differences::Precision: تغییرناپذیری
20
+
JAX as a NumPy Replacement::Differences::Immutability: راهحل جایگزین
21
21
Functional Programming: برنامهنویسی تابعی
22
22
Functional Programming::Pure functions: توابع خالص
23
23
Functional Programming::Examples: مثالها
@@ -26,18 +26,12 @@ translation:
26
26
Random numbers::Why explicit random state?: چرا وضعیت تصادفی صریح؟
27
27
Random numbers::Why explicit random state?::NumPy's approach: رویکرد NumPy
28
28
Random numbers::Why explicit random state?::JAX's approach: رویکرد JAX
29
-
JIT compilation: کامپایل JIT
30
-
JIT compilation::A simple example: یک مثال ساده
31
-
JIT compilation::A simple example::With NumPy: با NumPy
32
-
JIT compilation::A simple example::With JAX: با JAX
33
-
JIT compilation::A simple example::Changing array sizes: تغییر اندازه آرایهها
34
-
JIT compilation::Evaluating a more complicated function: ارزیابی یک تابع پیچیدهتر
35
-
JIT compilation::Evaluating a more complicated function::With NumPy: با NumPy
36
-
JIT compilation::Evaluating a more complicated function::With JAX: با JAX
37
-
JIT compilation::How JIT compilation works: نحوه کار کامپایل JIT
38
-
JIT compilation::Compiling the whole function: کامپایل کل تابع
39
-
JIT compilation::Compiling non-pure functions: کامپایل توابع غیرخالص
40
-
JIT compilation::Summary: خلاصه
29
+
JIT Compilation: کامپایل JIT
30
+
JIT Compilation::With NumPy: با NumPy
31
+
JIT Compilation::With JAX: با JAX
32
+
JIT Compilation::Compiling the Whole Function: کامپایل کل تابع
33
+
JIT Compilation::How JIT compilation works: نحوه کار کامپایل JIT
34
+
JIT Compilation::Compiling non-pure functions: کامپایل توابع غیرخالص
41
35
Vectorization with `vmap`: برداریسازی با `vmap`
42
36
Vectorization with `vmap`::A simple example: یک مثال ساده
43
37
Vectorization with `vmap`::Combining transformations: ترکیب تبدیلها
@@ -541,119 +535,17 @@ random_sum_jax(key)
541
535
542
536
کامپایلر just-in-time (JIT) JAX اجرا را با تولید کد ماشین کارآمد که با هم اندازه وظیفه و هم سختافزار متفاوت است، تسریع میکند.
543
537
544
-
### یک مثال ساده
545
-
546
-
فرض کنید میخواهیم تابع کسینوس را در نقاط بسیاری ارزیابی کنیم.
547
-
548
-
```{code-cell}
549
-
n = 50_000_000
550
-
x = np.linspace(0, 10, n)
551
-
```
552
-
553
-
#### با NumPy
554
-
555
-
بیایید ابتدا با NumPy امتحان کنیم
556
-
557
-
```{code-cell}
558
-
with qe.Timer():
559
-
y = np.cos(x)
560
-
```
561
-
562
-
و یک بار دیگر.
563
-
564
-
```{code-cell}
565
-
with qe.Timer():
566
-
y = np.cos(x)
567
-
```
568
-
569
-
در اینجا NumPy از یک فایل باینری از پیش ساخته شده، کامپایل شده از کد سطح پایین نوشته شده با دقت، برای اعمال کسینوس به یک آرایه از اعداد اعشاری استفاده میکند.
570
-
571
-
این فایل باینری با NumPy ارسال میشود.
572
-
573
-
#### با JAX
574
-
575
-
اکنون بیایید با JAX امتحان کنیم.
576
-
577
-
```{code-cell}
578
-
x = jnp.linspace(0, 10, n)
579
-
```
580
-
581
-
بیایید همان روش را زمانبندی کنیم.
582
-
583
-
```{code-cell}
584
-
with qe.Timer():
585
-
y = jnp.cos(x)
586
-
jax.block_until_ready(y);
587
-
```
588
-
589
-
```{note}
590
-
در اینجا، به منظور اندازهگیری سرعت واقعی، از متد `block_until_ready` استفاده میکنیم تا مفسر را تا زمانی که نتایج محاسبات برگردانده شوند، نگه دارد.
591
-
592
-
این امر ضروری است زیرا JAX از ارسال ناهمزمان استفاده میکند که به مفسر Python اجازه میدهد از محاسبات عددی جلوتر برود.
593
-
594
-
برای کدهای زمانبندی نشده، میتوانید خط حاوی `block_until_ready` را حذف کنید.
595
-
```
596
-
597
-
598
-
و بیایید دوباره آن را زمانبندی کنیم.
599
-
600
-
601
-
```{code-cell}
602
-
with qe.Timer():
603
-
y = jnp.cos(x)
604
-
jax.block_until_ready(y);
605
-
```
606
-
607
-
روی GPU، این کد بسیار سریعتر از معادل NumPy آن اجرا میشود.
608
-
609
-
همچنین، معمولاً، اجرای دوم سریعتر از اولین اجرا به دلیل کامپایل JIT است.
610
-
611
-
این به این دلیل است که حتی توابع داخلی مانند `jnp.cos` نیز JIT-کامپایل میشوند --- و اجرای اول شامل زمان کامپایل است.
612
-
613
-
چرا JAX میخواهد توابع داخلی مانند `jnp.cos` را به جای ارائه نسخههای از پیش کامپایل شده، مانند NumPy، JIT-کامپایل کند؟
614
-
615
-
دلیل این است که کامپایلر JIT میخواهد روی *اندازه* آرایه در حال استفاده (و همچنین نوع داده) تخصصی شود.
616
-
617
-
اندازه برای تولید کد بهینه شده اهمیت دارد زیرا موازیسازی کارآمد نیاز به تطبیق اندازه وظیفه با سختافزار موجود دارد.
618
-
619
-
به همین دلیل است که JAX منتظر میماند تا اندازه آرایه را قبل از کامپایل ببیند --- که نیاز به یک رویکرد JIT-کامپایل شده به جای ارائه باینریهای از پیش کامپایل شده دارد.
620
-
621
-
#### تغییر اندازه آرایهها
622
-
623
-
در اینجا اندازه ورودی را تغییر میدهیم و زمانهای اجرا را مشاهده میکنیم.
624
-
625
-
```{code-cell}
626
-
x = jnp.linspace(0, 10, n + 1)
627
-
```
628
-
629
-
```{code-cell}
630
-
with qe.Timer():
631
-
y = jnp.cos(x)
632
-
jax.block_until_ready(y);
633
-
```
634
-
635
-
636
-
```{code-cell}
637
-
with qe.Timer():
638
-
y = jnp.cos(x)
639
-
jax.block_until_ready(y);
640
-
```
641
-
642
-
معمولاً، زمان اجرا افزایش مییابد و سپس دوباره کاهش مییابد (این روی GPU واضحتر خواهد بود).
538
+
قدرت کامپایلر JIT JAX در ترکیب با سختافزار موازی را {ref}`بالاتر <jax_speed>` دیدیم، هنگامی که `cos` را روی یک آرایه بزرگ اعمال کردیم.
643
539
644
-
این به این دلیل است که کامپایلر JIT روی اندازه آرایه تخصصی میشود تا موازیسازی را بهرهبرداری کند --- و از این رو کد کامپایل شده جدیدی را هنگام تغییر اندازه آرایه تولید میکند.
645
-
646
-
### ارزیابی یک تابع پیچیدهتر
647
-
648
-
بیایید همان کار را با یک تابع پیچیدهتر امتحان کنیم.
540
+
بیایید همان کار را با یک تابع پیچیدهتر امتحان کنیم:
نتیجه مشابه مثال `cos` است --- JAX سریعتر است، به ویژه در اجرای دوم پس از کامپایل JIT.
701
597
702
-
علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- میتوانیم *کل* تابع را JIT-کامپایل کنیم، نه فقط عملیاتهای منفرد.
703
-
704
-
### نحوه کار کامپایل JIT
705
-
706
-
هنگامی که `jax.jit` را به یک تابع اعمال میکنیم، JAX آن را *ردیابی* میکند: به جای اجرای فوری عملیاتها، دنباله عملیاتها را به صورت یک گراف محاسباتی ثبت میکند و آن گراف را به کامپایلر [XLA](https://openxla.org/xla) تحویل میدهد.
707
-
708
-
سپس XLA عملیاتها را در یک هسته کامپایل شده واحد بهینهسازی و ادغام میکند که متناسب با سختافزار موجود (CPU، GPU، یا TPU) طراحی شده است.
709
-
710
-
نمودار زیر این خط لوله را برای یک تابع ساده نشان میدهد:
اولین فراخوانی به یک تابع JIT-کامپایل شده سربار کامپایل دارد، اما فراخوانیهای بعدی با همان شکلها و نوعهای ورودی از کد کامپایل شده کششده استفاده میکنند و با سرعت کامل اجرا میشوند.
598
+
علاوه بر این، با JAX، ترفند دیگری در آستین داریم --- میتوانیم کل تابع را JIT-کامپایل کنیم، نه فقط عملیاتهای منفرد.
756
599
757
600
### کامپایل کل تابع
758
601
759
-
کامپایلر just-in-time (JIT) JAX میتواند اجرا را در درون توابع با ادغام عملیات جبر خطی در یک هسته بهینه شده واحد تسریع کند.
602
+
کامپایلر just-in-time (JIT) JAX میتواند اجرا را در درون توابع با ادغام عملیات آرایهای در یک هسته بهینه شده واحد تسریع کند.
760
603
761
604
بیایید این را با تابع `f` امتحان کنیم:
762
605
@@ -766,21 +609,24 @@ f_jax = jax.jit(f)
766
609
767
610
```{code-cell}
768
611
with qe.Timer():
612
+
# First run
769
613
y = f_jax(x)
614
+
# Hold interpreter
770
615
jax.block_until_ready(y);
771
616
```
772
617
773
618
```{code-cell}
774
619
with qe.Timer():
620
+
# Second run
775
621
y = f_jax(x)
622
+
# Hold interpreter
776
623
jax.block_until_ready(y);
777
624
```
778
625
779
626
زمان اجرا دوباره بهبود یافته است --- اکنون به این دلیل که تمام عملیات را ادغام کردیم و به کامپایلر اجازه دادیم به طور تهاجمیتری بهینهسازی کند.
780
627
781
628
برای مثال، کامپایلر میتواند چندین فراخوانی به شتابدهنده سختافزاری و ایجاد تعدادی آرایه میانی را حذف کند.
782
629
783
-
784
630
اتفاقاً، نحو رایجتر هنگام هدف قرار دادن یک تابع برای کامپایلر JIT این است
785
631
786
632
```{code-cell} ipython3
@@ -789,6 +635,14 @@ def f(x):
789
635
pass # put function body here
790
636
```
791
637
638
+
### نحوه کار کامپایل JIT
639
+
640
+
هنگامی که `jax.jit` را به یک تابع اعمال میکنیم، JAX آن را *ردیابی* میکند: به جای اجرای فوری عملیاتها، دنباله عملیاتها را به صورت یک گراف محاسباتی ثبت میکند و آن گراف را به کامپایلر [XLA](https://openxla.org/xla) تحویل میدهد.
641
+
642
+
سپس XLA عملیاتها را در یک هسته کامپایل شده واحد بهینهسازی و ادغام میکند که متناسب با سختافزار موجود (CPU، GPU، یا TPU) طراحی شده است.
643
+
644
+
اولین فراخوانی به یک تابع JIT-کامپایل شده سربار کامپایل دارد، اما فراخوانیهای بعدی با همان شکلها و نوعهای ورودی از کد کامپایل شده کششده استفاده میکنند و با سرعت کامل اجرا میشوند.
645
+
792
646
### کامپایل توابع غیرخالص
793
647
794
648
اکنون که دیدیم کامپایل JIT چقدر قدرتمند میتواند باشد، درک رابطه آن با توابع خالص مهم است.
@@ -837,22 +691,6 @@ f(x)
837
691
838
692
درس اخلاقی داستان: هنگام استفاده از JAX، توابع خالص بنویسید!
839
693
840
-
### خلاصه
841
-
842
-
اکنون میتوانیم ببینیم که چرا هم توسعهدهندگان و هم کامپایلرها از توابع خالص بهره میبرند.
843
-
844
-
ما توابع خالص را دوست داریم زیرا آنها
845
-
846
-
* به تست کمک میکنند: هر تابع میتواند به صورت جداگانه عمل کند
847
-
* رفتار قطعی و از این رو تکرارپذیری را ترویج میکنند
848
-
* از باگهایی که از تغییر وضعیت مشترک ناشی میشوند، جلوگیری میکنند
849
-
850
-
کامپایلر توابع خالص و برنامهنویسی تابعی را دوست دارد زیرا
851
-
852
-
* وابستگیهای داده صریح هستند، که به بهینهسازی محاسبات پیچیده کمک میکند
853
-
* توابع خالص راحتتر قابل تمایز هستند (autodiff)
854
-
* توابع خالص راحتتر موازیسازی و بهینهسازی میشوند (به وضعیت قابل تغییر مشترک وابسته نیستند)
855
-
856
694
## برداریسازی با `vmap`
857
695
858
696
یکی دیگر از تبدیلهای قدرتمند JAX، `jax.vmap` است که بهطور خودکار
0 commit comments