Skip to content

Revise JAX intro lecture and add autodiff lecture#513

Merged
mmcky merged 5 commits intomainfrom
jax-intro-revisions
Apr 8, 2026
Merged

Revise JAX intro lecture and add autodiff lecture#513
mmcky merged 5 commits intomainfrom
jax-intro-revisions

Conversation

@jstac
Copy link
Copy Markdown
Contributor

@jstac jstac commented Apr 5, 2026

Summary

  • Bug fixes: Fixed coefficient mismatch (0.1 * x**2 vs x**2) in NumPy/JAX function comparison; updated deprecated jax.random.PRNGKey to jax.random.key throughout
  • New figures: Added code-generated PRNG key splitting tree diagram and JIT compilation pipeline diagram
  • New section: Added vmap section with mean/median example showing why Python loops are inefficient with JAX
  • Autodiff preview: Reworked the gradients section as a brief preview with forward reference to the new autodiff lecture
  • New lecture: Added "Adventures with Autodiff" (adapted from lecture-jax), covering finite differences vs symbolic vs autodiff, gradient descent with Barzilai-Borwein, and OLS regression
  • Housekeeping: Consolidated all imports into initial cell, added (jax_intro)= reference label, updated _toc.yml

Test plan

  • CI build passes (executes all notebooks)
  • Verify PRNG key splitting tree figure renders correctly
  • Verify JIT pipeline figure renders correctly
  • Verify autodiff lecture cross-references resolve
  • Check vmap example output is clear

🤖 Generated with Claude Code

jstac and others added 4 commits April 5, 2026 14:27
- Fix coefficient mismatch in NumPy vs JAX function comparison
- Update jax.random.PRNGKey to jax.random.key throughout
- Add code-generated figures (PRNG key splitting tree, JIT pipeline)
- Add vmap section with examples and transformation composition
- Rework gradients section as autodiff preview with forward reference
- Add autodiff lecture (adapted from lecture-jax) to TOC
- Consolidate all imports into initial cell

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Replace artificial sum-of-squares vmap example with mean/median statistics
- Add explanation of why Python loops are inefficient with JAX
- Move all imports to initial cell per lecture conventions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix variable name typo (z_max_numpy → z_max_numba)
- Fix vmap v2 print label
- Fix garbled em dash
- Consolidate imports to top of lecture
- Use jnp.meshgrid instead of np.meshgrid for JAX arrays
- Replace cm.jet with cm.viridis
- Qualify JAX speed claim re GPU
- Add overall recommendations section synthesizing trade-offs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@jstac
Copy link
Copy Markdown
Contributor Author

jstac commented Apr 5, 2026

Detailed changelog

jax_intro.md

Bug fixes:

  • Fixed coefficient mismatch in NumPy vs JAX function comparison (0.1 * x**2 vs x**2 — the two versions of f were computing different functions)
  • Updated jax.random.PRNGKeyjax.random.key (5 occurrences) — PRNGKey is the legacy API

New content:

  • Added code-generated PRNG key splitting tree diagram illustrating how split produces independent keys
  • Added JIT compilation pipeline diagram (Python function → trace → XLA → execution)
  • Added new vmap section with a mean/median example showing why Python loops are inefficient with JAX, and how transformations compose (jit(vmap(...)))
  • Reworked "Gradients" section into "Automatic differentiation: a preview" with forward reference to the new autodiff lecture

Housekeeping:

  • Consolidated all imports into the initial cell (was scattered across 4 locations)
  • Added (jax_intro)= reference label for cross-referencing
  • Rephrased transition text after import consolidation

autodiff.md (new lecture)

  • Adapted from lecture-jax/lectures/autodiff.md
  • Covers: finite differences vs symbolic calculus vs autodiff, differentiating through control flow and interpolation, gradient descent with Barzilai-Borwein, OLS regression example, polynomial regression exercise
  • Updated jax.random.PRNGKeyjax.random.key
  • Moved sympy import to top, dropped unused diff import
  • Removed nvidia-smi cell (GPU admonition covers this)
  • Added pip install cell and intro referencing the jax_intro lecture

numpy_vs_numba_vs_jax.md

