Skip to content

Commit cba9ff9

Browse files
authored
Lower isneginf(). (#8912)
1 parent 985fa18 commit cba9ff9

File tree

6 files changed

+26
-0
lines changed

6 files changed

+26
-0
lines changed

codegen/xla_native_functions.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ full_codegen:
6262
- hardswish_backward
6363
- inverse
6464
- isnan
65+
- isneginf
6566
- leaky_relu
6667
- le.Scalar
6768
- le.Tensor

test/test_operations.py

+8
Original file line numberDiff line numberDiff line change
@@ -2416,6 +2416,14 @@ def test_bitwise_right_shift_no_fallback(self):
24162416
t2 = torch.randint(0, 10, (2,))
24172417
self._test_no_fallback(torch.bitwise_right_shift, (t1, t2))
24182418

2419+
def test_isneginf_no_fallback(self):
2420+
t = torch.rand(10)
2421+
# Scale the tensor elements.
2422+
t = t * 100_000
2423+
# Convert to a lower precision data-type so as to get a few infs.
2424+
t = t.to(torch.float16)
2425+
self._test_no_fallback(torch.isneginf, (t,))
2426+
24192427

24202428
class MNISTComparator(nn.Module):
24212429

test/test_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def get_allowed_ops_map(
127127
AllowedOpInfoEntry('imag'),
128128
AllowedOpInfoEntry('inverse'),
129129
AllowedOpInfoEntry('isin'),
130+
AllowedOpInfoEntry('isneginf'),
130131
AllowedOpInfoEntry('le'),
131132
AllowedOpInfoEntry('linalg.cholesky'),
132133
AllowedOpInfoEntry('linalg.cholesky_ex'),

torch_xla/csrc/ops/ops_lower_fn.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "torch_xla/csrc/xla_lower_util.h"
1414
#include "xla/client/lib/math.h"
1515
#include "xla/client/lib/matrix.h"
16+
#include "xla/hlo/builder/lib/constants.h"
1617
#include "xla/hlo/builder/lib/logdet.h"
1718

1819
namespace torch_xla {
@@ -564,6 +565,13 @@ torch_xla::XlaOpVector Isnan::Lower(LoweringContext* loctx) const {
564565
return ReturnOp(xla::IsNan(xla_input), loctx);
565566
}
566567

568+
torch_xla::XlaOpVector Isneginf::Lower(LoweringContext* loctx) const {
569+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
570+
return ReturnOp(xla::Eq(input, xla::MinValue(input.builder(),
571+
XlaHelpers::TypeOfXlaOp(input))),
572+
loctx);
573+
}
574+
567575
torch_xla::XlaOpVector LeakyRelu::Lower(LoweringContext* loctx) const {
568576
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
569577
xla::XlaOp negative_slope = loctx->GetOutputOp(operand(1));

torch_xla/csrc/ops/ops_xla_shape_fn.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,12 @@ xla::Shape IsnanOutputShape(const torch::lazy::Value& input) {
606606
return isnan_shape;
607607
}
608608

609+
xla::Shape IsneginfOutputShape(const torch::lazy::Value& input) {
610+
xla::Shape shape(GetXlaShape(input));
611+
shape.set_element_type(xla::PRED);
612+
return shape;
613+
}
614+
609615
xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
610616
const torch::lazy::Value& negative_slope) {
611617
auto lower_for_shape_fn =

torch_xla/csrc/ops/ops_xla_shape_fn.h

+2
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@ xla::Shape InverseOutputShape(const torch::lazy::Value& input);
180180

181181
xla::Shape IsnanOutputShape(const torch::lazy::Value& input);
182182

183+
xla::Shape IsneginfOutputShape(const torch::lazy::Value& input);
184+
183185
xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
184186
const torch::lazy::Value& negative_slope);
185187

0 commit comments

Comments
 (0)