-
Notifications
You must be signed in to change notification settings - Fork 121
Open
Labels
Description
Description
When using Tornado's @Reduce
annotation in a task, if one of the task parameters is a persisted object from a previous task in another TaskGraph, the runtime fails to update object state. This appears to be an unsupported use case.
Current Behavior
The system fails when trying to use a reduction operation on data that was persisted from a previous TaskGraph execution.
Expected Behavior
Reduction operations should work with persisted objects from previous tasks, allowing for multi-stage pipelines where intermediate results are persisted between TaskGraphs.
Reproduction Code
private static void reduce(@Reduce FloatArray output, FloatArray x) {
output.set(0, 0.0f);
for (@Parallel int i = 0; i < x.getSize(); i++) {
float val = x.get(i) * x.get(i);
output.set(0, output.get(0) + val);
}
}
private static void singleNorm(FloatArray output, int size) {
// .. some compute
}
private static void preProcess(FloatArray x) {
// .. some compute to update x
}
// First TaskGraph that persists data
TaskGraph initTaskGraph = new TaskGraph("pre-process")
.transferToDevice(DataTransferMode.FIRST_EXECUTION, x)
.task("preProcess", RMSNorm::preProcess, x)
.persistOnDevice(x);
// Second TaskGraph that attempts to use persisted data in a reduction
TaskGraph taskGraphLoop = new TaskGraph("benchmark")
.consumeFromDevice(x)
.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights)
.task("reduce", RMSNorm::reduce, output, x) // This fails because x is persisted from another TaskGraph
.task("singleNorm", RMSNorm::singleNorm, output, size);