Pallas outside python #24453
Unanswered
krzysztofrusek
asked this question in
Q&A
Replies: 1 comment
-
It works with tf saved model, I am posting an example for future readers, Modelimport jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import tensorflow as tf
from jax.experimental import jax2tf
def add_vectors_kernel(x_ref, y_ref, o_ref):
x, y = x_ref[...], y_ref[...]
o_ref[...] = x + y
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
return pl.pallas_call(
add_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
)(x, y)
add_vectors(jnp.arange(8), jnp.arange(8))
my_model = tf.Module()
# Save a function that can take scalar inputs.
my_model.f = tf.function(jax2tf.convert(add_vectors), autograph=False,
input_signature=[tf.TensorSpec([8], tf.float32), tf.TensorSpec([8], tf.float32)])
tf.saved_model.save(my_model, 'mod',options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)) Serverimport tensorflow as tf
restored_model = tf.saved_model.load('./mod')
x = tf.convert_to_tensor([1,2,3,4,5,6,7,8], dtype=tf.float32)
y = tf.convert_to_tensor([1,2,3,4,5,6,7,8], dtype=tf.float32)
print(restored_model.f(x, y)) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
hello, is there an option to run Pallas kernels (triton or mosaic-gpu) outside python ecosystem?
Another related discussion #20508
Beta Was this translation helpful? Give feedback.
All reactions