Skip to content

Commit 54bacbe

Browse files
author
Torax team
committed
Merge pull request #804 from google-deepmind:onnx-docs
PiperOrigin-RevId: 760773200
2 parents 2d88c1e + 158135f commit 54bacbe

File tree

3 files changed

+309
-1
lines changed

3 files changed

+309
-1
lines changed

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@
310310
.. |QLKNN| replace:: `QLKNN <QLKNN_target_>`_
311311
.. _QLKNN_target: {github_base_url}/torax/_src/transport_model/qlknn_transport_model.py
312312
313+
.. |QuaLiKiz| replace:: `QuaLiKiz <qualikiz_target_>`_
314+
.. _qualikiz_target: {github_base_url}/torax/_src/transport_model/qualikiz_transport_model.py
315+
313316
.. |transport_model| replace:: `transport_model <torax_src_transport_model_target_>`_
314317
.. _torax_src_transport_model_target: {github_base_url}/torax/_src/transport_model
315318

docs/interfacing_with_surrogates.rst

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
.. _interfacing_with_surrogates:
2+
3+
JAX-compatible interfaces with ML-surrogates of physics models
4+
##############################################################
5+
6+
This section discusses a variety of options for building JAX-friendly interfaces
7+
to surrogate models.
8+
9+
As an illustrative example, suppose we have a new neural network surrogate
10+
transport model that we would like to use in TORAX. Assume that all the
11+
boilerplate described in the previous sections has been taken care of, as well
12+
as the definition of some functions to convert between TORAX structures and
13+
tensors for the neural network.
14+
15+
.. code-block:: python
16+
17+
class MyCustomSurrogateTransportModel(TransportModel):
18+
...
19+
def _call_implementation(
20+
self,
21+
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
22+
geo: geometry.Geometry,
23+
core_profiles: state.CoreProfiles,
24+
) -> state.CoreTransport:
25+
input_tensor = self._prepare_input(dynamic_runtime_params_slice, geo, core_profiles)
26+
27+
output_tensor = self._call_surrogate_model(input_tensor)
28+
29+
chi_i, chi_e, d_e, v_e = self._parse_output(output_tensor)
30+
31+
return state.CoreTransport(
32+
chi_face_ion=chi_i,
33+
chi_face_electron=chi_e,
34+
d_e=d_e,
35+
v_e=v_e,
36+
)
37+
38+
In this guide, we explore a few options for how you could make the
39+
``_call_surrogate_model`` function for an existing surrogate, while maintaining
40+
the full power of JAX:
41+
42+
1. **Manually reimplementing the model in JAX**.
43+
2. **Converting a Pytorch model to a JAX model**.
44+
3. **Using an ONNX model**.
45+
46+
.. note::
47+
These conversion methods are necessary in order to make an external model
48+
compatible with JAX's autodiff and JIT functionality, which is required for
49+
using TORAX's gradient-driven nonlinear solvers (e.g. Newton-Raphson).
50+
Interfacing with non-differentiable, non-JITtable models is possible
51+
(for an example, see the |QuaLiKiz| transport model implementation) if the
52+
linear solver is used. However, note that if the model is called within the
53+
step function, JIT will need to be disabled with
54+
``TORAX_COMPILATION_ENABLED=0``.
55+
56+
57+
Option 1: manually reimplementing the model in JAX
58+
==================================================
59+
60+
If the architecture of the surrogate is sufficiently simple, you might consider
61+
reimplementing the model in JAX. The surrogates in TORAX are mostly implemented
62+
using `Flax Linen`_, and can be found in the |fusion_surrogates|_ repository.
63+
If you're not familiar with Flax, you can check out the `Flax documentation`_
64+
on how to define your own models.
65+
66+
Consider a PyTorch neural network,
67+
68+
.. code-block:: python
69+
70+
import torch
71+
72+
class PyTorchMLP(torch.nn.Module):
73+
def __init__(self, hidden_dim: int, n_hidden: int, output_dim: int, input_dim: int):
74+
super().__init__()
75+
self.model = torch.nn.Sequential(
76+
torch.nn.Linear(input_dim, hidden_dim),
77+
torch.nn.ReLU(),
78+
*[torch.nn.Sequential(
79+
torch.nn.Linear(hidden_dim, hidden_dim),
80+
torch.nn.ReLU()
81+
) for _ in range(n_hidden)],
82+
torch.nn.Linear(hidden_dim, output_dim)
83+
)
84+
85+
def forward(self, x):
86+
return self.model(x)
87+
88+
torch_model = PyTorchMLP(hidden_dim, n_hidden, output_dim, input_dim)
89+
90+
This model can be replicated in Flax as follows:
91+
92+
.. code-block:: python
93+
94+
from flax import linen
95+
96+
class FlaxMLP(linen.Module):
97+
hidden_dim: int
98+
n_hidden: int
99+
output_dim: int
100+
input_dim: int
101+
102+
@linen.compact
103+
def __call__(self, x):
104+
x = linen.Dense(self.hidden_dim)(x)
105+
x = linen.relu(x)
106+
for _ in range(self.n_hidden):
107+
x = linen.Dense(self.hidden_dim)(x)
108+
x = linen.relu(x)
109+
x = linen.Dense(self.output_dim)(x)
110+
return x
111+
112+
flax_model = FlaxMLP(hidden_dim, n_hidden, output_dim, input_dim)
113+
114+
As this is only the model architecture, we need to load the trained weights
115+
separately. This can be a bit fiddly as you have to map from the parameter names
116+
in the weights checkpoint file to the parameter names in the Flax model.
117+
118+
For loading weights from a PyTorch checkpoint, you might do something like:
119+
120+
.. code-block:: python
121+
122+
import torch
123+
124+
state_dict = torch.load(PYTORCH_CHECKPOINT_PATH)
125+
126+
params = {}
127+
for i in range(n_hidden_layers):
128+
layer_dict = {
129+
"kernel": jnp.array(
130+
state_dict[f"model.{i*2}.weight"]
131+
).T,
132+
"bias": jnp.array(
133+
pytorch_state_dict[f"model.{j*2}.bias"]
134+
).T,
135+
}
136+
params[f"Dense_{i}"] = layer_dict
137+
138+
params = {'params': params}
139+
140+
The model can then be called like any Flax model,
141+
142+
.. code-block:: python
143+
144+
output_tensor = jax.jit(flax_model.apply)(params, input_tensor)
145+
146+
147+
.. warning::
148+
You need to be very careful when loading from a PyTorch state dict, as
149+
Flax and PyTorch may have slightly different representations of the weights
150+
(for example, one could be the transpose of the other). It's worth
151+
validating the output of your PyTorch model against your JAX model to make
152+
sure.
153+
154+
155+
Option 2: converting a PyTorch model to a JAX model
156+
===================================================
157+
158+
.. warning::
159+
The `torch_xla2`_ package is still evolving, which means there may be
160+
unexpected breaking changes. Some of the methods described in this section
161+
may become deprecated with little warning.
162+
163+
If your model is in PyTorch, you could also consider using the `torch_xla2`_
164+
package to do the conversion to JAX automatically.
165+
166+
.. code-block:: python
167+
168+
import torch
169+
import torch_xla2 as tx
170+
171+
trained_model = torch.load(PYTORCH_MODEL_PATH, weights_only=False) # Use weights_only=False if you want to load the full model
172+
params, jax_model_from_torch = tx.extract_jax(model)
173+
174+
The model can then be called as a pure JAX function:
175+
176+
.. code-block:: python
177+
178+
output_tensor = jax.jit(jax_model_from_torch)(params, input_tensor)
179+
180+
To remove the need for performing the conversion every time the model is loaded,
181+
you might want to save a JAX-compatible version of the weights and model to
182+
disk:
183+
184+
.. code-block:: python
185+
186+
import jax
187+
import numpy as np
188+
189+
# jax.export uses StableHLO to serialize the model to a binary format
190+
exported_model = jax.export(jax.jit(jax_model_from_torch))
191+
with open("model.hlo", "wb") as f:
192+
f.write(exported_model.serialize())
193+
194+
# The weights can be saved as numpy arrays
195+
np.savez("weights.npz", *params)
196+
197+
The model can then be loaded and run as follows:
198+
199+
.. code-block:: python
200+
201+
# Load the HLO checkpoint
202+
with open('model.hlo', 'rb') as f:
203+
model_as_bytes = f.read()
204+
model = jax.export.deserialize(model_as_bytes)
205+
206+
# Load the weights
207+
weights_as_npz = np.load('weights.npz')
208+
weights = [jnp.array(v) for v in weights_as_npz.values()]
209+
210+
211+
Option 3: using an ONNX model
212+
=============================
213+
214+
The `Open Neural Network Exchange`_ format (ONNX) is a highly interoperable
215+
format for sharing neural network models. ONNX files include the model
216+
architecture and weights bundled together.
217+
218+
An ONNX model can be loaded and called as follows, making sure to specify the
219+
correct input and output node names for your specific model:
220+
221+
.. code-block:: python
222+
223+
import onnxruntime as ort
224+
import numpy as np
225+
226+
s = ort.InferenceSession(ONNX_MODEL_PATH)
227+
onnx_output_tensor = s.run(
228+
# Output node names
229+
['output1', 'output2'],
230+
# Mapping from input node names to input tensors
231+
# NOTE: input tensors must have correct dtype for your specific model
232+
{'input': np.asarray(input_tensor, dtype=np.float32)},
233+
)
234+
235+
However, JAX will not be able to differentiate through the InferenceSession.
236+
To convert the ONNX model to a JAX representation, you can use the
237+
`jaxonnxruntime`_ package:
238+
239+
.. code-block:: python
240+
241+
import jax.numpy as jnp
242+
from jaxonnxruntime.backend import Backend as ONNXJaxBackend
243+
import onnx
244+
245+
onnx_model = onnx.load_model(ONNX_MODEL_PATH)
246+
247+
jax_model_from_onnx = ONNXJaxBackend.prepare(onnx_model)
248+
# NOTE: run() returns a list of output tensors, in order of the output nodes
249+
output_tensors = jax.jit(jax_model_from_onnx.run)({"input": jnp.asarray(input_tensor, dtype=jnp.float32)})
250+
251+
252+
Best practices
253+
==============
254+
255+
**Caching and lazy loading**: Ideally, the model should be constructed and
256+
weights loaded once only, on the first call to the function. The loaded model
257+
should be cached and reused for subsequent calls.
258+
259+
For example, in the ``_combined`` function of the QLKNN transport model (the
260+
function that actually evaluates this model), we have:
261+
262+
.. code-block:: python
263+
264+
model = get_model(self._model_path)
265+
...
266+
model_output = model.predict(...)
267+
268+
where
269+
270+
.. code-block:: python
271+
272+
@functools.lru_cache(maxsize=1)
273+
def get_model(path: str) -> base_qlknn_model.BaseQLKNNModel:
274+
"""Load the model."""
275+
...
276+
return qlknn_10d.QLKNN10D(path)
277+
278+
By decorating with ``functools.lru_cache(maxsize=1)``, the result of this
279+
function - the loaded model - is stored in the cache and is only re-loaded if
280+
the function is called with a different ``path``.
281+
282+
**JITting model calls**: In general, you should make sure that your forward call
283+
of the model is JITted:
284+
285+
.. code-block:: python
286+
287+
output_tensor = jax.jit(flax_model.apply)(params, input_tensor) # Good
288+
output_tensor = flax_model.apply(params, input_tensor) # Bad
289+
290+
This is vital to ensure fast performance.
291+
292+
.. _Flax Linen: https://flax-linen.readthedocs.io/en/latest/index.html
293+
.. _Flax documentation: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html#defining-your-own-models
294+
.. _torch_xla2: https://pytorch.org/xla/master/features/stablehlo.html
295+
.. _Open Neural Network Exchange: https://onnx.ai/
296+
.. _jaxonnxruntime: https://github.com/google/jaxonnxruntime
297+
.. |fusion_surrogates| replace:: ``google-deepmind/fusion_surrogates``
298+
.. _fusion_surrogates: https://github.com/google-deepmind/fusion_surrogates

docs/model_integration.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,11 @@ exposed this as part of the TORAX API.
1212

1313
If you would like to use this please reach out to us. We aim to expose this
1414
functionality as part of the TORAX API in the very near future to further
15-
facilitate the integration of custom models.
15+
facilitate the integration of custom models, and further expand the
16+
documentation.
17+
18+
.. toctree::
19+
:maxdepth: 1
20+
:caption: Model Integration Topics
21+
22+
interfacing_with_surrogates

0 commit comments

Comments
 (0)