Skip to content

Commit 4fb7e89

Browse files
leobois67bakpaulhugtalbot
authored
Example of a ForceField implemented with JAX (#557)
* Add the example * Igne jax example on main CI * Fix ignoring the jax examples * Fix addKToMatrix() with different options * Apply suggestions from code review --------- Co-authored-by: Paul Baksic <paul.baksic@outlook.fr> Co-authored-by: Hugo <hugo.talbot@sofa-framework.org>
1 parent b3a79c6 commit 4fb7e89

2 files changed

Lines changed: 183 additions & 0 deletions

File tree

examples/.scene-tests

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ iterations "access_stiffness_matrix.py" "3"
1010

1111
# Ignore additional examples
1212
ignore "additional-examples/.*"
13+
14+
# Ignore jax examples
15+
ignore "jax/*"

examples/jax/forcefield.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
Toy example of a force field leveraging autodiff with JAX.
3+
4+
JAX can be installed via e.g. `pip install -U jax[cuda12]`
5+
"""
6+
import jax
7+
import jax.numpy as jnp
8+
import numpy as np
9+
10+
import Sofa
11+
12+
13+
# Some configuration for JAX: device and precision
14+
# jax.config.update('jax_default_device', jax.devices('cpu')[0])
15+
jax.config.update("jax_default_device", jax.devices("gpu")[0]) # default "gpu"
16+
jax.config.update("jax_enable_x64", True) # default False (ie use float32)
17+
18+
19+
@jax.jit # JIT (just-in-time compilation) for better performance
20+
def get_force(position, length, stiffness):
21+
"""
22+
Spring between the origin and the given position.
23+
24+
position: array of shape (n_particles, n_dimensions)
25+
length: scalar or array of shape (n_particles, 1)
26+
stiffness: scalar or array of shape (n_particles, 1)
27+
"""
28+
distance = jnp.sqrt(jnp.sum(position**2, axis=1, keepdims=True))
29+
direction = position / distance
30+
return - stiffness * (distance - length) * direction
31+
32+
33+
@jax.jit # JIT (just-in-time compilation) for better performance
34+
def get_dforce(position, length, stiffness, vector):
35+
"""
36+
Compute the jacobian-vector product (jvp) using autodiff
37+
"""
38+
def get_force_from_position(x):
39+
return get_force(x, length, stiffness)
40+
# Differentiate get_force() as a function of the position
41+
return jax.jvp(get_force_from_position, (position,), (vector,))[1]
42+
43+
44+
@jax.jit # JIT (just-in-time compilation) for better performance
45+
def get_kmatrix(position, length, stiffness):
46+
"""
47+
Compute the jacobian using autodiff
48+
49+
Warning: The jacobian computed this way is a dense matrix.
50+
Check `sparsejac` if you are interested in sparse jacobian with JAX.
51+
"""
52+
def get_force_from_position(x):
53+
return get_force(x, length, stiffness)
54+
# Differentiate get_force() as a function of the position
55+
return jax.jacrev(get_force_from_position)(position)
56+
57+
58+
class JaxForceField(Sofa.Core.ForceFieldVec3d):
59+
60+
def __init__(self, length, stiffness, *args, **kwargs):
61+
Sofa.Core.ForceFieldVec3d.__init__(self, *args, **kwargs)
62+
self.length = length
63+
self.stiffness = stiffness
64+
self.dense_to_sparse = None
65+
66+
def addForce(self, mechanical_parameters, out_force, position, velocity):
67+
with out_force.writeableArray() as wa:
68+
wa[:] += get_force(position.value, self.length, self.stiffness)
69+
70+
def addDForce(self, mechanical_parameters, df, dx):
71+
with df.writeableArray() as wa:
72+
wa[:] += get_dforce(self.mstate.position.value, self.length, self.stiffness, dx.value) * mechanical_parameters['kFactor']
73+
74+
# Option 1: Return the jacobian as a dense array (must have shape (n, n, 1) to be interpreted as such).
75+
# Note: Very slow for big sparse matrices.
76+
# def addKToMatrix(self, mechanical_parameters, n_particles, n_dimensions):
77+
# jacobian = get_kmatrix(self.mstate.position.value, self.length, self.stiffness)
78+
# return np.array(jacobian).reshape((n_particles*n_dimensions, n_particles*n_dimensions, 1))
79+
80+
# Option 2: Return the non-zero coefficients of the jacobian as an array with rows (i, j, value).
81+
# Note: The extraction of the non-zero coefficients is faster with JAX on GPU.
82+
# def addKToMatrix(self, mechanical_parameters, n_particles, n_dimensions):
83+
# jacobian = get_kmatrix(self.mstate.position.value, self.length, self.stiffness)
84+
# jacobian = jacobian.reshape((n_particles*n_dimensions, n_particles*n_dimensions))
85+
# i, j = jacobian.nonzero()
86+
# sparse_jacobian = jnp.stack([i, j, jacobian[i, j]], axis=1)
87+
# return np.array(sparse_jacobian)
88+
89+
# Option 2 optimization: We know the sparsity of the jacobian in advance (diagonal by 3x3 blocks).
90+
def addKToMatrix(self, mechanical_parameters, n_particles, n_dimensions):
91+
if self.dense_to_sparse is None:
92+
# i = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, ...]
93+
# j = [0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, ...]
94+
i = jnp.repeat(jnp.arange(n_particles*n_dimensions), 3)
95+
j = jnp.repeat(jnp.arange(n_particles*n_dimensions).reshape((-1, 3)), 3, axis=0).reshape(-1)
96+
self.dense_to_sparse = lambda jac: jnp.stack([i, j, jac[i, j]], axis=1)
97+
self.dense_to_sparse = jax.jit(self.dense_to_sparse) # slightly faster with jit
98+
99+
jacobian = get_kmatrix(self.mstate.position.value, self.length, self.stiffness)
100+
jacobian = jacobian.reshape((n_particles*n_dimensions, n_particles*n_dimensions))
101+
sparse_jacobian = self.dense_to_sparse(jacobian)
102+
sparse_jacobian = np.array(sparse_jacobian)
103+
# Note: with the computations optimized, the conversion below can account for
104+
# ~90% of the time spent in this function.
105+
return np.array(sparse_jacobian)
106+
107+
108+
def createScene(root, method="implicit-matrix-assembly", n_particles=1_000, use_sofa=False):
109+
root.dt = 1e-3
110+
root.gravity = (0, -9.8, 0)
111+
root.box = (-5, -5, -5, 5, 5, 5)
112+
root.addObject(
113+
"RequiredPlugin",
114+
pluginName=[
115+
'Sofa.Component.Visual',
116+
'Sofa.Component.ODESolver.Forward',
117+
'Sofa.Component.ODESolver.Backward',
118+
'Sofa.Component.LinearSolver.Iterative',
119+
'Sofa.Component.LinearSolver.Direct',
120+
'Sofa.Component.StateContainer',
121+
'Sofa.Component.Mass',
122+
'Sofa.Component.SolidMechanics.FEM.Elastic',
123+
'Sofa.Component.SolidMechanics.Spring',
124+
]
125+
)
126+
root.addObject("DefaultAnimationLoop")
127+
root.addObject("VisualStyle", displayFlags="showBehaviorModels showForceFields")
128+
129+
physics = root.addChild("Physics")
130+
131+
if method.lower() == "explicit": # Requires the implementation of 'addForce'
132+
physics.addObject("EulerExplicitSolver", name="eulerExplicit")
133+
elif method.lower() == "implicit-matrix-free": # Requires the implementation of 'addForce' and 'addDForce'
134+
physics.addObject("EulerImplicitSolver", name="eulerImplicit")
135+
physics.addObject("CGLinearSolver", template="GraphScattered", name="solver", iterations=50, tolerance=1e-5, threshold=1e-5)
136+
elif method == "implicit-matrix-assembly": # Requires the implementation of 'addForce', 'addDForce' and 'addKToMatrix'
137+
physics.addObject("EulerImplicitSolver", name="eulerImplicit")
138+
physics.addObject("SparseLDLSolver", name="solver", template="CompressedRowSparseMatrixd")
139+
140+
position = np.random.uniform(-1, 1, (n_particles, 3))
141+
velocity = np.zeros_like(position)
142+
length = np.random.uniform(0.8, 1.2, size=(n_particles, 1))
143+
stiffness = 100.0
144+
145+
particles = physics.addChild("Particles")
146+
particles.addObject("MechanicalObject", name="state", template="Vec3d", position=position, velocity=velocity, showObject=True)
147+
particles.addObject("UniformMass", name="mass", totalMass=n_particles)
148+
149+
if not use_sofa: # Use the force field implemented with JAX
150+
particles.addObject(JaxForceField(length=length, stiffness=stiffness))
151+
else: # Use a SOFA equivalent for comparison
152+
root.addObject("MechanicalObject", name="origin", template="Vec3d", position="0 0 0")
153+
particles.addObject("SpringForceField", name="force", object1="@/origin", object2="@/Physics/Particles/state", indices1=np.zeros(n_particles, dtype=np.int32), indices2=np.arange(n_particles), length=length, stiffness=stiffness*np.ones(n_particles), damping=np.zeros(n_particles))
154+
155+
156+
def main():
157+
import argparse
158+
import SofaRuntime
159+
import SofaImGui
160+
import Sofa.Gui
161+
162+
parser = argparse.ArgumentParser(description="Example of a scene using a ForceField implemented with JAX")
163+
parser.add_argument("--method", default="implicit-matrix-assembly", help="must be 'explicit', 'implicit-matrix-free' or 'implicit-matrix-assembly'")
164+
parser.add_argument("--particles", type=int, default=1000, help="number of particles (default 1000)")
165+
parser.add_argument("--use-sofa", action="store_true", help="use a force field from SOFA instead of the one implemented with JAX")
166+
args = parser.parse_args()
167+
168+
root=Sofa.Core.Node("root")
169+
createScene(root, method=args.method, n_particles=args.particles, use_sofa=args.use_sofa)
170+
Sofa.Simulation.initRoot(root)
171+
172+
Sofa.Gui.GUIManager.Init("myscene", "imgui")
173+
Sofa.Gui.GUIManager.createGUI(root, __file__)
174+
Sofa.Gui.GUIManager.SetDimension(1600, 900)
175+
Sofa.Gui.GUIManager.MainLoop(root)
176+
Sofa.Gui.GUIManager.closeGUI()
177+
178+
179+
if __name__ == "__main__":
180+
main()

0 commit comments

Comments
 (0)