@@ -1421,6 +1421,28 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
1421
1421
return %0 : !torch.vtensor <[1 ,12 ,5 ,5 ],f32 >
1422
1422
}
1423
1423
1424
+ // -----
1425
+ // CHECK-LABEL: func.func @torch.aten.where.self_differing_rank_inputs(
1426
+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4],i1>,
1427
+ // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>,
1428
+ // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[1,3,1,1,5,4],f32>) -> !torch.vtensor<[1,3,1,1,5,4],f32> {
1429
+ // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_2]] : !torch.vtensor<[1,3,1,1,5,4],f32> -> tensor<1x3x1x1x5x4xf32>
1430
+ // CHECK: %[[VAL_4:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[],f32> -> tensor<f32>
1431
+ // CHECK: %[[VAL_5:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4],i1> -> tensor<5x4xi1>
1432
+ // CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2>
1433
+ // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_6]] : (tensor<f32>, !tosa.shape<2>) -> tensor<1x1xf32>
1434
+ // CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 1, 5, 4]> : tensor<6xindex>} : () -> !tosa.shape<6>
1435
+ // CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_8]] : (tensor<5x4xi1>, !tosa.shape<6>) -> tensor<1x1x1x1x5x4xi1>
1436
+ // CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6>
1437
+ // CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_10]] : (tensor<1x1xf32>, !tosa.shape<6>) -> tensor<1x1x1x1x1x1xf32>
1438
+ // CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_9]], %[[VAL_11]], %[[VAL_3]] : (tensor<1x1x1x1x5x4xi1>, tensor<1x1x1x1x1x1xf32>, tensor<1x3x1x1x5x4xf32>) -> tensor<1x3x1x1x5x4xf32>
1439
+ // CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x3x1x1x5x4xf32> -> !torch.vtensor<[1,3,1,1,5,4],f32>
1440
+ // CHECK: return %[[VAL_13]]
1441
+ func.func @torch.aten.where.self_differing_rank_inputs (%40: !torch.vtensor <[5 ,4 ],i1 >, %41: !torch.vtensor <[],f32 >, %38 : !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >) -> (!torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >) {
1442
+ %42 = torch.aten.where.self %40 , %41 , %38 : !torch.vtensor <[5 ,4 ],i1 >, !torch.vtensor <[],f32 >, !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 > -> !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >
1443
+ return %42: !torch.vtensor <[1 ,3 ,1 ,1 ,5 ,4 ],f32 >
1444
+ }
1445
+
1424
1446
// -----
1425
1447
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
1426
1448
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
0 commit comments