Skip to content

[DRAFT/WIP] Add top-p masking #8871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions test/topp/test_topp_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import torch_xla
from torch_xla import xm
import unittest
from torch_xla.experimental.topp_mask import topp_mask
import math


class TestTopPMask(unittest.TestCase):

def setUp(self):
self.device = torch_xla.device()

def test_invalid_p(self):
# Test that an invalid p throws an assertion error
logits = torch.tensor([[0.2, 0.5, 0.3], [0.1, 0.7, 0.2]],
dtype=torch.float32,
device=self.device)

with self.assertRaises(AssertionError):
topp_mask(logits, -0.1) # p < 0

with self.assertRaises(AssertionError):
topp_mask(logits, 1.1) # p > 1

def test_basic(self):
logits = torch.tensor(
[[math.log(0.2), math.log(0.3),
math.log(0.5)], [math.log(0.5),
math.log(0.2),
math.log(0.3)]],
dtype=torch.float32,
device=self.device)
mask = topp_mask(logits, 0.79)

expected_mask = torch.tensor([[0, 1, 1], [1, 0, 1]],
dtype=torch.float32,
device=self.device)
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))

def test_dim(self):
logits = torch.tensor([[math.log(0.2), math.log(0.3)],
[math.log(0.3), math.log(0.2)],
[math.log(0.5), math.log(0.5)]],
dtype=torch.float32,
device=self.device)
mask = topp_mask(logits, 0.79, dim=0)

expected_mask = torch.tensor([[0, 1], [1, 0], [1, 1]],
dtype=torch.float32,
device=self.device)
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))

def test_p_is_zero(self):
logits = torch.tensor([[0.2, 0.5, 5], [0.1, 2, 0.2]],
dtype=torch.float32,
device=self.device)
mask = topp_mask(logits, 0.0)

expected_mask = torch.tensor([[0, 0, 1], [0, 1, 0]],
dtype=torch.float32,
device=self.device)
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))

def test_p_is_one(self):
logits = torch.tensor([[0.2, 0.5, 5], [0.1, 2, 0.2]],
dtype=torch.float32,
device=self.device)
mask = topp_mask(logits, 1.0)

# All elements should be selected.
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 1]],
dtype=torch.float32,
device=self.device)
self.assertTrue(torch.allclose(expected_mask, mask, atol=1e-6))


if __name__ == '__main__':
unittest.main()
9 changes: 9 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,12 @@ xla::Shape GetTensorShape(const at::Tensor& tensor,
return CreateComputationShapeFromTensor(tensor, &device);
}

at::Tensor TopPMask(const at::Tensor& input, float p, int64_t dim) {
auto result = tensor_methods::topp_mask(bridge::GetXlaTensor(input), p, dim,
/*stable=*/false);
return bridge::AtenFromXlaTensor(std::move(result));
}

py::dict GetMemoryInfo(const std::string& device_str) {
runtime::ComputationClient::MemoryInfo mem_info;
{
Expand Down Expand Up @@ -3008,6 +3014,9 @@ void InitXlaModuleBindings(py::module m) {
[](std::string name, std::shared_ptr<const runtime::PjRtPlugin> plugin) {
runtime::RegisterPjRtPlugin(name, plugin);
});
m.def("_xla_topp_mask", [](const at::Tensor& input, float p, int64_t dim) {
return TopPMask(input, p, dim);
});
py::class_<runtime::PjRtPlugin, PyPjRtPlugin,
std::shared_ptr<runtime::PjRtPlugin>>(m, "PjRtPlugin")
.def(py::init<>())
Expand Down
34 changes: 34 additions & 0 deletions torch_xla/csrc/ops/topp_mask.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#include "torch_xla/csrc/ops/topp_mask.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {

TopPMask::TopPMask(const torch::lazy::Value& input, float p, int64_t dim,
bool stable)
: XlaNode(torch::lazy::OpKind(at::aten::topk), {input}, GetXlaShape(input),
/*num_outputs=*/1, torch::lazy::MHash(p, dim, stable)),
p_(p),
dim_(dim),
stable_(stable) {}

torch::lazy::NodePtr TopPMask::Clone(torch::lazy::OpList operands) const {
return torch_xla::MakeNode<TopPMask>(operands.at(0), p_, dim_, stable_);
}

XlaOpVector TopPMask::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
return ReturnOp(CreateTopPMask(loctx->device(), input, p_, dim_, stable_),
loctx);
}

std::string TopPMask::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", p=" << p_ << ", dim=" << dim_
<< ", stable=" << stable_;
return ss.str();
}

} // namespace torch_xla
32 changes: 32 additions & 0 deletions torch_xla/csrc/ops/topp_mask.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_
#define XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class TopPMask : public XlaNode {
public:
TopPMask(const torch::lazy::Value& input, float p, int64_t dim, bool stable);

std::string ToString() const override;

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

float p() const { return p_; };

int64_t dim() const { return dim_; };

bool stable() const { return stable_; }

