@@ -30,38 +30,30 @@ def parse_op_args(args: List[str]):
30
30
class TorchLMHeadCE (torch .nn .Module ):
31
31
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
32
32
33
- :param H: hidden size
34
- :param V: vocab size
35
33
:param ignore_index: index to ignore
36
34
:param reduction: reduction method
37
35
"""
38
36
39
- def __init__ (self , H : int , V : int , dtype : torch . dtype , ignore_index : int = - 100 ):
37
+ def __init__ (self , ignore_index : int = - 100 ):
40
38
super ().__init__ ()
41
- self .lin = torch .nn .Linear (
42
- in_features = H , out_features = V , bias = False , dtype = dtype
43
- )
44
39
self .ce_loss = torch .nn .CrossEntropyLoss (
45
40
ignore_index = ignore_index , reduction = "mean"
46
41
)
47
42
48
- def forward (self , input , target ):
49
- logits = self . lin (input )
43
+ def forward (self , input , weight , target ):
44
+ logits = torch . nn . functional . linear (input , weight )
50
45
return self .ce_loss (logits , target )
51
46
52
47
53
48
class LigerLMHeadCE (torch .nn .Module ):
54
- def __init__ (self , H : int , V : int , dtype : torch . dtype , ignore_index : int = - 100 ):
49
+ def __init__ (self , ignore_index : int = - 100 ):
55
50
super ().__init__ ()
56
- self .lin = torch .nn .Linear (
57
- in_features = H , out_features = V , bias = False , dtype = dtype
58
- )
59
51
self .ce_loss = LigerFusedLinearCrossEntropyLoss (
60
52
ignore_index = ignore_index , reduction = "mean"
61
53
)
62
54
63
- def forward (self , input , target ):
64
- return self .ce_loss (self . lin . weight , input , target )
55
+ def forward (self , input , weight , target ):
56
+ return self .ce_loss (weight , input , target )
65
57
66
58
67
59
class Operator (BenchmarkOperator ):
@@ -72,12 +64,17 @@ def __init__(
72
64
op_args = parse_op_args (self .extra_args )
73
65
self .hidden_size = op_args .hidden_size
74
66
self .vocab_size = op_args .vocab_size
75
- self .baseline_model = TorchLMHeadCE (
76
- H = self .hidden_size , V = self .vocab_size , dtype = self .dtype
77
- ).to (self .device )
78
- self .liger_model = LigerLMHeadCE (
79
- H = self .hidden_size , V = self .vocab_size , dtype = self .dtype
80
- ).to (self .device )
67
+ # Create the shared weight tensor
68
+ self .weight = torch .randn (
69
+ self .vocab_size ,
70
+ self .hidden_size ,
71
+ dtype = self .dtype ,
72
+ device = self .device ,
73
+ requires_grad = True ,
74
+ )
75
+
76
+ self .baseline_model = TorchLMHeadCE ().to (self .device )
77
+ self .liger_model = LigerLMHeadCE ().to (self .device )
81
78
82
79
def get_input_iter (self ) -> Generator :
83
80
for BT in [2 ** i for i in range (12 , 16 )]:
@@ -91,20 +88,20 @@ def get_input_iter(self) -> Generator:
91
88
target = torch .randint (
92
89
self .vocab_size , (BT , 1 ), dtype = torch .long , device = self .device
93
90
).squeeze (1 )
94
- yield _input , target
91
+ yield _input , self . weight , target
95
92
96
93
@register_benchmark (baseline = True )
97
- def torch_lm_head_ce (self , input , target ) -> Callable :
98
- return lambda : self .baseline_model (input , target )
94
+ def torch_lm_head_ce (self , input , weight , target ) -> Callable :
95
+ return lambda : self .baseline_model (input , weight , target )
99
96
100
97
@register_benchmark ()
101
- def liger_lm_head_ce (self , input , target ) -> Callable :
102
- return lambda : self .liger_model (input , target )
98
+ def liger_lm_head_ce (self , input , weight , target ) -> Callable :
99
+ return lambda : self .liger_model (input , weight , target )
103
100
104
101
@register_benchmark ()
105
- def inductor_fused_linear_cross_entropy (self , input , target ) -> Callable :
102
+ def inductor_fused_linear_cross_entropy (self , input , weight , target ) -> Callable :
106
103
compiled = torch .compile (self .baseline_model )
107
- return lambda : compiled (input , target )
104
+ return lambda : compiled (input , weight , target )
108
105
109
106
@register_x_val (label = "(B*T, H)" )
110
107
def get_x_val (self , example_inputs ) -> Tuple [int , int ]:
0 commit comments