Skip to content

Commit 30208fa

Browse files
bythew3ijax authors
authored andcommitted
[XLA:Mosaic] Support strided load/store memref with arbitrary shape as long as last dim size is 128 and dtype is 32bit.
PiperOrigin-RevId: 614862128
1 parent 6353877 commit 30208fa

File tree

4 files changed

+255
-3
lines changed

4 files changed

+255
-3
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,33 @@ def TPU_LoadOp : TPU_Op<"load"> {
194194
}];
195195
}
196196

197+
def TPU_StridedLoadOp : TPU_Op<"strided_load"> {
198+
let arguments = (ins
199+
AnyMemRef:$base,
200+
DenseI32ArrayAttr:$indices,
201+
DenseI32ArrayAttr:$strides
202+
);
203+
let results = (outs AnyVector:$result);
204+
let assemblyFormat = [{
205+
$base attr-dict `:` type($base) `,` type($result)
206+
}];
207+
let hasVerifier = 1;
208+
}
209+
210+
def TPU_StridedStoreOp : TPU_Op<"strided_store"> {
211+
let arguments = (ins
212+
AnyVector:$valueToStore,
213+
AnyMemRef:$base,
214+
DenseI32ArrayAttr:$indices,
215+
DenseI32ArrayAttr:$strides
216+
);
217+
let results = (outs);
218+
let assemblyFormat = [{
219+
$base `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore)
220+
}];
221+
let hasVerifier = 1;
222+
}
223+
197224
def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> {
198225
let arguments = (ins
199226
AnyVector:$value,

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "mlir/IR/Value.h"
2525
#include "mlir/Support/LLVM.h"
2626
#include "mlir/Support/LogicalResult.h"
27+
#include "mlir/include/mlir/IR/BuiltinTypes.h"
2728
#include "mlir/include/mlir/IR/IRMapping.h"
2829
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
2930

@@ -180,6 +181,56 @@ LogicalResult MemRefSqueezeOp::canonicalize(MemRefSqueezeOp op,
180181
return success();
181182
}
182183

184+
template <typename Op>
185+
LogicalResult verifyStridedOp(Op op, MemRefType memref_ty,
186+
VectorType vector_ty) {
187+
auto indices = op.getIndices();
188+
auto strides = op.getStrides();
189+
if (memref_ty.getRank() != indices.size()) {
190+
op.emitError("Base memref's rank and indices size do not match: ")
191+
<< memref_ty.getRank() << " vs " << indices.size();
192+
return failure();
193+
}
194+
if (memref_ty.getRank() != strides.size()) {
195+
op.emitError("Base memref's rank and strides size do not match: ")
196+
<< memref_ty.getRank() << " vs " << strides.size();
197+
return failure();
198+
}
199+
if (memref_ty.getRank() != vector_ty.getRank()) {
200+
op.emitError("Base memref's rank and result's rank do not match: ")
201+
<< memref_ty.getRank() << " vs " << vector_ty.getRank();
202+
return failure();
203+
}
204+
for (int64_t i = 0; i < memref_ty.getRank(); ++i) {
205+
if (indices[i] < 0 && indices[i] >= memref_ty.getDimSize(i)) {
206+
op.emitError("Indices[")
207+
<< i << "]=" << indices[i] << " is out of range [0, "
208+
<< memref_ty.getDimSize(i) << ")";
209+
return failure();
210+
}
211+
if (strides[i] < 1) {
212+
op.emitError("Strides[") << i << "]=" << strides[i] << " must be >= 1";
213+
return failure();
214+
}
215+
if ((indices[i] + (vector_ty.getDimSize(i) - 1) * strides[i]) >
216+
memref_ty.getDimSize(i)) {
217+
op.emitError() << "Strided slice is out of range at dim " << i;
218+
return failure();
219+
}
220+
}
221+
return success();
222+
}
223+
224+
LogicalResult StridedLoadOp::verify() {
225+
return verifyStridedOp<StridedLoadOp>(*this, getMemRefType(getBase()),
226+
getType());
227+
}
228+
229+
LogicalResult StridedStoreOp::verify() {
230+
return verifyStridedOp<StridedStoreOp>(*this, getMemRefType(getBase()),
231+
getValueToStore().getType());
232+
}
233+
183234
LogicalResult ReinterpretCastOp::verify() {
184235
auto source_type = getMemRefType(getInput());
185236
auto target_type = getType();

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,139 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
11581158
return success();
11591159
}
11601160

1161+
LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op,
1162+
Value base_ref, const VectorType &vty,
1163+
const VectorLayout &layout,
1164+
const ArrayRef<int32_t> &indices,
1165+
const ArrayRef<int32_t> &strides) {
1166+
if (!isa<tpu::StridedLoadOp, tpu::StridedStoreOp>(op)) {
1167+
return op.emitOpError("Not implemented: Unsupported strided op")
1168+
<< op.getName();
1169+
}
1170+
if (layout != VectorLayout(32, {0, 0}, ctx.target_shape,
1171+
VectorLayout::ImplicitDim::kNone)) {
1172+
return op.emitOpError("Not implemented: Unsupported vector layout in ")
1173+
<< op.getName();
1174+
}
1175+
const auto base_ty = getMemRefType(base_ref);
1176+
auto rank = base_ty.getRank();
1177+
CHECK_EQ(rank, indices.size());
1178+
CHECK_EQ(rank, strides.size());
1179+
CHECK_EQ(rank, vty.getShape().size());
1180+
if (rank < 2) {
1181+
return op.emitOpError("Not implemented: Stride on 1D vector");
1182+
}
1183+
auto mem_layout = dyn_cast<TiledLayoutAttr>(base_ty.getLayout());
1184+
if (!mem_layout) {
1185+
return op.emitOpError("Expected a tiled memref");
1186+
}
1187+
auto tile_strides = mem_layout.getTileStrides();
1188+
1189+
// Currently we hold constraints that the last dim size of memref needs to be
1190+
// exactly same as the lane size of native vreg and the memref has never
1191+
// been sliced before on the last dim. In other words, the original base
1192+
// memref's shape needs to be (..., target_shape[1]).
1193+
if (base_ty.getShape()[rank - 1] != ctx.target_shape[1] ||
1194+
tile_strides.take_back(2) != ArrayRef<int64_t>{1, 1}) {
1195+
return op.emitOpError("Not Implemented: The last dim size is not ")
1196+
<< ctx.target_shape[1] << " in original base memref";
1197+
}
1198+
if (strides[rank - 1] != 1) {
1199+
return op.emitOpError("Not Implemented: Stride on last dim is not 1");
1200+
}
1201+
if (indices[rank - 1] != 0) {
1202+
return op.emitOpError("Not Implemented: Index on last dim is not 0");
1203+
}
1204+
ImplicitLocOpBuilder builder(op.getLoc(), &op);
1205+
1206+
FAILUREOR_ASSIGN_OR_RETURN(
1207+
VectorType vreg_ty,
1208+
getNativeVregType(vty.getElementType(), ctx.target_shape));
1209+
1210+
bool is_load_op = true;
1211+
xla::Array<Value> tiles(
1212+
layout.tileArrayShape(vty.getShape(), ctx.target_shape));
1213+
if (auto store_op = dyn_cast<tpu::StridedStoreOp>(op)) {
1214+
is_load_op = false;
1215+
FAILUREOR_ASSIGN_OR_RETURN(
1216+
tiles, disassemble(builder, layout, store_op.getValueToStore(),
1217+
ctx.target_shape));
1218+
}
1219+
1220+
tiles.Each([&](absl::Span<const int64_t> tile_idxs, Value *v) {
1221+
CHECK_EQ(tile_idxs.size(), rank);
1222+
SmallVector<Value> idxs(rank);
1223+
for (int64_t i = 0; i < rank; ++i) {
1224+
int64_t stride = (i < rank - 2)
1225+
? strides[i]
1226+
: (strides[i] * ctx.target_shape[i - rank + 2]);
1227+
idxs[i] =
1228+
IdxConst(indices[i] + tile_idxs[i] * stride, builder, op.getLoc());
1229+
}
1230+
SmallVector<bool> sublane_mask(ctx.target_shape[0], true);
1231+
int64_t sublane_rem = vty.getDimSize(rank - 2) % ctx.target_shape[0];
1232+
if (sublane_rem > 0 && tile_idxs[rank - 2] == tiles.dim(rank - 2) - 1) {
1233+
for (int64_t i = sublane_rem; i < ctx.target_shape[0]; ++i) {
1234+
sublane_mask[i] = false;
1235+
}
1236+
}
1237+
const auto sublane_mask_attr =
1238+
DenseBoolArrayAttr::get(op.getContext(), sublane_mask);
1239+
if (is_load_op) {
1240+
*v = builder.create<tpu::LoadOp>(
1241+
vreg_ty, base_ref, idxs, sublane_mask_attr,
1242+
builder.getI32IntegerAttr(strides[rank - 2]));
1243+
} else {
1244+
builder.create<tpu::StoreOp>(
1245+
*v, base_ref, idxs, sublane_mask_attr,
1246+
/*mask=*/nullptr, builder.getI32IntegerAttr(strides[rank - 2]));
1247+
}
1248+
});
1249+
if (is_load_op) {
1250+
op.replaceAllUsesWith(
1251+
assemble(builder, vty, layout, std::move(tiles), ctx.target_shape));
1252+
}
1253+
op.erase();
1254+
return success();
1255+
}
1256+
1257+
// TODO(jevinjiang): maybe unify with vector load?
1258+
LogicalResult tpu_strided_load_rule(RewriteContext &ctx, Operation &op,
1259+
const ArrayRef<Layout> layouts_in,
1260+
const ArrayRef<Layout> layouts_out) {
1261+
TPU_ASSERT_OP(llvm::none_of(layouts_in,
1262+
[&](const Layout &l) { return l.has_value(); }));
1263+
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
1264+
TPU_ASSERT_OP(layouts_out.front().has_value());
1265+
const VectorLayout &layout_out = *layouts_out.front();
1266+
auto load_op = cast<tpu::StridedLoadOp>(op);
1267+
const auto base_ref = load_op.getBase();
1268+
const auto indices = load_op.getIndices();
1269+
const auto strides = load_op.getStrides();
1270+
const auto vty = cast<VectorType>(load_op.getResult().getType());
1271+
return strided_op_rule_impl(ctx, op, base_ref, vty, layout_out, indices,
1272+
strides);
1273+
}
1274+
1275+
// TODO(jevinjiang): maybe unify with vector store?
1276+
LogicalResult tpu_strided_store_rule(RewriteContext &ctx, Operation &op,
1277+
const ArrayRef<Layout> layouts_in,
1278+
const ArrayRef<Layout> layouts_out) {
1279+
TPU_ASSERT_OP(layouts_in.front().has_value());
1280+
TPU_ASSERT_OP(llvm::none_of(layouts_in.drop_front(),
1281+
[&](const Layout &l) { return l.has_value(); }));
1282+
TPU_ASSERT_EQ_OP(layouts_out.size(), 0);
1283+
1284+
const VectorLayout &to_store_layout = *layouts_in.front();
1285+
auto store_op = cast<tpu::StridedStoreOp>(op);
1286+
const auto base_ref = store_op.getBase();
1287+
const auto indices = store_op.getIndices();
1288+
const auto strides = store_op.getStrides();
1289+
const auto vty = store_op.getValueToStore().getType();
1290+
return strided_op_rule_impl(ctx, op, base_ref, vty, to_store_layout, indices,
1291+
strides);
1292+
}
1293+
11611294
LogicalResult matmul_rule_impl(RewriteContext &ctx, Operation &op,
11621295
const bool transpose_lhs,
11631296
const bool transpose_rhs,
@@ -3510,10 +3643,12 @@ const llvm::StringMap<rule_type> &rules() {
35103643
{tpu::IotaOp::getOperationName(), tpu_iota_rule},
35113644
{tpu::GatherOp::getOperationName(), tpu_gather_rule},
35123645
{tpu::LoadOp::getOperationName(), tpu_load_rule},
3646+
{tpu::StoreOp::getOperationName(), tpu_store_rule},
3647+
{tpu::StridedLoadOp::getOperationName(), tpu_strided_load_rule},
3648+
{tpu::StridedStoreOp::getOperationName(), tpu_strided_store_rule},
35133649
{tpu::MatmulOp::getOperationName(), tpu_matmul_rule},
35143650
{tpu::RegionOp::getOperationName(), tpu_region_rule},
35153651
{tpu::RepeatOp::getOperationName(), tpu_repeat_rule},
3516-
{tpu::StoreOp::getOperationName(), tpu_store_rule},
35173652
{tpu::BitcastOp::getOperationName(), tpu_bitcast_rule},
35183653
{tpu::TraceOp::getOperationName(), tpu_trace_rule},
35193654
{tpu::AssumeLayoutOp::getOperationName(), tpu_assume_layout_rule},

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,11 +232,19 @@ class VectorLayoutInferer {
232232
if (infer(op).failed()) {
233233
return failure();
234234
}
235-
} else if (auto op = dyn_cast<tpu::MatmulOp>(any_op)) {
235+
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
236236
if (infer(op).failed()) {
237237
return failure();
238238
}
239-
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
239+
} else if (auto op = dyn_cast<tpu::StridedLoadOp>(any_op)) {
240+
if (infer(op).failed()) {
241+
return failure();
242+
}
243+
} else if (auto op = dyn_cast<tpu::StridedStoreOp>(any_op)) {
244+
if (infer(op).failed()) {
245+
return failure();
246+
}
247+
} else if (auto op = dyn_cast<tpu::MatmulOp>(any_op)) {
240248
if (infer(op).failed()) {
241249
return failure();
242250
}
@@ -581,6 +589,37 @@ class VectorLayoutInferer {
581589
return success();
582590
}
583591

592+
LogicalResult infer(tpu::StridedLoadOp op) {
593+
auto vty = op.getResult().getType();
594+
int8_t bitwidth = vty.getElementTypeBitWidth();
595+
if (bitwidth != 32) {
596+
NYI("Strided load with non 32-bit data");
597+
}
598+
if (vty.getRank() < 2) {
599+
NYI("Strided load with 1D vector");
600+
}
601+
SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
602+
setLayout(op, in_layout,
603+
VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
604+
ImplicitDim::kNone));
605+
return success();
606+
}
607+
608+
LogicalResult infer(tpu::StridedStoreOp op) {
609+
auto vty = op.getValueToStore().getType();
610+
int8_t bitwidth = vty.getElementTypeBitWidth();
611+
if (bitwidth != 32) {
612+
NYI("Strided store with non 32-bit data");
613+
}
614+
if (vty.getRank() < 2) {
615+
NYI("Strided store with 1D vector");
616+
}
617+
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
618+
ImplicitDim::kNone);
619+
setInLayout(op, {store_layout, kNoLayout});
620+
return success();
621+
}
622+
584623
LogicalResult infer(tpu::MatmulOp op) { return inferMatmul(op); }
585624

586625
LogicalResult infer(tpu::StoreOp op) {

0 commit comments

Comments
 (0)