Skip to content

Commit 6691bb6

Browse files
committed
Add top-p
1 parent 760675a commit 6691bb6

File tree

10 files changed

+105
-1
lines changed

10 files changed

+105
-1
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,12 @@ xla::Shape GetTensorShape(const at::Tensor& tensor,
818818
return CreateComputationShapeFromTensor(tensor, &device);
819819
}
820820

821+
at::Tensor TopPMask(const at::Tensor& input, float p, int64_t dim) {
822+
auto result = tensor_methods::topp_mask(bridge::GetXlaTensor(input), p, dim,
823+
/*stable=*/false);
824+
return bridge::AtenFromXlaTensor(std::move(result));
825+
}
826+
821827
py::dict GetMemoryInfo(const std::string& device_str) {
822828
runtime::ComputationClient::MemoryInfo mem_info;
823829
{
@@ -3008,6 +3014,9 @@ void InitXlaModuleBindings(py::module m) {
30083014
[](std::string name, std::shared_ptr<const runtime::PjRtPlugin> plugin) {
30093015
runtime::RegisterPjRtPlugin(name, plugin);
30103016
});
3017+
m.def("_xla_topp_mask", [](const at::Tensor& input, float p, int64_t dim) {
3018+
return TopPMask(input, p, dim);
3019+
});
30113020
py::class_<runtime::PjRtPlugin, PyPjRtPlugin,
30123021
std::shared_ptr<runtime::PjRtPlugin>>(m, "PjRtPlugin")
30133022
.def(py::init<>())

torch_xla/csrc/ops/topp_mask.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include "torch_xla/csrc/ops/topp_mask.h"
2+
3+
#include "torch_xla/csrc/lowering_context.h"
4+
#include "torch_xla/csrc/ops/infer_output_shape.h"
5+
#include "torch_xla/csrc/xla_lower_util.h"
6+
7+
namespace torch_xla {
8+
9+
TopPMask::TopPMask(const torch::lazy::Value& input, float p, int64_t dim,
10+
bool stable)
11+
: XlaNode(torch::lazy::OpKind(at::aten::topk), {input}, GetXlaShape(input),
12+
/*num_outputs=*/1, torch::lazy::MHash(p, dim, stable)),
13+
p_(p),
14+
dim_(dim),
15+
stable_(stable) {}
16+
17+
torch::lazy::NodePtr TopPMask::Clone(torch::lazy::OpList operands) const {
18+
return torch_xla::MakeNode<TopPMask>(operands.at(0), p_, dim_, stable_);
19+
}
20+
21+
XlaOpVector TopPMask::Lower(LoweringContext* loctx) const {
22+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
23+
xla::XlaOp output = CreateTopPMask(input, p_, dim_, stable_);
24+
return ReturnOp(output, loctx);
25+
}
26+
27+
std::string TopPMask::ToString() const {
28+
std::stringstream ss;
29+
ss << XlaNode::ToString() << ", p=" << p_ << ", dim=" << dim_
30+
<< ", stable=" << stable_;
31+
return ss.str();
32+
}
33+
34+
} // namespace torch_xla

torch_xla/csrc/ops/topp_mask.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_
2+
#define XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_
3+
4+
#include "torch_xla/csrc/ir.h"
5+
6+
namespace torch_xla {
7+
8+
class TopPMask : public XlaNode {
9+
public:
10+
TopPMask(const torch::lazy::Value& input, float p, int64_t dim, bool stable);
11+
12+
std::string ToString() const override;
13+
14+
torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;
15+
16+
XlaOpVector Lower(LoweringContext* loctx) const override;
17+
18+
float p() const { return p_; };
19+
20+
int64_t dim() const { return dim_; };
21+
22+
bool stable() const { return stable_; }
23+
24+
private:
25+
float p_;
26+
int64_t dim_;
27+
bool stable_;
28+
};
29+
30+
} // namespace torch_xla
31+
32+
#endif // XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_

torch_xla/csrc/ops/xla_ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,6 @@ const OpKindWrapper xla_update_slice("xla::update_slice");
4040
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
4141
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");
4242
const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call");
43+
const OpKindWrapper xla_topp_mask("xla::topp_mask");
4344

4445
} // namespace torch_xla

torch_xla/csrc/ops/xla_ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ extern const OpKindWrapper xla_update_slice;
6565
extern const OpKindWrapper xla_custom_sharding;
6666
extern const OpKindWrapper xla_tpu_custom_call;
6767
extern const OpKindWrapper xla_gpu_custom_call;
68+
extern const OpKindWrapper xla_topp_mask;
6869

6970
} // namespace torch_xla
7071

71-
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_
72+
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_

torch_xla/csrc/tensor_methods.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
#include "torch_xla/csrc/ops/threshold.h"
133133
#include "torch_xla/csrc/ops/threshold_backward.h"
134134
#include "torch_xla/csrc/ops/topk.h"
135+
#include "torch_xla/csrc/ops/topp_mask.h"
135136
#include "torch_xla/csrc/ops/tpu_custom_call.h"
136137
#include "torch_xla/csrc/ops/triangular_solve.h"
137138
#include "torch_xla/csrc/ops/uniform.h"
@@ -3438,6 +3439,14 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
34383439
return std::make_tuple(t1, t2);
34393440
}
34403441

3442+
XLATensorPtr topp_mask(const XLATensorPtr& input, float p, int64_t dim,
3443+
bool stable) {
3444+
return input->CreateFrom(torch_xla::MakeNode<TopPMask>(
3445+
input->GetIrValue(), p,
3446+
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()),
3447+
stable));
3448+
}
3449+
34413450
XLATensorPtr trace(const XLATensorPtr& input) {
34423451
auto input_shape_ref = input->shape();
34433452
XLA_CHECK_EQ((*input_shape_ref).rank(), 2)

torch_xla/csrc/tensor_methods.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,9 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
973973
bool largest, bool sorted,
974974
bool stable);
975975

976+
XLATensorPtr topp_mask(const XLATensorPtr& input, float p, int64_t dim,
977+
bool stable);
978+
976979
// Returns the sum of the elements of the diagonal of the input 2-D matrix.
977980
XLATensorPtr trace(const XLATensorPtr& input);
978981

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,11 @@ std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, int64_t k, int64_t dim,
390390
xla::PrimitiveType::S64))};
391391
}
392392

393+
xla::XlaOp CreateTopPMask(xla::XlaOp input, float p, int64_t dim, bool stable) {
394+
// TODO: implement
395+
return input;
396+
}
397+
393398
xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs) {
394399
// Expand cases in https://pytorch.org/docs/stable/torch.html#torch.matmul
395400
xla::Shape lhs_shape = ShapeHelper::ShapeOfXlaOp(lhs);

torch_xla/csrc/xla_lower_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ std::vector<xla::XlaOp> CreateKthValue(xla::XlaOp input, int64_t k, int64_t dim,
2020
std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, int64_t k, int64_t dim,
2121
bool largest, bool stable);
2222

23+
xla::XlaOp CreateTopPMask(xla::XlaOp input, float p, int64_t dim, bool stable);
24+
2325
xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs);
2426

2527
xla::XlaOp BuildMatMul(xla::XlaOp lhs, xla::XlaOp rhs, xla::XlaOp bias);

torch_xla/experimental/topp_mask.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch_xla
2+
3+
def topp_mask(logits, p, dim=None):
4+
assert p >= 0 and p <= 1.0, "p must be in [0, 1]."
5+
if dim is None:
6+
dim = -1
7+
return torch_xla._XLAC._xla_topp_mask(logits, p, dim)
8+

0 commit comments

Comments
 (0)