Skip to content

Commit 2aa11cc

Browse files
authored
Add lowering for bitwise_right_shift. (#8866)
1 parent f2bdecf commit 2aa11cc

File tree

5 files changed

+36
-0
lines changed

5 files changed

+36
-0
lines changed

codegen/xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ full_codegen:
3030
- binary_cross_entropy_backward
3131
- bitwise_not
3232
- bitwise_left_shift.Tensor
33+
- bitwise_right_shift.Tensor
3334
- ceil
3435
- cholesky
3536
- clamp.Tensor

test/test_operations.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,6 +2411,11 @@ def test_bitwise_left_shift_no_fallback(self):
24112411
t2 = torch.randint(0, 10, (2,))
24122412
self._test_no_fallback(torch.bitwise_left_shift, (t1, t2))
24132413

2414+
def test_bitwise_right_shift_no_fallback(self):
2415+
t1 = torch.randint(0, 10, (2, 2))
2416+
t2 = torch.randint(0, 10, (2,))
2417+
self._test_no_fallback(torch.bitwise_right_shift, (t1, t2))
2418+
24142419

24152420
class MNISTComparator(nn.Module):
24162421

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,25 @@ torch_xla::XlaOpVector BitwiseLeftShiftTensor::Lower(
291291
loctx);
292292
}
293293

294+
torch_xla::XlaOpVector BitwiseRightShiftTensor::Lower(
295+
LoweringContext* loctx) const {
296+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
297+
xla::XlaOp xla_other_input = loctx->GetOutputOp(operand(1));
298+
return ReturnOp(
299+
XlaHelpers::PromotedBinaryOp(
300+
xla_input, xla_other_input,
301+
[](xla::XlaOp one, xla::XlaOp two) {
302+
auto broadcast_dims = XlaHelpers::getBroadcastDimensions(one, two);
303+
if (xla::primitive_util::IsSignedIntegralType(
304+
XlaHelpers::TypeOfXlaOp(one))) {
305+
return xla::ShiftRightArithmetic(one, two, broadcast_dims);
306+
} else {
307+
return xla::ShiftRightLogical(one, two, broadcast_dims);
308+
}
309+
}),
310+
loctx);
311+
}
312+
294313
torch_xla::XlaOpVector Ceil::Lower(LoweringContext* loctx) const {
295314
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
296315
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,14 @@ xla::Shape BitwiseLeftShiftTensorOutputShape(const torch::lazy::Value& input,
368368
});
369369
}
370370

371+
xla::Shape BitwiseRightShiftTensorOutputShape(const torch::lazy::Value& input,
372+
const torch::lazy::Value& other) {
373+
return InferBinaryOpShape(input, other, [](xla::XlaOp one, xla::XlaOp two) {
374+
return xla::ShiftRightArithmetic(
375+
one, two, XlaHelpers::getBroadcastDimensions(one, two));
376+
});
377+
}
378+
371379
xla::Shape CeilOutputShape(const torch::lazy::Value& input) {
372380
return GetXlaShape(input);
373381
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ xla::Shape BitwiseXorTensorOutputShape(const torch::lazy::Value& input,
9696
xla::Shape BitwiseLeftShiftTensorOutputShape(const torch::lazy::Value& input,
9797
const torch::lazy::Value& other);
9898

99+
xla::Shape BitwiseRightShiftTensorOutputShape(const torch::lazy::Value& input,
100+
const torch::lazy::Value& other);
101+
99102
xla::Shape CeilOutputShape(const torch::lazy::Value& input);
100103

101104
xla::Shape CholeskyOutputShape(const torch::lazy::Value& input,

0 commit comments

Comments
 (0)