Bug fixes:

  • Fixed variable name z_max_numpyz_max_numba for the Numba result
  • Fixed print label "JAX vmap v1" → "v2" in the vmap v2 section
  • Fixed garbled em dash (--—---)
  • Changed np.meshgridjnp.meshgrid when operating on JAX arrays (avoids silent host-to-device transfer)
  • Qualified "significantly faster due to GPU acceleration" → "significantly faster, especially on a GPU"

New content:

  • Added "Overall recommendations" synthesis section covering: JAX for vectorized work, Numba for sequential loops, lax.scan differentiability advantage, and a rule of thumb for choosing between them

Housekeeping:

  • Consolidated all imports to top (numba, lax, partial were mid-lecture)
  • Replaced cm.jet with cm.viridis (perceptually uniform, colorblind-friendly)

_toc.yml

  • Added autodiff after numpy_vs_numba_vs_jax in the "High Performance Computing" section

@jstac
Copy link
Copy Markdown
Contributor Author

jstac commented Apr 5, 2026

please also review this and merge when ready @mmcky .

please check that the figure showing the jax.jit compilation process in jax_intro came out okay.

Once these changes and those in #512 are merged, please make live

@jstac
Copy link
Copy Markdown
Contributor Author

jstac commented Apr 7, 2026

@mmcky Do you have time to look at this? It would be great to get it live.

Remove (jax_intro)= from jax_intro.md and (autodiff)= from autodiff.md.
These labels duplicate the automatic :std:doc: targets created from the
filenames, causing myst.xref_ambiguous warnings that fail CI with -W.
@mmcky
Copy link
Copy Markdown
Contributor

mmcky commented Apr 7, 2026

Reviewed all changes. The content additions look great — well-structured JAX intro revisions, solid new autodiff lecture, and useful improvements to the comparison lecture.

CI fix applied (commit 3abf22a): Removed redundant (jax_intro)= and (autodiff)= MyST labels from jax_intro.md and autodiff.md. These labels duplicated the automatic :std:doc: targets created from the filenames, causing the myst.xref_ambiguous warning that failed CI with -W.

Files reviewed:

_toc.yml — autodiff placed correctly after numpy_vs_numba_vs_jax
jax_intro.md — Import consolidation, PRNGKey→key migration, new PRNG tree/JIT pipeline diagrams, vmap section, autodiff preview all look good
autodiff.md — Clear progression (finite differences → symbolic → autodiff), good gradient descent + OLS examples, well-constructed exercise
numpy_vs_numba_vs_jax.md — Bug fixes (variable name, print label, meshgrid), viridis colormap, new "Overall recommendations" section is a useful synthesis
Note: The remaining LaTeX warnings (5 undefined hyperrefs like pyess_ex2, numba_ex4 + missing emoji glyph) are pre-existing and unrelated to this PR.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Apr 7, 2026

@github-actions github-actions bot temporarily deployed to pull request April 7, 2026 22:25 Inactive
@mmcky
Copy link
Copy Markdown
Contributor

mmcky commented Apr 8, 2026

thanks @jstac this looks good to merge.

@mmcky mmcky merged commit 05ce956 into main Apr 8, 2026
5 checks passed
@mmcky mmcky deleted the jax-intro-revisions branch April 8, 2026 00:22
@mmcky
Copy link
Copy Markdown
Contributor

mmcky commented Apr 8, 2026

✅ Translation sync completed (zh-cn)

Target repo: QuantEcon/lecture-python-programming.zh-cn
Translation PR: QuantEcon/lecture-python-programming.zh-cn#16
Files synced (4):

  • lectures/autodiff.md
  • lectures/jax_intro.md
  • lectures/numpy_vs_numba_vs_jax.md
  • lectures/_toc.yml

@mmcky
Copy link
Copy Markdown
Contributor

mmcky commented Apr 8, 2026

✅ Translation sync completed (fa)

Target repo: QuantEcon/lecture-python-programming.fa
Translation PR: QuantEcon/lecture-python-programming.fa#87
Files synced (4):

  • lectures/autodiff.md
  • lectures/jax_intro.md
  • lectures/numpy_vs_numba_vs_jax.md
  • lectures/_toc.yml

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants