feat: add Differentiable LBM example for inverse problem solving (JAX) #156
+819
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Contributing Guidelines
Description
As requested by @hsalehipour, I added an example that demonstrates gradient-based optimization of initial conditions
to achieve various target density patterns.
This PR adds a new differentiable LBM example (
examples/cfd/differentiable_lbm.py) that demonstrates inverse problem solving using automatic differentiation. The example optimizes initial conditions to achieve a target density pattern after simulation.Note: This example requires the JAX backend for automatic differentiation. The Warp backend does not currently support gradient propagation through the stepper (see Autodiff Limitations below).
New Example:
differentiable_lbm.pyWhat it does
The example solves an inverse problem: given a target density pattern, find the initial distribution function
f_0such that afterNsimulation steps, the density matches the target.Results
Convergence:

First iteration (initial uniform density):

Final iteration (optimized initial conditions):
Autodiff Limitations: Warp vs JAX
Why JAX is required
This example uses JAX because XLB's Warp stepper does not propagate gradients. This is a fundamental limitation of how the Warp kernels are currently implemented.
Root cause
Warp's autodiff (
wp.Tape) requires either:@wp.func_gradimplementations - required for complex kernelsXLB's stepper kernel (
xlb/operator/stepper/nse_stepper.py) has characteristics that prevent automatic adjoint generation:if _boundary_id == wp.uint8(255): return)The
Macroscopicoperator works because it's a simple summation kernel that Warp can auto-differentiate. But the stepper (collision + streaming) does not.Test:
test_stepper_autodiff.pyA new test script (
examples/cfd/test_stepper_autodiff.py) demonstrates this limitation by running identical gradient tests on both backends:Test Output (click to expand)
Key finding from test:
The gradient flow analysis shows exactly where gradients stop in Warp:
What would be needed to fix Warp autodiff
To enable Warp autodiff through the stepper, XLB would need
@wp.func_gradadjoint implementationsType of change
How Has This Been Tested?
Linting and Code Formatting
Make sure the code follows the project's linting and formatting standards. This project uses Ruff for linting.
To run Ruff, execute the following command from the root of the repository:
ruff check .