@@ -1158,6 +1158,139 @@ LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
1158
1158
return success ();
1159
1159
}
1160
1160
1161
+ LogicalResult strided_op_rule_impl (RewriteContext &ctx, Operation &op,
1162
+ Value base_ref, const VectorType &vty,
1163
+ const VectorLayout &layout,
1164
+ const ArrayRef<int32_t > &indices,
1165
+ const ArrayRef<int32_t > &strides) {
1166
+ if (!isa<tpu::StridedLoadOp, tpu::StridedStoreOp>(op)) {
1167
+ return op.emitOpError (" Not implemented: Unsupported strided op" )
1168
+ << op.getName ();
1169
+ }
1170
+ if (layout != VectorLayout (32 , {0 , 0 }, ctx.target_shape ,
1171
+ VectorLayout::ImplicitDim::kNone )) {
1172
+ return op.emitOpError (" Not implemented: Unsupported vector layout in " )
1173
+ << op.getName ();
1174
+ }
1175
+ const auto base_ty = getMemRefType (base_ref);
1176
+ auto rank = base_ty.getRank ();
1177
+ CHECK_EQ (rank, indices.size ());
1178
+ CHECK_EQ (rank, strides.size ());
1179
+ CHECK_EQ (rank, vty.getShape ().size ());
1180
+ if (rank < 2 ) {
1181
+ return op.emitOpError (" Not implemented: Stride on 1D vector" );
1182
+ }
1183
+ auto mem_layout = dyn_cast<TiledLayoutAttr>(base_ty.getLayout ());
1184
+ if (!mem_layout) {
1185
+ return op.emitOpError (" Expected a tiled memref" );
1186
+ }
1187
+ auto tile_strides = mem_layout.getTileStrides ();
1188
+
1189
+ // Currently we hold constraints that the last dim size of memref needs to be
1190
+ // exactly same as the lane size of native vreg and the memref has never
1191
+ // been sliced before on the last dim. In other words, the original base
1192
+ // memref's shape needs to be (..., target_shape[1]).
1193
+ if (base_ty.getShape ()[rank - 1 ] != ctx.target_shape [1 ] ||
1194
+ tile_strides.take_back (2 ) != ArrayRef<int64_t >{1 , 1 }) {
1195
+ return op.emitOpError (" Not Implemented: The last dim size is not " )
1196
+ << ctx.target_shape [1 ] << " in original base memref" ;
1197
+ }
1198
+ if (strides[rank - 1 ] != 1 ) {
1199
+ return op.emitOpError (" Not Implemented: Stride on last dim is not 1" );
1200
+ }
1201
+ if (indices[rank - 1 ] != 0 ) {
1202
+ return op.emitOpError (" Not Implemented: Index on last dim is not 0" );
1203
+ }
1204
+ ImplicitLocOpBuilder builder (op.getLoc (), &op);
1205
+
1206
+ FAILUREOR_ASSIGN_OR_RETURN (
1207
+ VectorType vreg_ty,
1208
+ getNativeVregType (vty.getElementType (), ctx.target_shape ));
1209
+
1210
+ bool is_load_op = true ;
1211
+ xla::Array<Value> tiles (
1212
+ layout.tileArrayShape (vty.getShape (), ctx.target_shape ));
1213
+ if (auto store_op = dyn_cast<tpu::StridedStoreOp>(op)) {
1214
+ is_load_op = false ;
1215
+ FAILUREOR_ASSIGN_OR_RETURN (
1216
+ tiles, disassemble (builder, layout, store_op.getValueToStore (),
1217
+ ctx.target_shape ));
1218
+ }
1219
+
1220
+ tiles.Each ([&](absl::Span<const int64_t > tile_idxs, Value *v) {
1221
+ CHECK_EQ (tile_idxs.size (), rank);
1222
+ SmallVector<Value> idxs (rank);
1223
+ for (int64_t i = 0 ; i < rank; ++i) {
1224
+ int64_t stride = (i < rank - 2 )
1225
+ ? strides[i]
1226
+ : (strides[i] * ctx.target_shape [i - rank + 2 ]);
1227
+ idxs[i] =
1228
+ IdxConst (indices[i] + tile_idxs[i] * stride, builder, op.getLoc ());
1229
+ }
1230
+ SmallVector<bool > sublane_mask (ctx.target_shape [0 ], true );
1231
+ int64_t sublane_rem = vty.getDimSize (rank - 2 ) % ctx.target_shape [0 ];
1232
+ if (sublane_rem > 0 && tile_idxs[rank - 2 ] == tiles.dim (rank - 2 ) - 1 ) {
1233
+ for (int64_t i = sublane_rem; i < ctx.target_shape [0 ]; ++i) {
1234
+ sublane_mask[i] = false ;
1235
+ }
1236
+ }
1237
+ const auto sublane_mask_attr =
1238
+ DenseBoolArrayAttr::get (op.getContext (), sublane_mask);
1239
+ if (is_load_op) {
1240
+ *v = builder.create <tpu::LoadOp>(
1241
+ vreg_ty, base_ref, idxs, sublane_mask_attr,
1242
+ builder.getI32IntegerAttr (strides[rank - 2 ]));
1243
+ } else {
1244
+ builder.create <tpu::StoreOp>(
1245
+ *v, base_ref, idxs, sublane_mask_attr,
1246
+ /* mask=*/ nullptr , builder.getI32IntegerAttr (strides[rank - 2 ]));
1247
+ }
1248
+ });
1249
+ if (is_load_op) {
1250
+ op.replaceAllUsesWith (
1251
+ assemble (builder, vty, layout, std::move (tiles), ctx.target_shape ));
1252
+ }
1253
+ op.erase ();
1254
+ return success ();
1255
+ }
1256
+
1257
+ // TODO(jevinjiang): maybe unify with vector load?
1258
+ LogicalResult tpu_strided_load_rule (RewriteContext &ctx, Operation &op,
1259
+ const ArrayRef<Layout> layouts_in,
1260
+ const ArrayRef<Layout> layouts_out) {
1261
+ TPU_ASSERT_OP (llvm::none_of (layouts_in,
1262
+ [&](const Layout &l) { return l.has_value (); }));
1263
+ TPU_ASSERT_EQ_OP (layouts_out.size (), 1 );
1264
+ TPU_ASSERT_OP (layouts_out.front ().has_value ());
1265
+ const VectorLayout &layout_out = *layouts_out.front ();
1266
+ auto load_op = cast<tpu::StridedLoadOp>(op);
1267
+ const auto base_ref = load_op.getBase ();
1268
+ const auto indices = load_op.getIndices ();
1269
+ const auto strides = load_op.getStrides ();
1270
+ const auto vty = cast<VectorType>(load_op.getResult ().getType ());
1271
+ return strided_op_rule_impl (ctx, op, base_ref, vty, layout_out, indices,
1272
+ strides);
1273
+ }
1274
+
1275
+ // TODO(jevinjiang): maybe unify with vector store?
1276
+ LogicalResult tpu_strided_store_rule (RewriteContext &ctx, Operation &op,
1277
+ const ArrayRef<Layout> layouts_in,
1278
+ const ArrayRef<Layout> layouts_out) {
1279
+ TPU_ASSERT_OP (layouts_in.front ().has_value ());
1280
+ TPU_ASSERT_OP (llvm::none_of (layouts_in.drop_front (),
1281
+ [&](const Layout &l) { return l.has_value (); }));
1282
+ TPU_ASSERT_EQ_OP (layouts_out.size (), 0 );
1283
+
1284
+ const VectorLayout &to_store_layout = *layouts_in.front ();
1285
+ auto store_op = cast<tpu::StridedStoreOp>(op);
1286
+ const auto base_ref = store_op.getBase ();
1287
+ const auto indices = store_op.getIndices ();
1288
+ const auto strides = store_op.getStrides ();
1289
+ const auto vty = store_op.getValueToStore ().getType ();
1290
+ return strided_op_rule_impl (ctx, op, base_ref, vty, to_store_layout, indices,
1291
+ strides);
1292
+ }
1293
+
1161
1294
LogicalResult matmul_rule_impl (RewriteContext &ctx, Operation &op,
1162
1295
const bool transpose_lhs,
1163
1296
const bool transpose_rhs,
@@ -3510,10 +3643,12 @@ const llvm::StringMap<rule_type> &rules() {
3510
3643
{tpu::IotaOp::getOperationName (), tpu_iota_rule},
3511
3644
{tpu::GatherOp::getOperationName (), tpu_gather_rule},
3512
3645
{tpu::LoadOp::getOperationName (), tpu_load_rule},
3646
+ {tpu::StoreOp::getOperationName (), tpu_store_rule},
3647
+ {tpu::StridedLoadOp::getOperationName (), tpu_strided_load_rule},
3648
+ {tpu::StridedStoreOp::getOperationName (), tpu_strided_store_rule},
3513
3649
{tpu::MatmulOp::getOperationName (), tpu_matmul_rule},
3514
3650
{tpu::RegionOp::getOperationName (), tpu_region_rule},
3515
3651
{tpu::RepeatOp::getOperationName (), tpu_repeat_rule},
3516
- {tpu::StoreOp::getOperationName (), tpu_store_rule},
3517
3652
{tpu::BitcastOp::getOperationName (), tpu_bitcast_rule},
3518
3653
{tpu::TraceOp::getOperationName (), tpu_trace_rule},
3519
3654
{tpu::AssumeLayoutOp::getOperationName (), tpu_assume_layout_rule},
0 commit comments