Skip to content

Ways to name constants when dumping the hlo module via xla_computation? #17775

Answered by froystig
louis-shao asked this question in Q&A
Discussion options

You must be logged in to vote

We don't have a mechanism for annotating constants directly. Would it work for you to name a constant function? I'm thinking along the following lines:

import jax

def g(x):
  @jax.jit
  def my_const(): return 3.

  return x + my_const()
>>> print(jax.jit(g).lower(1.).as_text())
module @jit_g attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<f32> {mhlo.sharding = "{replicated}"}) -> (tensor<f32> {jax.result_info = ""}) {
    %0 = call @my_const() : () -> tensor<f32>
    %1 = stablehlo.add %arg0, %0 : tensor<f32>
    return %1 : tensor<f32>
  }
  func.func private @my_const() -> tensor<f32> {
    %0 = stablehlo.constant dense<3…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@louis-shao
Comment options

@froystig
Comment options

Answer selected by louis-shao
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants