Skip to content

Commit 29404a1

Browse files
authored
Pass in weight tensor via kernel function arg for fused_linear_cross_entropy (#283)
1 parent 95bdbbe commit 29404a1

File tree

1 file changed

+24
-27
lines changed
  • tritonbench/operators/fused_linear_cross_entropy

1 file changed

+24
-27
lines changed

tritonbench/operators/fused_linear_cross_entropy/operator.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,38 +30,30 @@ def parse_op_args(args: List[str]):
3030
class TorchLMHeadCE(torch.nn.Module):
3131
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
3232
33-
:param H: hidden size
34-
:param V: vocab size
3533
:param ignore_index: index to ignore
3634
:param reduction: reduction method
3735
"""
3836

39-
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
37+
def __init__(self, ignore_index: int = -100):
4038
super().__init__()
41-
self.lin = torch.nn.Linear(
42-
in_features=H, out_features=V, bias=False, dtype=dtype
43-
)
4439
self.ce_loss = torch.nn.CrossEntropyLoss(
4540
ignore_index=ignore_index, reduction="mean"
4641
)
4742

48-
def forward(self, input, target):
49-
logits = self.lin(input)
43+
def forward(self, input, weight, target):
44+
logits = torch.nn.functional.linear(input, weight)
5045
return self.ce_loss(logits, target)
5146

5247

5348
class LigerLMHeadCE(torch.nn.Module):
54-
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
49+
def __init__(self, ignore_index: int = -100):
5550
super().__init__()
56-
self.lin = torch.nn.Linear(
57-
in_features=H, out_features=V, bias=False, dtype=dtype
58-
)
5951
self.ce_loss = LigerFusedLinearCrossEntropyLoss(
6052
ignore_index=ignore_index, reduction="mean"
6153
)
6254

63-
def forward(self, input, target):
64-
return self.ce_loss(self.lin.weight, input, target)
55+
def forward(self, input, weight, target):
56+
return self.ce_loss(weight, input, target)
6557

6658

6759
class Operator(BenchmarkOperator):
@@ -72,12 +64,17 @@ def __init__(
7264
op_args = parse_op_args(self.extra_args)
7365
self.hidden_size = op_args.hidden_size
7466
self.vocab_size = op_args.vocab_size
75-
self.baseline_model = TorchLMHeadCE(
76-
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
77-
).to(self.device)
78-
self.liger_model = LigerLMHeadCE(
79-
H=self.hidden_size, V=self.vocab_size, dtype=self.dtype
80-
).to(self.device)
67+
# Create the shared weight tensor
68+
self.weight = torch.randn(
69+
self.vocab_size,
70+
self.hidden_size,
71+
dtype=self.dtype,
72+
device=self.device,
73+
requires_grad=True,
74+
)
75+
76+
self.baseline_model = TorchLMHeadCE().to(self.device)
77+
self.liger_model = LigerLMHeadCE().to(self.device)
8178

8279
def get_input_iter(self) -> Generator:
8380
for BT in [2**i for i in range(12, 16)]:
@@ -91,20 +88,20 @@ def get_input_iter(self) -> Generator:
9188
target = torch.randint(
9289
self.vocab_size, (BT, 1), dtype=torch.long, device=self.device
9390
).squeeze(1)
94-
yield _input, target
91+
yield _input, self.weight, target
9592

9693
@register_benchmark(baseline=True)
97-
def torch_lm_head_ce(self, input, target) -> Callable:
98-
return lambda: self.baseline_model(input, target)
94+
def torch_lm_head_ce(self, input, weight, target) -> Callable:
95+
return lambda: self.baseline_model(input, weight, target)
9996

10097
@register_benchmark()
101-
def liger_lm_head_ce(self, input, target) -> Callable:
102-
return lambda: self.liger_model(input, target)
98+
def liger_lm_head_ce(self, input, weight, target) -> Callable:
99+
return lambda: self.liger_model(input, weight, target)
103100

104101
@register_benchmark()
105-
def inductor_fused_linear_cross_entropy(self, input, target) -> Callable:
102+
def inductor_fused_linear_cross_entropy(self, input, weight, target) -> Callable:
106103
compiled = torch.compile(self.baseline_model)
107-
return lambda: compiled(input, target)
104+
return lambda: compiled(input, weight, target)
108105

109106
@register_x_val(label="(B*T, H)")
110107
def get_x_val(self, example_inputs) -> Tuple[int, int]:

0 commit comments

Comments
 (0)