private:
float p_;
int64_t dim_;
bool stable_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_TOPP_MASK_H_
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ const OpKindWrapper xla_update_slice("xla::update_slice");
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");
const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call");
const OpKindWrapper xla_topp_mask("xla::topp_mask");

} // namespace torch_xla
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ extern const OpKindWrapper xla_update_slice;
extern const OpKindWrapper xla_custom_sharding;
extern const OpKindWrapper xla_tpu_custom_call;
extern const OpKindWrapper xla_gpu_custom_call;
extern const OpKindWrapper xla_topp_mask;

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_
9 changes: 9 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
#include "torch_xla/csrc/ops/threshold.h"
#include "torch_xla/csrc/ops/threshold_backward.h"
#include "torch_xla/csrc/ops/topk.h"
#include "torch_xla/csrc/ops/topp_mask.h"
#include "torch_xla/csrc/ops/tpu_custom_call.h"
#include "torch_xla/csrc/ops/triangular_solve.h"
#include "torch_xla/csrc/ops/uniform.h"
Expand Down Expand Up @@ -3438,6 +3439,14 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
return std::make_tuple(t1, t2);
}

XLATensorPtr topp_mask(const XLATensorPtr& input, float p, int64_t dim,
bool stable) {
return input->CreateFrom(torch_xla::MakeNode<TopPMask>(
input->GetIrValue(), p,
torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()),
stable));
}

XLATensorPtr trace(const XLATensorPtr& input) {
auto input_shape_ref = input->shape();
XLA_CHECK_EQ((*input_shape_ref).rank(), 2)
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,9 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
bool largest, bool sorted,
bool stable);

XLATensorPtr topp_mask(const XLATensorPtr& input, float p, int64_t dim,
bool stable);

// Returns the sum of the elements of the diagonal of the input 2-D matrix.
XLATensorPtr trace(const XLATensorPtr& input);

Expand Down
47 changes: 47 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/softmax_builder.h"
#include "torch_xla/csrc/tensor_util.h"
#include "xla/client/lib/arithmetic.h"
#include "xla/client/lib/comparators.h"
Expand Down Expand Up @@ -390,6 +391,52 @@ std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, int64_t k, int64_t dim,
xla::PrimitiveType::S64))};
}

xla::XlaOp CreateTopPMask(const torch::lazy::BackendDevice& device,
xla::XlaOp logits, float p, int64_t dim,
bool stable) {
// Convert logits to probabilities.
xla::XlaOp probs = BuildSoftmax(logits, dim);
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(probs);

// Sort the probabilities in ascending order and keep track of the sorted
// indices.
xla::Shape iota_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, shape.dimensions());
xla::XlaOp iota = xla::Iota(probs.builder(), iota_shape, dim);
xla::XlaComputation comparator = xla::CreateScalarLtComputation(
{shape.element_type(), xla::PrimitiveType::S32}, probs.builder());
xla::XlaOp sort_result = xla::Sort({probs, iota}, comparator, dim, stable);

// Compute cumulative probabilities.
xla::XlaOp zero = xla::Zero(probs.builder(), shape.element_type());
xla::XlaComputation reducer =
XlaHelpers::CreateAddComputation(shape.element_type());
xla::XlaOp cumprobs = BuildCumulativeComputation(
xla::GetTupleElement(sort_result, 0), dim, reducer, zero);

// Create a mask for the "p-set" elements.
xla::XlaOp one = xla::One(probs.builder(), shape.element_type());
xla::XlaOp p_mask = BuildThreshold(cumprobs, one, 1 - p, 0);

// The largest element should always be included.
std::vector<int64_t> sizes = XlaHelpers::SizesOfXlaOp(p_mask);
sizes[dim] = 1;
xla::XlaOp ones = xla::Broadcast(one, sizes);
std::vector<int64_t> index_to_update = XlaHelpers::SizesOfXlaOp(p_mask);
for (int i = 0; i < index_to_update.size(); ++i) {
if (i != dim) index_to_update[i] = 0;
}
xla::XlaOp p_mask_updated = BuildUpdateSlice(p_mask, ones, index_to_update);

// Re-order the mask back to pre-sorted order.
xla::XlaOp sorted_indices = xla::GetTupleElement(sort_result, 1);
ScatterOptions options(/*combiner=*/nullptr);
xla::XlaOp p_mask_reordered = CreateScatter(
device, p_mask_updated, sorted_indices, p_mask_updated, dim, options);

return p_mask_reordered;
}

xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs) {
// Expand cases in https://pytorch.org/docs/stable/torch.html#torch.matmul
xla::Shape lhs_shape = ShapeHelper::ShapeOfXlaOp(lhs);
Expand Down
3 changes: 3 additions & 0 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ std::vector<xla::XlaOp> CreateKthValue(xla::XlaOp input, int64_t k, int64_t dim,
std::vector<xla::XlaOp> CreateTopK(xla::XlaOp input, int64_t k, int64_t dim,
bool largest, bool stable);

xla::XlaOp CreateTopPMask(const torch::lazy::BackendDevice& device,
xla::XlaOp input, float p, int64_t dim, bool stable);

xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs);

xla::XlaOp BuildMatMul(xla::XlaOp lhs, xla::XlaOp rhs, xla::XlaOp bias);
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/experimental/topp_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import torch_xla

def topp_mask(logits, p, dim=None):
assert p >= 0 and p <= 1.0, "p must be in [0, 1]."
if dim is None:
dim = -1
return torch_xla._XLAC._xla_topp_mask(logits, p, dim)