|
| 1 | +""" |
| 2 | +Toy example of a force field leveraging autodiff with JAX. |
| 3 | +
|
| 4 | +JAX can be installed via e.g. `pip install -U jax[cuda12]` |
| 5 | +""" |
| 6 | +import jax |
| 7 | +import jax.numpy as jnp |
| 8 | +import numpy as np |
| 9 | + |
| 10 | +import Sofa |
| 11 | + |
| 12 | + |
| 13 | +# Some configuration for JAX: device and precision |
| 14 | +# jax.config.update('jax_default_device', jax.devices('cpu')[0]) |
| 15 | +jax.config.update("jax_default_device", jax.devices("gpu")[0]) # default "gpu" |
| 16 | +jax.config.update("jax_enable_x64", True) # default False (ie use float32) |
| 17 | + |
| 18 | + |
| 19 | +@jax.jit # JIT (just-in-time compilation) for better performance |
| 20 | +def get_force(position, length, stiffness): |
| 21 | + """ |
| 22 | + Spring between the origin and the given position. |
| 23 | +
|
| 24 | + position: array of shape (n_particles, n_dimensions) |
| 25 | + length: scalar or array of shape (n_particles, 1) |
| 26 | + stiffness: scalar or array of shape (n_particles, 1) |
| 27 | + """ |
| 28 | + distance = jnp.sqrt(jnp.sum(position**2, axis=1, keepdims=True)) |
| 29 | + direction = position / distance |
| 30 | + return - stiffness * (distance - length) * direction |
| 31 | + |
| 32 | + |
| 33 | +@jax.jit # JIT (just-in-time compilation) for better performance |
| 34 | +def get_dforce(position, length, stiffness, vector): |
| 35 | + """ |
| 36 | + Compute the jacobian-vector product (jvp) using autodiff |
| 37 | + """ |
| 38 | + def get_force_from_position(x): |
| 39 | + return get_force(x, length, stiffness) |
| 40 | + # Differentiate get_force() as a function of the position |
| 41 | + return jax.jvp(get_force_from_position, (position,), (vector,))[1] |
| 42 | + |
| 43 | + |
| 44 | +@jax.jit # JIT (just-in-time compilation) for better performance |
| 45 | +def get_kmatrix(position, length, stiffness): |
| 46 | + """ |
| 47 | + Compute the jacobian using autodiff |
| 48 | +
|
| 49 | + Warning: The jacobian computed this way is a dense matrix. |
| 50 | + Check `sparsejac` if you are interested in sparse jacobian with JAX. |
| 51 | + """ |
| 52 | + def get_force_from_position(x): |
| 53 | + return get_force(x, length, stiffness) |
| 54 | + # Differentiate get_force() as a function of the position |
| 55 | + return jax.jacrev(get_force_from_position)(position) |
| 56 | + |
| 57 | + |
| 58 | +class JaxForceField(Sofa.Core.ForceFieldVec3d): |
| 59 | + |
| 60 | + def __init__(self, length, stiffness, *args, **kwargs): |
| 61 | + Sofa.Core.ForceFieldVec3d.__init__(self, *args, **kwargs) |
| 62 | + self.length = length |
| 63 | + self.stiffness = stiffness |
| 64 | + self.dense_to_sparse = None |
| 65 | + |
| 66 | + def addForce(self, mechanical_parameters, out_force, position, velocity): |
| 67 | + with out_force.writeableArray() as wa: |
| 68 | + wa[:] += get_force(position.value, self.length, self.stiffness) |
| 69 | + |
| 70 | + def addDForce(self, mechanical_parameters, df, dx): |
| 71 | + with df.writeableArray() as wa: |
| 72 | + wa[:] += get_dforce(self.mstate.position.value, self.length, self.stiffness, dx.value) * mechanical_parameters['kFactor'] |
| 73 | + |
| 74 | + # Option 1: Return the jacobian as a dense array (must have shape (n, n, 1) to be interpreted as such). |
| 75 | + # Note: Very slow for big sparse matrices. |
| 76 | + # def addKToMatrix(self, mechanical_parameters, n_particles, n_dimensions): |
| 77 | + # jacobian = get_kmatrix(self.mstate.position.value, self.length, self.stiffness) |
| 78 | + # return np.array(jacobian).reshape((n_particles*n_dimensions, n_particles*n_dimensions, 1)) |
| 79 | + |
| 80 | + # Option 2: Return the non-zero coefficients of the jacobian as an array with rows (i, j, value). |
| 81 | + # Note: The extraction of the non-zero coefficients is faster with JAX on GPU. |
| 82 | + # def addKToMatrix(self, mechanical_parameters, n_particles, n_dimensions): |
| 83 | + # jacobian = get_kmatrix(self.mstate.position.value, self.length, self.stiffness) |
| 84 | + # jacobian = jacobian.reshape((n_particles*n_dimensions, n_particles*n_dimensions)) |
| 85 | + # i, j = jacobian.nonzero() |
| 86 | + # sparse_jacobian = jnp.stack([i, j, jacobian[i, j]], axis=1) |
| 87 | + # return np.array(sparse_jacobian) |
| 88 | + |
| 89 | + # Option 2 optimization: We know the sparsity of the jacobian in advance (diagonal by 3x3 blocks). |
| 90 | + def addKToMatrix(self, mechanical_parameters, n_particles, n_dimensions): |
| 91 | + if self.dense_to_sparse is None: |
| 92 | + # i = [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, ...] |
| 93 | + # j = [0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 4, 5, ...] |
| 94 | + i = jnp.repeat(jnp.arange(n_particles*n_dimensions), 3) |
| 95 | + j = jnp.repeat(jnp.arange(n_particles*n_dimensions).reshape((-1, 3)), 3, axis=0).reshape(-1) |
| 96 | + self.dense_to_sparse = lambda jac: jnp.stack([i, j, jac[i, j]], axis=1) |
| 97 | + self.dense_to_sparse = jax.jit(self.dense_to_sparse) # slightly faster with jit |
| 98 | + |
| 99 | + jacobian = get_kmatrix(self.mstate.position.value, self.length, self.stiffness) |
| 100 | + jacobian = jacobian.reshape((n_particles*n_dimensions, n_particles*n_dimensions)) |
| 101 | + sparse_jacobian = self.dense_to_sparse(jacobian) |
| 102 | + sparse_jacobian = np.array(sparse_jacobian) |
| 103 | + # Note: with the computations optimized, the conversion below can account for |
| 104 | + # ~90% of the time spent in this function. |
| 105 | + return np.array(sparse_jacobian) |
| 106 | + |
| 107 | + |
| 108 | +def createScene(root, method="implicit-matrix-assembly", n_particles=1_000, use_sofa=False): |
| 109 | + root.dt = 1e-3 |
| 110 | + root.gravity = (0, -9.8, 0) |
| 111 | + root.box = (-5, -5, -5, 5, 5, 5) |
| 112 | + root.addObject( |
| 113 | + "RequiredPlugin", |
| 114 | + pluginName=[ |
| 115 | + 'Sofa.Component.Visual', |
| 116 | + 'Sofa.Component.ODESolver.Forward', |
| 117 | + 'Sofa.Component.ODESolver.Backward', |
| 118 | + 'Sofa.Component.LinearSolver.Iterative', |
| 119 | + 'Sofa.Component.LinearSolver.Direct', |
| 120 | + 'Sofa.Component.StateContainer', |
| 121 | + 'Sofa.Component.Mass', |
| 122 | + 'Sofa.Component.SolidMechanics.FEM.Elastic', |
| 123 | + 'Sofa.Component.SolidMechanics.Spring', |
| 124 | + ] |
| 125 | + ) |
| 126 | + root.addObject("DefaultAnimationLoop") |
| 127 | + root.addObject("VisualStyle", displayFlags="showBehaviorModels showForceFields") |
| 128 | + |
| 129 | + physics = root.addChild("Physics") |
| 130 | + |
| 131 | + if method.lower() == "explicit": # Requires the implementation of 'addForce' |
| 132 | + physics.addObject("EulerExplicitSolver", name="eulerExplicit") |
| 133 | + elif method.lower() == "implicit-matrix-free": # Requires the implementation of 'addForce' and 'addDForce' |
| 134 | + physics.addObject("EulerImplicitSolver", name="eulerImplicit") |
| 135 | + physics.addObject("CGLinearSolver", template="GraphScattered", name="solver", iterations=50, tolerance=1e-5, threshold=1e-5) |
| 136 | + elif method == "implicit-matrix-assembly": # Requires the implementation of 'addForce', 'addDForce' and 'addKToMatrix' |
| 137 | + physics.addObject("EulerImplicitSolver", name="eulerImplicit") |
| 138 | + physics.addObject("SparseLDLSolver", name="solver", template="CompressedRowSparseMatrixd") |
| 139 | + |
| 140 | + position = np.random.uniform(-1, 1, (n_particles, 3)) |
| 141 | + velocity = np.zeros_like(position) |
| 142 | + length = np.random.uniform(0.8, 1.2, size=(n_particles, 1)) |
| 143 | + stiffness = 100.0 |
| 144 | + |
| 145 | + particles = physics.addChild("Particles") |
| 146 | + particles.addObject("MechanicalObject", name="state", template="Vec3d", position=position, velocity=velocity, showObject=True) |
| 147 | + particles.addObject("UniformMass", name="mass", totalMass=n_particles) |
| 148 | + |
| 149 | + if not use_sofa: # Use the force field implemented with JAX |
| 150 | + particles.addObject(JaxForceField(length=length, stiffness=stiffness)) |
| 151 | + else: # Use a SOFA equivalent for comparison |
| 152 | + root.addObject("MechanicalObject", name="origin", template="Vec3d", position="0 0 0") |
| 153 | + particles.addObject("SpringForceField", name="force", object1="@/origin", object2="@/Physics/Particles/state", indices1=np.zeros(n_particles, dtype=np.int32), indices2=np.arange(n_particles), length=length, stiffness=stiffness*np.ones(n_particles), damping=np.zeros(n_particles)) |
| 154 | + |
| 155 | + |
| 156 | +def main(): |
| 157 | + import argparse |
| 158 | + import SofaRuntime |
| 159 | + import SofaImGui |
| 160 | + import Sofa.Gui |
| 161 | + |
| 162 | + parser = argparse.ArgumentParser(description="Example of a scene using a ForceField implemented with JAX") |
| 163 | + parser.add_argument("--method", default="implicit-matrix-assembly", help="must be 'explicit', 'implicit-matrix-free' or 'implicit-matrix-assembly'") |
| 164 | + parser.add_argument("--particles", type=int, default=1000, help="number of particles (default 1000)") |
| 165 | + parser.add_argument("--use-sofa", action="store_true", help="use a force field from SOFA instead of the one implemented with JAX") |
| 166 | + args = parser.parse_args() |
| 167 | + |
| 168 | + root=Sofa.Core.Node("root") |
| 169 | + createScene(root, method=args.method, n_particles=args.particles, use_sofa=args.use_sofa) |
| 170 | + Sofa.Simulation.initRoot(root) |
| 171 | + |
| 172 | + Sofa.Gui.GUIManager.Init("myscene", "imgui") |
| 173 | + Sofa.Gui.GUIManager.createGUI(root, __file__) |
| 174 | + Sofa.Gui.GUIManager.SetDimension(1600, 900) |
| 175 | + Sofa.Gui.GUIManager.MainLoop(root) |
| 176 | + Sofa.Gui.GUIManager.closeGUI() |
| 177 | + |
| 178 | + |
| 179 | +if __name__ == "__main__": |
| 180 | + main() |
0 commit comments