You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I need to define a custom primitive which has some multi-dimensional inputs (for the time being, only for CPU). For now, for simplicity, I'm just flattening the inputs on the python side to pass them to custom_call and unflattening them in C++. On the C++ side, the XLA custom call looks like this (I'm putting the whole code, but the important parts are the ones associated to pos, edgesx, edgesy, edgesz and eventually also field):
def_ppaint_lowering(ctx, pos, mass, Nmesh, edgesx, edgesy, edgesz, comm):
comm=unpack_hashable(comm)
# Extract the numpy type of the inputspos_aval=ctx.avals_in[0]
np_dtype=np.dtype(pos_aval.dtype)
out_type=mlir.aval_to_ir_type(ctx.avals_out[0])
# Number of particles is length of pos / 3Nparts= (np.prod(pos_aval.shape) /3).astype(np.int64)
# Dealing with comm as in mpi4jax, see for instance barrier.py in collective opscomm=as_mhlo_constant(to_mpi_handle(comm), np.uintp)
# Dispatch a different call depending on the dtypeifnp_dtype==np.float32:
op_name="ppaint_f32"elifnp_dtype==np.float64:
op_name="ppaint_f64"else:
raiseNotImplementedError(f"Unsupported dtype {np_dtype}")
returncustom_call(
op_name,
# Output typesresult_types=[out_type],
# The inputs:operands=[pos, mlir.ir_constant(Nparts), mass, Nmesh, edgesx, edgesy, edgesz, comm],
).results
For example, originally pos in python is (Nparts, 3), and I'm passing it to this lowering function flattened. As you can see, on the C++ side I need to pass a T** to ppaint (it could also be a T*[3] if it's easier to deal with), and for now I'm just expecting a T* (_pos) from the XLA custom call, defining pos via:
T** pos = new T*[Nparts];
for (int i=0; i<Nparts; i++) pos[i] = &(_pos[i * 3]); // 3 is dimensionality
My question is: how could I do this in an easier way? In other words, how should I pass for instance pos (and what layout should I specify, if needed) to the custom_call in python in order to have directly T** pos = reinterpret_cast<T**>(in[0]); on the C++ side?
I hope this is clear enough, thank you in advance!
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hello! Hopefully this is not a trivial question.
I need to define a custom primitive which has some multi-dimensional inputs (for the time being, only for CPU). For now, for simplicity, I'm just flattening the inputs on the python side to pass them to
custom_call
and unflattening them in C++. On the C++ side, the XLA custom call looks like this (I'm putting the whole code, but the important parts are the ones associated topos
,edgesx
,edgesy
,edgesz
and eventually alsofield
):while the lowering function in python is:
For example, originally
pos
in python is(Nparts, 3)
, and I'm passing it to this lowering function flattened. As you can see, on the C++ side I need to pass aT**
toppaint
(it could also be aT*[3]
if it's easier to deal with), and for now I'm just expecting aT*
(_pos
) from the XLA custom call, definingpos
via:My question is: how could I do this in an easier way? In other words, how should I pass for instance
pos
(and what layout should I specify, if needed) to thecustom_call
in python in order to have directlyT** pos = reinterpret_cast<T**>(in[0]);
on the C++ side?I hope this is clear enough, thank you in advance!
Beta Was this translation helpful? Give feedback.
All reactions