Ways to name constants when dumping the hlo module via xla_computation? #17775
-
Hi, when using My use case is that I have some data that need to be prepared dynamically in the runtime, but I don't want them to be fed as inputs to the function in every function call, which could add more inference time (which is important to us). I'm hoping that if the constants are named, I can have a placeholder constant in the function and substitute the values in the proto manually before loading the function for execution. I tried |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
We don't have a mechanism for annotating constants directly. Would it work for you to name a constant function? I'm thinking along the following lines: import jax
def g(x):
@jax.jit
def my_const(): return 3.
return x + my_const() >>> print(jax.jit(g).lower(1.).as_text())
module @jit_g attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
%0 = call @my_const() : () -> tensor<f32>
%1 = stablehlo.add %arg0, %0 : tensor<f32>
return %1 : tensor<f32>
}
func.func private @my_const() -> tensor<f32> {
%0 = stablehlo.constant dense<3.000000e+00> : tensor<f32>
return %0 : tensor<f32>
}
} |
Beta Was this translation helpful? Give feedback.
We don't have a mechanism for annotating constants directly. Would it work for you to name a constant function? I'm thinking along the following lines: