Skip to content

[runtime] Automatic reduction with @Reduce fails when input is a persisted object from previous task #672

@mikepapadim

Description

@mikepapadim

Description

When using Tornado's @Reduce annotation in a task, if one of the task parameters is a persisted object from a previous task in another TaskGraph, the runtime fails to update object state. This appears to be an unsupported use case.

Current Behavior

The system fails when trying to use a reduction operation on data that was persisted from a previous TaskGraph execution.

Expected Behavior

Reduction operations should work with persisted objects from previous tasks, allowing for multi-stage pipelines where intermediate results are persisted between TaskGraphs.

Reproduction Code

private static void reduce(@Reduce FloatArray output, FloatArray x) {
    output.set(0, 0.0f);
    for (@Parallel int i = 0; i < x.getSize(); i++) {
        float val = x.get(i) * x.get(i);
        output.set(0, output.get(0) + val);
    }
}

private static void singleNorm(FloatArray output, int size) {
    // .. some compute
}

private static void preProcess(FloatArray x) {
    // .. some compute to update x
}

// First TaskGraph that persists data
TaskGraph initTaskGraph = new TaskGraph("pre-process")
    .transferToDevice(DataTransferMode.FIRST_EXECUTION, x)
    .task("preProcess", RMSNorm::preProcess, x)
    .persistOnDevice(x);

// Second TaskGraph that attempts to use persisted data in a reduction
TaskGraph taskGraphLoop = new TaskGraph("benchmark")
    .consumeFromDevice(x)
    .transferToDevice(DataTransferMode.FIRST_EXECUTION, weights)
    .task("reduce", RMSNorm::reduce, output, x)  // This fails because x is persisted from another TaskGraph
    .task("singleNorm", RMSNorm::singleNorm, output, size);

Metadata

Metadata

Labels

bugSomething isn't workingruntime

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions