Skip to content

Commit 46a7990

Browse files
authored
Add lowering for bitwise_left_shift. (#8865)
1 parent 8030f05 commit 46a7990

File tree

6 files changed

+44
-8
lines changed

6 files changed

+44
-8
lines changed

codegen/xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ full_codegen:
2929
- binary_cross_entropy
3030
- binary_cross_entropy_backward
3131
- bitwise_not
32+
- bitwise_left_shift.Tensor
3233
- ceil
3334
- cholesky
3435
- clamp.Tensor

docs/source/contribute/codegen_migration.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ being incremented correctly.
330330

331331
## Sample PRs
332332

333+
- Lowering of `bitwise_left_shift` <https://github.com/pytorch/xla/pull/8865>
333334
- Unary/Binary OP -\> Codegen erf, erfc, erfinv, and exp
334335
(<https://github.com/pytorch/xla/pull/3659>)
335336
- OP with optional -\> Codegen binary_cross_entropy/backward

test/test_operations.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
all_types_and_complex_and,
3636
all_types_and,
3737
)
38+
import torch.utils._pytree as pytree
3839
import torch_xla
3940
import torch_xla.core.xla_builder as xb
4041
import torch_xla.core.xla_op_registry as xor
@@ -2384,24 +2385,32 @@ def test_cummax_0_sized_dimension(self):
23842385

23852386
self.assertEqual(actual, expected)
23862387

2387-
def test_conj(self):
2388-
# Leave the factory out of the fallback count.
2389-
tensor = torch.rand(2, 2, dtype=torch.complex64)
2390-
2388+
def _test_no_fallback(self, runf, args):
23912389
met.clear_all()
23922390

23932391
def run(device):
2394-
return torch.conj(tensor.to(device))
2392+
args_ = pytree.tree_map_only(torch.Tensor,
2393+
lambda t: t.clone().detach().to(device),
2394+
args)
2395+
return runf(*args_)
23952396

23962397
actual = run("cpu")
23972398
expected = run(xm.xla_device())
23982399

2399-
self.assertEqual(
2400-
met.executed_fallback_ops(), [],
2401-
message="expected no fallback operations.")
2400+
self.assertFalse(
2401+
met.executed_fallback_ops(), msg="expected no fallback operations.")
24022402
self.assertEqual(
24032403
actual, expected.cpu(), message="XLA results should match CPU results.")
24042404

2405+
def test_conj_no_fallback(self):
2406+
tensor = torch.rand(2, 2, dtype=torch.complex64)
2407+
self._test_no_fallback(torch.conj, (tensor,))
2408+
2409+
def test_bitwise_left_shift_no_fallback(self):
2410+
t1 = torch.randint(0, 10, (2, 2))
2411+
t2 = torch.randint(0, 10, (2,))
2412+
self._test_no_fallback(torch.bitwise_left_shift, (t1, t2))
2413+
24052414

24062415
class MNISTComparator(nn.Module):
24072416

torch_xla/csrc/ops/ops_lower_fn.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,20 @@ torch_xla::XlaOpVector BitwiseXorTensor::Lower(LoweringContext* loctx) const {
277277
loctx);
278278
}
279279

280+
torch_xla::XlaOpVector BitwiseLeftShiftTensor::Lower(
281+
LoweringContext* loctx) const {
282+
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
283+
xla::XlaOp xla_other_input = loctx->GetOutputOp(operand(1));
284+
return ReturnOp(XlaHelpers::PromotedBinaryOp(
285+
xla_input, xla_other_input,
286+
[](xla::XlaOp one, xla::XlaOp two) {
287+
return xla::ShiftLeft(
288+
one, two,
289+
XlaHelpers::getBroadcastDimensions(one, two));
290+
}),
291+
loctx);
292+
}
293+
280294
torch_xla::XlaOpVector Ceil::Lower(LoweringContext* loctx) const {
281295
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
282296
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

100755100644
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,14 @@ xla::Shape BitwiseXorTensorOutputShape(const torch::lazy::Value& input,
360360
});
361361
}
362362

363+
xla::Shape BitwiseLeftShiftTensorOutputShape(const torch::lazy::Value& input,
364+
const torch::lazy::Value& other) {
365+
return InferBinaryOpShape(input, other, [](xla::XlaOp one, xla::XlaOp two) {
366+
return xla::ShiftLeft(one, two,
367+
XlaHelpers::getBroadcastDimensions(one, two));
368+
});
369+
}
370+
363371
xla::Shape CeilOutputShape(const torch::lazy::Value& input) {
364372
return GetXlaShape(input);
365373
}

torch_xla/csrc/ops/ops_xla_shape_fn.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ xla::Shape BitwiseOrTensorOutputShape(const torch::lazy::Value& input,
9393
xla::Shape BitwiseXorTensorOutputShape(const torch::lazy::Value& input,
9494
const torch::lazy::Value& other);
9595

96+
xla::Shape BitwiseLeftShiftTensorOutputShape(const torch::lazy::Value& input,
97+
const torch::lazy::Value& other);
98+
9699
xla::Shape CeilOutputShape(const torch::lazy::Value& input);
97100

98101
xla::Shape CholeskyOutputShape(const torch::lazy::Value& input,

0 commit comments

Comments
 (0)