Skip to content

Commit 7578e10

Browse files
bythew3ijax authors
authored andcommitted
[XLA:Mosaic] Support dynamic indices in strided load/store.
PiperOrigin-RevId: 615931990
1 parent ac41032 commit 7578e10

File tree

5 files changed

+46
-35
lines changed

5 files changed

+46
-35
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,12 +197,12 @@ def TPU_LoadOp : TPU_Op<"load"> {
197197
def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
198198
let arguments = (ins
199199
AnyMemRef:$base,
200-
DenseI32ArrayAttr:$indices,
200+
Variadic<Index>:$indices,
201201
DenseI32ArrayAttr:$strides
202202
);
203203
let results = (outs AnyVector:$result);
204204
let assemblyFormat = [{
205-
$base attr-dict `:` type($base) `,` type($result)
205+
$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)
206206
}];
207207
let hasVerifier = 1;
208208
}
@@ -211,12 +211,12 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
211211
let arguments = (ins
212212
AnyVector:$valueToStore,
213213
AnyMemRef:$base,
214-
DenseI32ArrayAttr:$indices,
214+
Variadic<Index>:$indices,
215215
DenseI32ArrayAttr:$strides
216216
);
217217
let results = (outs);
218218
let assemblyFormat = [{
219-
$base `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
219+
$base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
220220
}];
221221
let hasVerifier = 1;
222222
}

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -202,21 +202,10 @@ LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
202202
return failure();
203203
}
204204
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
205-
if (indices[i] < 0 && indices[i] >= memref_ty.getDimSize(i)) {
206-
op.emitError("Indices[")
207-
<< i << "]=" << indices[i] << " is out of range [0, "
208-
<< memref_ty.getDimSize(i) << ")";
209-
return failure();
210-
}
211205
if (strides[i] < 1) {
212206
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1";
213207
return failure();
214208
}
215-
if ((indices[i] + (vector_ty.getDimSize(i) - 1) * strides[i]) >
216-
memref_ty.getDimSize(i)) {
217-
op.emitError() << "Strided slice is out of range at dim " << i;
218-
return failure();
219-
}
220209
}
221210
return success();
222211
}

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,9 +1159,9 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
11591159
}
11601160

11611161
LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
1162-
Value base_ref, const VectorType &vty,
1162+
Value base_ref, ValueRange indices,
1163+
const VectorType &vty,
11631164
const VectorLayout &layout,
1164-
const ArrayRef<int32_t> &indices,
11651165
const ArrayRef<int32_t> &strides) {
11661166
if (!isa<tpu::StridedLoadOp, tpu::StridedStoreOp>(op)) {
11671167
return op.emitOpError("Not implemented: Unsupported strided op")
@@ -1198,7 +1198,10 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
11981198
if (strides[rank - 1] != 1) {
11991199
return op.emitOpError("Not Implemented: Stride on last dim is not 1");
12001200
}
1201-
if (indices[rank - 1] != 0) {
1201+
auto last_idx = getIntConst(indices[rank - 1], /*silent=*/true);
1202+
if (failed(last_idx)) {
1203+
return op.emitOpError("Not Implemented: Dynamic index on last dim");
1204+
} else if (last_idx.value() != 0) {
12021205
return op.emitOpError("Not Implemented: Index on last dim is not 0");
12031206
}
12041207
ImplicitLocOpBuilder builder(op.getLoc(), &op);
@@ -1224,8 +1227,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
12241227
int64_t stride = (i < rank - 2)
12251228
? strides[i]
12261229
: (strides[i] * ctx.target_shape[i - rank + 2]);
1227-
idxs[i] =
1228-
IdxConst(indices[i] + tile_idxs[i] * stride, builder, op.getLoc());
1230+
idxs[i] = builder.create<arith::AddIOp>(
1231+
indices[i], IdxConst(tile_idxs[i] * stride, builder, op.getLoc()));
12291232
}
12301233
SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
12311234
int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0];
@@ -1264,12 +1267,9 @@ LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op,
12641267
TPU_ASSERT_OP(layouts_out.front().has_value());
12651268
const VectorLayout &layout_out = *layouts_out.front();
12661269
auto load_op = cast<tpu::StridedLoadOp>(op);
1267-
const auto base_ref = load_op.getBase();
1268-
const auto indices = load_op.getIndices();
1269-
const auto strides = load_op.getStrides();
12701270
const auto vty = cast<VectorType>(load_op.getResult().getType());
1271-
return strided_op_rule_impl(ctx, op, base_ref, vty, layout_out, indices,
1272-
strides);
1271+
return strided_op_rule_impl(ctx, op, load_op.getBase(), load_op.getIndices(),
1272+
vty, layout_out, load_op.getStrides());
12731273
}
12741274

12751275
// TODO(jevinjiang): maybe unify with vector store?
@@ -1283,12 +1283,10 @@ LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op,
12831283

12841284
const VectorLayout &to_store_layout = *layouts_in.front();
12851285
auto store_op = cast<tpu::StridedStoreOp>(op);
1286-
const auto base_ref = store_op.getBase();
1287-
const auto indices = store_op.getIndices();
1288-
const auto strides = store_op.getStrides();
12891286
const auto vty = store_op.getValueToStore().getType();
1290-
return strided_op_rule_impl(ctx, op, base_ref, vty, to_store_layout, indices,
1291-
strides);
1287+
return strided_op_rule_impl(ctx, op, store_op.getBase(),
1288+
store_op.getIndices(), vty, to_store_layout,
1289+
store_op.getStrides());
12921290
}
12931291

12941292
LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op,

jaxlib/mosaic/dialect/tpu/transforms/debug_assert_insertion.cc

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ rule_type as_generic_rule(void (*rule)(Op)) {
5555

5656
void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices,
5757
ArrayRef<int64_t> window_shape,
58-
ArrayRef<int64_t> full_shape) {
58+
ArrayRef<int64_t> full_shape,
59+
ArrayRef<int32_t> strides = {}) {
5960
if (base_indices.size() != window_shape.size() ||
60-
base_indices.size() != full_shape.size()) {
61+
base_indices.size() != full_shape.size() ||
62+
(!strides.empty() && base_indices.size() != strides.size())) {
6163
return; // Malformed op.
6264
}
6365
if (base_indices.empty()) {
@@ -68,14 +70,15 @@ void assertIsValidSubwindow(Operation *op, mlir::ValueRange base_indices,
6870
for (auto [dim, access] :
6971
llvm::enumerate(llvm::zip(base_indices, window_shape, full_shape))) {
7072
auto [idx, size, bound] = access;
73+
int64_t stride = strides.empty() ? 1 : strides[dim];
7174
Value positive = builder.create<arith::CmpIOp>(
7275
arith::CmpIPredicate::sge, idx,
7376
builder.create<arith::ConstantOp>(builder.getIntegerAttr(idx_type, 0)));
7477
Value in_bounds = builder.create<arith::CmpIOp>(
75-
arith::CmpIPredicate::sle,
78+
arith::CmpIPredicate::slt,
7679
builder.create<arith::AddIOp>(
7780
idx, builder.create<arith::ConstantOp>(
78-
builder.getIntegerAttr(idx_type, size))),
81+
builder.getIntegerAttr(idx_type, (size - 1) * stride))),
7982
builder.create<arith::ConstantOp>(
8083
builder.getIntegerAttr(idx_type, bound)));
8184
std::string msg;
@@ -107,13 +110,32 @@ void tpu_memref_slice_rule(tpu::MemRefSliceOp op) {
107110
/*full_shape=*/op.getMemRef().getType().getShape());
108111
}
109112

113+
void tpu_strided_load_rule(tpu::StridedLoadOp op) {
114+
assertIsValidSubwindow(op, op.getIndices(),
115+
/*window_shape=*/op.getResult().getType().getShape(),
116+
/*full_shape=*/op.getBase().getType().getShape(),
117+
/*strides=*/op.getStrides());
118+
}
119+
120+
void tpu_strided_store_rule(tpu::StridedStoreOp op) {
121+
assertIsValidSubwindow(
122+
op, op.getIndices(),
123+
/*window_shape=*/op.getValueToStore().getType().getShape(),
124+
/*full_shape=*/op.getBase().getType().getShape(),
125+
/*strides=*/op.getStrides());
126+
}
127+
110128
const llvm::StringMap<rule_type> &rules() {
111129
static auto rules = new llvm::StringMap<rule_type>{
112130
// TODO: tpu::LoadOp, tpu::StoreOp
113131
{vector::LoadOp::getOperationName(), as_generic_rule(vector_load_rule)},
114132
{vector::StoreOp::getOperationName(), as_generic_rule(vector_store_rule)},
115133
{tpu::MemRefSliceOp::getOperationName(),
116134
as_generic_rule(tpu_memref_slice_rule)},
135+
{tpu::StridedLoadOp::getOperationName(),
136+
as_generic_rule(tpu_strided_load_rule)},
137+
{tpu::StridedStoreOp::getOperationName(),
138+
as_generic_rule(tpu_strided_store_rule)},
117139
};
118140
return *rules;
119141
}

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,9 @@ class VectorLayoutInferer {
616616
}
617617
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
618618
ImplicitDim::kNone);
619-
setInLayout(op, {store_layout, kNoLayout});
619+
SmallVector<Layout, 5> in_layout{op->getNumOperands(), kNoLayout};
620+
in_layout[0] = store_layout;
621+
setInLayout(op, in_layout);
620622
return success();
621623
}
622624

0 commit comments

Comments
 (0)