|
| 1 | +.. _interfacing_with_surrogates: |
| 2 | + |
| 3 | +JAX-compatible interfaces with ML-surrogates of physics models |
| 4 | +############################################################## |
| 5 | + |
| 6 | +This section discusses a variety of options for building JAX-friendly interfaces |
| 7 | +to surrogate models. |
| 8 | + |
| 9 | +As an illustrative example, suppose we have a new neural network surrogate |
| 10 | +transport model that we would like to use in TORAX. Assume that all the |
| 11 | +boilerplate described in the previous sections has been taken care of, as well |
| 12 | +as the definition of some functions to convert between TORAX structures and |
| 13 | +tensors for the neural network. |
| 14 | + |
| 15 | +.. code-block:: python |
| 16 | +
|
| 17 | + class MyCustomSurrogateTransportModel(TransportModel): |
| 18 | + ... |
| 19 | + def _call_implementation( |
| 20 | + self, |
| 21 | + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, |
| 22 | + geo: geometry.Geometry, |
| 23 | + core_profiles: state.CoreProfiles, |
| 24 | + ) -> state.CoreTransport: |
| 25 | + input_tensor = self._prepare_input(dynamic_runtime_params_slice, geo, core_profiles) |
| 26 | +
|
| 27 | + output_tensor = self._call_surrogate_model(input_tensor) |
| 28 | +
|
| 29 | + chi_i, chi_e, d_e, v_e = self._parse_output(output_tensor) |
| 30 | +
|
| 31 | + return state.CoreTransport( |
| 32 | + chi_face_ion=chi_i, |
| 33 | + chi_face_electron=chi_e, |
| 34 | + d_e=d_e, |
| 35 | + v_e=v_e, |
| 36 | + ) |
| 37 | +
|
| 38 | +In this guide, we explore a few options for how you could make the |
| 39 | +``_call_surrogate_model`` function for an existing surrogate, while maintaining |
| 40 | +the full power of JAX: |
| 41 | + |
| 42 | +1. **Manually reimplementing the model in JAX**. |
| 43 | +2. **Converting a Pytorch model to a JAX model**. |
| 44 | +3. **Using an ONNX model**. |
| 45 | + |
| 46 | +.. note:: |
| 47 | + These conversion methods are necessary in order to make an external model |
| 48 | + compatible with JAX's autodiff and JIT functionality, which is required for |
| 49 | + using TORAX's gradient-driven nonlinear solvers (e.g. Newton-Raphson). |
| 50 | + Interfacing with non-differentiable, non-JITtable models is possible |
| 51 | + (for an example, see the |QuaLiKiz| transport model implementation) if the |
| 52 | + linear solver is used. However, note that if the model is called within the |
| 53 | + step function, JIT will need to be disabled with |
| 54 | + ``TORAX_COMPILATION_ENABLED=0``. |
| 55 | + |
| 56 | + |
| 57 | +Option 1: manually reimplementing the model in JAX |
| 58 | +================================================== |
| 59 | + |
| 60 | +If the architecture of the surrogate is sufficiently simple, you might consider |
| 61 | +reimplementing the model in JAX. The surrogates in TORAX are mostly implemented |
| 62 | +using `Flax Linen`_, and can be found in the |fusion_surrogates|_ repository. |
| 63 | +If you're not familiar with Flax, you can check out the `Flax documentation`_ |
| 64 | +on how to define your own models. |
| 65 | + |
| 66 | +Consider a PyTorch neural network, |
| 67 | + |
| 68 | +.. code-block:: python |
| 69 | +
|
| 70 | + import torch |
| 71 | +
|
| 72 | + class PyTorchMLP(torch.nn.Module): |
| 73 | + def __init__(self, hidden_dim: int, n_hidden: int, output_dim: int, input_dim: int): |
| 74 | + super().__init__() |
| 75 | + self.model = torch.nn.Sequential( |
| 76 | + torch.nn.Linear(input_dim, hidden_dim), |
| 77 | + torch.nn.ReLU(), |
| 78 | + *[torch.nn.Sequential( |
| 79 | + torch.nn.Linear(hidden_dim, hidden_dim), |
| 80 | + torch.nn.ReLU() |
| 81 | + ) for _ in range(n_hidden)], |
| 82 | + torch.nn.Linear(hidden_dim, output_dim) |
| 83 | + ) |
| 84 | +
|
| 85 | + def forward(self, x): |
| 86 | + return self.model(x) |
| 87 | +
|
| 88 | + torch_model = PyTorchMLP(hidden_dim, n_hidden, output_dim, input_dim) |
| 89 | +
|
| 90 | +This model can be replicated in Flax as follows: |
| 91 | + |
| 92 | +.. code-block:: python |
| 93 | +
|
| 94 | + from flax import linen |
| 95 | +
|
| 96 | + class FlaxMLP(linen.Module): |
| 97 | + hidden_dim: int |
| 98 | + n_hidden: int |
| 99 | + output_dim: int |
| 100 | + input_dim: int |
| 101 | +
|
| 102 | + @linen.compact |
| 103 | + def __call__(self, x): |
| 104 | + x = linen.Dense(self.hidden_dim)(x) |
| 105 | + x = linen.relu(x) |
| 106 | + for _ in range(self.n_hidden): |
| 107 | + x = linen.Dense(self.hidden_dim)(x) |
| 108 | + x = linen.relu(x) |
| 109 | + x = linen.Dense(self.output_dim)(x) |
| 110 | + return x |
| 111 | +
|
| 112 | + flax_model = FlaxMLP(hidden_dim, n_hidden, output_dim, input_dim) |
| 113 | +
|
| 114 | +As this is only the model architecture, we need to load the trained weights |
| 115 | +separately. This can be a bit fiddly as you have to map from the parameter names |
| 116 | +in the weights checkpoint file to the parameter names in the Flax model. |
| 117 | + |
| 118 | +For loading weights from a PyTorch checkpoint, you might do something like: |
| 119 | + |
| 120 | +.. code-block:: python |
| 121 | +
|
| 122 | + import torch |
| 123 | +
|
| 124 | + state_dict = torch.load(PYTORCH_CHECKPOINT_PATH) |
| 125 | +
|
| 126 | + params = {} |
| 127 | + for i in range(n_hidden_layers): |
| 128 | + layer_dict = { |
| 129 | + "kernel": jnp.array( |
| 130 | + state_dict[f"model.{i*2}.weight"] |
| 131 | + ).T, |
| 132 | + "bias": jnp.array( |
| 133 | + pytorch_state_dict[f"model.{j*2}.bias"] |
| 134 | + ).T, |
| 135 | + } |
| 136 | + params[f"Dense_{i}"] = layer_dict |
| 137 | +
|
| 138 | + params = {'params': params} |
| 139 | +
|
| 140 | +The model can then be called like any Flax model, |
| 141 | + |
| 142 | +.. code-block:: python |
| 143 | +
|
| 144 | + output_tensor = jax.jit(flax_model.apply)(params, input_tensor) |
| 145 | +
|
| 146 | +
|
| 147 | +.. warning:: |
| 148 | + You need to be very careful when loading from a PyTorch state dict, as |
| 149 | + Flax and PyTorch may have slightly different representations of the weights |
| 150 | + (for example, one could be the transpose of the other). It's worth |
| 151 | + validating the output of your PyTorch model against your JAX model to make |
| 152 | + sure. |
| 153 | + |
| 154 | + |
| 155 | +Option 2: converting a PyTorch model to a JAX model |
| 156 | +=================================================== |
| 157 | + |
| 158 | +.. warning:: |
| 159 | + The `torch_xla2`_ package is still evolving, which means there may be |
| 160 | + unexpected breaking changes. Some of the methods described in this section |
| 161 | + may become deprecated with little warning. |
| 162 | + |
| 163 | +If your model is in PyTorch, you could also consider using the `torch_xla2`_ |
| 164 | +package to do the conversion to JAX automatically. |
| 165 | + |
| 166 | +.. code-block:: python |
| 167 | +
|
| 168 | + import torch |
| 169 | + import torch_xla2 as tx |
| 170 | +
|
| 171 | + trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model |
| 172 | + params, jax_model_from_torch = tx.extract_jax(model) |
| 173 | +
|
| 174 | +The model can then be called as a pure JAX function: |
| 175 | + |
| 176 | +.. code-block:: python |
| 177 | +
|
| 178 | + output_tensor = jax.jit(jax_model_from_torch)(params, input_tensor) |
| 179 | +
|
| 180 | +To remove the need for performing the conversion every time the model is loaded, |
| 181 | +you might want to save a JAX-compatible version of the weights and model to |
| 182 | +disk: |
| 183 | + |
| 184 | +.. code-block:: python |
| 185 | +
|
| 186 | + import jax |
| 187 | + import numpy as np |
| 188 | +
|
| 189 | + # jax.export uses StableHLO to serialize the model to a binary format |
| 190 | + exported_model = jax.export(jax.jit(jax_model_from_torch)) |
| 191 | + with open("model.hlo", "wb") as f: |
| 192 | + f.write(exported_model.serialize()) |
| 193 | +
|
| 194 | + # The weights can be saved as numpy arrays |
| 195 | + np.savez("weights.npz", *params) |
| 196 | +
|
| 197 | +The model can then be loaded and run as follows: |
| 198 | + |
| 199 | +.. code-block:: python |
| 200 | +
|
| 201 | + # Load the HLO checkpoint |
| 202 | + with open('model.hlo', 'rb') as f: |
| 203 | + model_as_bytes = f.read() |
| 204 | + model = jax.export.deserialize(model_as_bytes) |
| 205 | +
|
| 206 | + # Load the weights |
| 207 | + weights_as_npz = np.load('weights.npz') |
| 208 | + weights = [jnp.array(v) for v in weights_as_npz.values()] |
| 209 | +
|
| 210 | +
|
| 211 | +Option 3: using an ONNX model |
| 212 | +============================= |
| 213 | + |
| 214 | +The `Open Neural Network Exchange`_ format (ONNX) is a highly interoperable |
| 215 | +format for sharing neural network models. ONNX files include the model |
| 216 | +architecture and weights bundled together. |
| 217 | + |
| 218 | +An ONNX model can be loaded and called as follows, making sure to specify the |
| 219 | +correct input and output node names for your specific model: |
| 220 | + |
| 221 | +.. code-block:: python |
| 222 | +
|
| 223 | + import onnxruntime as ort |
| 224 | + import numpy as np |
| 225 | +
|
| 226 | + s = ort.InferenceSession(ONNX_MODEL_PATH) |
| 227 | + onnx_output_tensor = s.run( |
| 228 | + # Output node names |
| 229 | + ['output1', 'output2'], |
| 230 | + # Mapping from input node names to input tensors |
| 231 | + # NOTE: input tensors must have correct dtype for your specific model |
| 232 | + {'input': np.asarray(input_tensor, dtype=np.float32)}, |
| 233 | + ) |
| 234 | +
|
| 235 | +However, JAX will not be able to differentiate through the InferenceSession. |
| 236 | +To convert the ONNX model to a JAX representation, you can use the |
| 237 | +`jaxonnxruntime`_ package: |
| 238 | + |
| 239 | +.. code-block:: python |
| 240 | +
|
| 241 | + import jax.numpy as jnp |
| 242 | + from jaxonnxruntime.backend import Backend as ONNXJaxBackend |
| 243 | + import onnx |
| 244 | +
|
| 245 | + onnx_model = onnx.load_model(ONNX_MODEL_PATH) |
| 246 | +
|
| 247 | + jax_model_from_onnx = ONNXJaxBackend.prepare(onnx_model) |
| 248 | + # NOTE: run() returns a list of output tensors, in order of the output nodes |
| 249 | + output_tensors = jax.jit(jax_model_from_onnx.run)({"input": jnp.asarray(input_tensor, dtype=jnp.float32)}) |
| 250 | +
|
| 251 | +
|
| 252 | +Best practices |
| 253 | +============== |
| 254 | + |
| 255 | +**Caching and lazy loading**: Ideally, the model should be constructed and |
| 256 | +weights loaded once only, on the first call to the function. The loaded model |
| 257 | +should be cached and reused for subsequent calls. |
| 258 | + |
| 259 | +For example, in the ``_combined`` function of the QLKNN transport model (the |
| 260 | +function that actually evaluates this model), we have: |
| 261 | + |
| 262 | +.. code-block:: python |
| 263 | +
|
| 264 | + model = get_model(self._model_path) |
| 265 | + ... |
| 266 | + model_output = model.predict(...) |
| 267 | +
|
| 268 | +where |
| 269 | + |
| 270 | +.. code-block:: python |
| 271 | +
|
| 272 | + @functools.lru_cache(maxsize=1) |
| 273 | + def get_model(path: str) -> base_qlknn_model.BaseQLKNNModel: |
| 274 | + """Load the model.""" |
| 275 | + ... |
| 276 | + return qlknn_10d.QLKNN10D(path) |
| 277 | +
|
| 278 | +By decorating with ``functools.lru_cache(maxsize=1)``, the result of this |
| 279 | +function - the loaded model - is stored in the cache and is only re-loaded if |
| 280 | +the function is called with a different ``path``. |
| 281 | + |
| 282 | +**JITting model calls**: In general, you should make sure that your forward call |
| 283 | +of the model is JITted: |
| 284 | + |
| 285 | +.. code-block:: python |
| 286 | +
|
| 287 | + output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good |
| 288 | + output_tensor = flax_model.apply(params, input_tensor) # Bad |
| 289 | +
|
| 290 | +This is vital to ensure fast performance. |
| 291 | + |
| 292 | +.. _Flax Linen: https://flax-linen.readthedocs.io/en/latest/index.html |
| 293 | +.. _Flax documentation: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#defining-your-own-models |
| 294 | +.. _torch_xla2: https://pytorch.org/xla/master/features/stablehlo.html |
| 295 | +.. _Open Neural Network Exchange: https://onnx.ai/ |
| 296 | +.. _jaxonnxruntime: https://github.com/google/jaxonnxruntime |
| 297 | +.. |fusion_surrogates| replace:: ``google-deepmind/fusion_surrogates`` |
| 298 | +.. _fusion_surrogates: https://github.com/google-deepmind/fusion_surrogates |
0 commit comments