Skip to content

Conversation

@Medyan-Naser
Copy link

Contributing Guidelines

Description

As requested by @hsalehipour, I added an example that demonstrates gradient-based optimization of initial conditions
to achieve various target density patterns.

This PR adds a new differentiable LBM example (examples/cfd/differentiable_lbm.py) that demonstrates inverse problem solving using automatic differentiation. The example optimizes initial conditions to achieve a target density pattern after simulation.

Note: This example requires the JAX backend for automatic differentiation. The Warp backend does not currently support gradient propagation through the stepper (see Autodiff Limitations below).


New Example: differentiable_lbm.py

What it does

The example solves an inverse problem: given a target density pattern, find the initial distribution function f_0 such that after N simulation steps, the density matches the target.

Results

Convergence:
convergence

First iteration (initial uniform density):
iteration_00000

Final iteration (optimized initial conditions):

iteration_00149

Autodiff Limitations: Warp vs JAX

Why JAX is required

This example uses JAX because XLB's Warp stepper does not propagate gradients. This is a fundamental limitation of how the Warp kernels are currently implemented.

Root cause

Warp's autodiff (wp.Tape) requires either:

  1. Automatic adjoint generation - works for simple kernels
  2. Manual @wp.func_grad implementations - required for complex kernels

XLB's stepper kernel (xlb/operator/stepper/nse_stepper.py) has characteristics that prevent automatic adjoint generation:

  • Early returns (if _boundary_id == wp.uint8(255): return)
  • Integer conditionals and mask operations
  • Warp does not throw an error when it cannot differentiate a branch; it simply returns a 0.00 gradient for those inputs.

The Macroscopic operator works because it's a simple summation kernel that Warp can auto-differentiate. But the stepper (collision + streaming) does not.

Test: test_stepper_autodiff.py

A new test script (examples/cfd/test_stepper_autodiff.py) demonstrates this limitation by running identical gradient tests on both backends:

Test Output (click to expand)
(venv) medy@medy-AI:~/projects/open_source/XLB$ python examples/cfd/test_stepper_autodiff.py 

======================================================================
XLB STEPPER AUTODIFF TEST
======================================================================

This test checks if gradients propagate through the LBM stepper.
We run the SAME test on both JAX and Warp backends and compare.

Warp 1.11.0 initialized:
   CUDA Toolkit 12.9, Driver 12.8
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "NVIDIA GeForce RTX 3060 Ti" (8 GiB, sm_86, mempool enabled)
   Kernel cache:
     /home/medy/.cache/warp/1.11.0
Warp DeprecationWarning: The symbol `warp.utils.ScopedTimer` will soon be removed from the public API. Use `warp.ScopedTimer` instead.
----------------------------------------------------------------------
TEST CONFIGURATION
----------------------------------------------------------------------
  Grid shape:       (32, 32)
  Omega:            1.8
  Precision:        FP32FP32
  Boundary:         FullwayBounceBackBC (walls)
  Collision:        BGK
  Test:             Forward 1 step -> Compute rho -> MSE Loss -> Backward

Warp DeprecationWarning: The symbol `warp.mat` will soon be removed from the public API. Use `warp.types.matrix` instead.
Warp DeprecationWarning: The symbol `warp.vec` will soon be removed from the public API. Use `warp.types.vector` instead.
registered bc FullwayBounceBackBC_8537495545848 with id 1
Module xlb.operator.equilibrium.quadratic_equilibrium f89485b load on device 'cuda:0' took 0.22 ms  (cached)
Module xlb.operator.boundary_masker.indices_boundary_masker 2a27561 load on device 'cuda:0' took 0.28 ms  (cached)
Single-GPU support is available: 1 GPU detected.
registered bc FullwayBounceBackBC_8537491724211 with id 2
Module xlb.operator.stepper.nse_stepper c5d2e58 load on device 'cuda:0' took 0.21 ms  (cached)
Module xlb.operator.macroscopic.macroscopic 58bc72f load on device 'cuda:0' took 0.17 ms  (cached)
Module __main__ 2f56003 load on device 'cuda:0' took 0.89 ms  (cached)
======================================================================
RESULTS: SIDE-BY-SIDE COMPARISON
======================================================================

Metric                              WARP            JAX            
-----------------------------------------------------------------
Loss value                          1024.0000       1024.0000      
d(Loss)/d(f_input) gradient norm    0.0000          192.0000       

----------------------------------------------------------------------
GRADIENT FLOW ANALYSIS (Warp only - to debug where gradients stop)
----------------------------------------------------------------------

  In Warp, we can check gradients at each stage of the computation:

    1. loss.grad (set manually)        : 1.0 (seed)
    2. d(loss)/d(rho) gradient norm    : 64.0000
    3. d(loss)/d(f_out) gradient norm  : 192.0000
    4. d(loss)/d(f_in) gradient norm   : 0.0000  <-- THIS IS THE PROBLEM

  Gradient flows: loss -> rho -> f_out (through Macroscopic) ✓
  Gradient STOPS: f_out -> f_in (through Stepper) ✗

Key finding from test:

Metric Warp JAX
Loss value 1024.00 1024.00
d(Loss)/d(f_input) gradient norm 0.00 192.00

The gradient flow analysis shows exactly where gradients stop in Warp:

Gradient flows: loss -> rho -> f_out (through Macroscopic) ✓
Gradient STOPS: f_out -> f_in (through Stepper) ✗

What would be needed to fix Warp autodiff

To enable Warp autodiff through the stepper, XLB would need @wp.func_grad adjoint implementations

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Documentation update

How Has This Been Tested?

  • All pytest tests pass

Linting and Code Formatting

Make sure the code follows the project's linting and formatting standards. This project uses Ruff for linting.

To run Ruff, execute the following command from the root of the repository:

ruff check .
  • Ruff passes

@Medyan-Naser Medyan-Naser changed the title Add example differentiable feat: add Differentiable LBM example for inverse problem solving (JAX) Jan 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant