Skip to content

Commit a5823c2

Browse files
committed
Add top-p
1 parent 760675a commit a5823c2

File tree

11 files changed

+229
-1
lines changed

11 files changed

+229
-1
lines changed

test/topp/test_topp_mask.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch_xla
3+
from torch_xla import xm
4+
import unittest
5+
from torch_xla.experimental.topp_mask import topp_mask
6+
import math
7+
8+
9+
class TestTopPMask(unittest.TestCase):
10+
11+
def setUp(self):
12+
self.device = torch_xla.device()
13+
14+
def test_invalid_p(self):
15+
# Test that an invalid p throws an assertion error
16+
logits = torch.tensor([[0.2, 0.5, 0.3], [0.1, 0.7, 0.2]],
17+
dtype=torch.float32,
18+
device=self.device)
19+
20+
with self.assertRaises(AssertionError):
21+
topp_mask(logits, -0.1) # p < 0
22+
23+
with self.assertRaises(AssertionError):
24+
topp_mask(logits, 1.1) # p > 1
25+
26+
def test_basic(self):
27+
logits = torch.tensor(
28+
[[math.log(0.2), math.log(0.3),
29+
math.log(0.5)], [math.log(0.5),
30+
math.log(0.2),
31+
math.log(0.3)]],
32+
dtype=torch.float32,
33+
device=self.device)
34+
mask = topp_mask(logits, 0.79)
35+
36+
expected_mask = torch.tensor([[0, 1, 1], [1, 0, 1]],
37+
dtype=torch.float32,
38+
device=self.device)
39+
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))
40+
41+
def test_dim(self):
42+
logits = torch.tensor([[math.log(0.2), math.log(0.3)],
43+
[math.log(0.3), math.log(0.2)],
44+
[math.log(0.5), math.log(0.5)]],
45+
dtype=torch.float32,
46+
device=self.device)
47+
mask = topp_mask(logits, 0.79, dim=0)
48+
49+
print(mask)
50+
51+
expected_mask = torch.tensor([[0, 1], [1, 0], [1, 1]],
52+
dtype=torch.float32,
53+
device=self.device)
54+
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))
55+
56+
def test_p_is_zero(self):
57+
logits = torch.tensor([[0.2, 0.5, 5], [0.1, 2, 0.2]],
58+
dtype=torch.float32,
59+
device=self.device)
60+
mask = topp_mask(logits, 0.0)
61+
62+
expected_mask = torch.tensor([[0, 0, 1], [0, 1, 0]],
63+
dtype=torch.float32,
64+
device=self.device)
65+
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))
66+
67+
def test_p_is_one(self):
68+
logits = torch.tensor([[0.2, 0.5, 5], [0.1, 2, 0.2]],
69+
dtype=torch.float32,
70+
device=self.device)
71+
mask = topp_mask(logits, 1.0)
72+
73+
# All elements should be selected.
74+
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 1]],
75+
dtype=torch.float32,
76+
device=self.device)
77+
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))
78+
79+
80+
if __name__ == '__main__':
81+
unittest.main()

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,12 @@ xla::Shape GetTensorShape(const at::Tensor& tensor,
818818
return CreateComputationShapeFromTensor(tensor, &device);
819819
}
820820

821+
at::Tensor TopPMask(const at::Tensor& input, float p, int64_t dim) {
822+
auto result = tensor_methods::topp_mask(bridge::GetXlaTensor(input), p, dim,
823+
/*stable=*/false);
824+
return bridge::AtenFromXlaTensor(std::move(result));
825+
}
826+
821827
py::dict GetMemoryInfo(const std::string& device_str) {
822828
runtime::ComputationClient::MemoryInfo mem_info;
823829
{
@@ -3008,6 +3014,9 @@ void InitXlaModuleBindings(py::module m) {
30083014
[](std::string name, std::shared_ptr<const runtime::PjRtPlugin> plugin) {
30093015
runtime::RegisterPjRtPlugin(name, plugin);
30103016
});
3017+
m.def("_xla_topp_mask", [](const at::Tensor& input, float p, int64_t dim) {
3018+
return TopPMask(input, p, dim);
3019+
});
30113020
py::class_<runtime::PjRtPlugin, PyPjRtPlugin,
30123021
std::shared_ptr<runtime::PjRtPlugin>>(m, "PjRtPlugin")
30133022
.def(py::init<>())

torch_xla/csrc/ops/topp_mask.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "torch_xla/csrc/ops/topp_mask.h"
2+
3+
#include "torch_xla/csrc/lowering_context.h"
4+
#include "torch_xla/csrc/ops/infer_output_shape.h"
5+
#include "torch_xla/csrc/xla_lower_util.h"
6+
7+
namespace torch_xla {
8+
9+
TopPMask::TopPMask(const torch::lazy::Value& input, float p, int64_t dim,
10+
bool stable)
11+
: XlaNode(torch::lazy::OpKind(at::aten::topk), {input}, GetXlaShape(input),
12+
/*num_outputs=*/1, torch::lazy::MHash(p, dim, stable)),
13+
p_(p),
14+
dim_(dim),
15+
stable_(stable) {}
16+
17+
torch::lazy::NodePtr TopPMask::Clone(torch::lazy::OpList operands) const {
18+
return torch_xla::MakeNode<TopPMask>(operands.at(0), p_, dim_, stable_);
19+
}
20+
21+
XlaOpVector TopPMask::Lower(LoweringContext* loctx) const {
22+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
23+
return ReturnOp(CreateTopPMask(loctx->device(), input, p_, dim_, stable_),
24+
loctx);
25+
}
26+
27+
std::string TopPMask::ToString() const {
28+
std::stringstream ss;
29+
ss << XlaNode::ToString() << ", p=" << p_ << ", dim=" << dim_
30+
<< ", stable=" << stable_;
31+
return ss.str();
32+
}
33+
34+
} // namespace torch_xla

torch_xla/csrc/ops/topp_mask.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_
2+
#define XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_
3+
4+
#include "torch_xla/csrc/ir.h"
5+
6+
namespace torch_xla {
7+
8+
class TopPMask : public XlaNode {
9+
public:
10+
TopPMask(const torch::lazy::Value& input, float p, int64_t dim, bool stable);
11+
12+
std::string ToString() const override;
13+
14+
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
15+
16+
XlaOpVector Lower(LoweringContext* loctx) const override;
17+
18+
float p() const { return p_; };
19+
20+
int64_t dim() const { return dim_; };
21+
22+
bool stable() const { return stable_; }
23+
24+
private:
25+
float p_;
26+
int64_t dim_;
27+
bool stable_;
28+
};
29+
30+
} // namespace torch_xla
31+
32+
#endif // XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_

torch_xla/csrc/ops/xla_ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,6 @@ const OpKindWrapper xla_update_slice("xla::update_slice");
4040
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
4141
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");
4242
const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call");
43+
const OpKindWrapper xla_topp_mask("xla::topp_mask");
4344

4445
} // namespace torch_xla

torch_xla/csrc/ops/xla_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ extern const OpKindWrapper xla_update_slice;
6565
extern const OpKindWrapper xla_custom_sharding;
6666
extern const OpKindWrapper xla_tpu_custom_call;
6767
extern const OpKindWrapper xla_gpu_custom_call;
68+
extern const OpKindWrapper xla_topp_mask;
6869

6970
} // namespace torch_xla
7071

71-
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_
72+
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_

torch_xla/csrc/tensor_methods.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
#include "torch_xla/csrc/ops/threshold.h"
133133
#include "torch_xla/csrc/ops/threshold_backward.h"
134134
#include "torch_xla/csrc/ops/topk.h"
135+
#include "torch_xla/csrc/ops/topp_mask.h"
135136
#include "torch_xla/csrc/ops/tpu_custom_call.h"
136137
#include "torch_xla/csrc/ops/triangular_solve.h"
137138
#include "torch_xla/csrc/ops/uniform.h"
@@ -3438,6 +3439,14 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
34383439
return std::make_tuple(t1, t2);
34393440
}
34403441

3442+
XLATensorPtr topp_mask(const XLATensorPtr& input, float p, int64_t dim,
3443+
bool stable) {
3444+
return input->CreateFrom(torch_xla::MakeNode<TopPMask>(
3445+
input->GetIrValue(), p,
3446+
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()),
3447+
stable));
3448+
}
3449+
34413450
XLATensorPtr trace(const XLATensorPtr& input) {
34423451
auto input_shape_ref = input->shape();
34433452
XLA_CHECK_EQ((*input_shape_ref).rank(), 2)

torch_xla/csrc/tensor_methods.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,9 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
973973
bool largest, bool sorted,
974974
bool stable);
975975

976+
XLATensorPtr topp_mask(const XLATensorPtr& input, float p, int64_t dim,
977+
bool stable);
978+
976979
// Returns the sum of the elements of the diagonal of the input 2-D matrix.
977980
XLATensorPtr trace(const XLATensorPtr& input);
978981

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "torch_xla/csrc/runtime/debug_macros.h"
1818
#include "torch_xla/csrc/runtime/util.h"
1919
#include "torch_xla/csrc/shape_helper.h"
20+
#include "torch_xla/csrc/softmax_builder.h"
2021
#include "torch_xla/csrc/tensor_util.h"
2122
#include "xla/client/lib/arithmetic.h"
2223
#include "xla/client/lib/comparators.h"
@@ -390,6 +391,52 @@ std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, int64_t k, int64_t dim,
390391
xla::PrimitiveType::S64))};
391392
}
392393

394+
xla::XlaOp CreateTopPMask(const torch::lazy::BackendDevice& device,
395+
xla::XlaOp logits, float p, int64_t dim,
396+
bool stable) {
397+
// Convert logits to probabilities.
398+
xla::XlaOp probs = BuildSoftmax(logits, dim);
399+
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(probs);
400+
401+
// Sort the probabilities in ascending order and keep track of the sorted
402+
// indices.
403+
xla::Shape iota_shape =
404+
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions());
405+
xla::XlaOp iota = xla::Iota(probs.builder(), iota_shape, dim);
406+
xla::XlaComputation comparator = xla::CreateScalarLtComputation(
407+
{shape.element_type(), xla::PrimitiveType::S32}, probs.builder());
408+
xla::XlaOp sort_result = xla::Sort({probs, iota}, comparator, dim, stable);
409+
410+
// Compute cumulative probabilities.
411+
xla::XlaOp zero = xla::Zero(probs.builder(), shape.element_type());
412+
xla::XlaComputation reducer =
413+
XlaHelpers::CreateAddComputation(shape.element_type());
414+
xla::XlaOp cumprobs = BuildCumulativeComputation(
415+
xla::GetTupleElement(sort_result, 0), dim, reducer, zero);
416+
417+
// Create a mask for the "p-set" elements.
418+
xla::XlaOp one = xla::One(probs.builder(), shape.element_type());
419+
xla::XlaOp p_mask = BuildThreshold(cumprobs, one, 1 - p, 0);
420+
421+
// The largest element should always be included.
422+
std::vector<int64_t> sizes = XlaHelpers::SizesOfXlaOp(p_mask);
423+
sizes[dim] = 1;
424+
xla::XlaOp ones = xla::Broadcast(one, sizes);
425+
std::vector<int64_t> index_to_update = XlaHelpers::SizesOfXlaOp(p_mask);
426+
for (int i = 0; i < index_to_update.size(); ++i) {
427+
if (i != dim) index_to_update[i] = 0;
428+
}
429+
xla::XlaOp p_mask_updated = BuildUpdateSlice(p_mask, ones, index_to_update);
430+
431+
// Re-order the mask back to pre-sorted order.
432+
xla::XlaOp sorted_indices = xla::GetTupleElement(sort_result, 1);
433+
ScatterOptions options(/*combiner=*/nullptr);
434+
xla::XlaOp p_mask_reordered = CreateScatter(
435+
device, p_mask_updated, sorted_indices, p_mask_updated, dim, options);
436+
437+
return p_mask_reordered;
438+
}
439+
393440
xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs) {
394441
// Expand cases in https://pytorch.org/docs/stable/torch.html#torch.matmul
395442
xla::Shape lhs_shape = ShapeHelper::ShapeOfXlaOp(lhs);

torch_xla/csrc/xla_lower_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ std::vector<xla::XlaOp> CreateKthValue(xla::XlaOp input, int64_t k, int64_t dim,
2020
std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, int64_t k, int64_t dim,
2121
bool largest, bool stable);
2222

23+
xla::XlaOp CreateTopPMask(const torch::lazy::BackendDevice& device,
24+
xla::XlaOp input, float p, int64_t dim, bool stable);
25+
2326
xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs);
2427

2528
xla::XlaOp BuildMatMul(xla::XlaOp lhs, xla::XlaOp rhs, xla::XlaOp bias);

0 commit comments

Comments
 (0)