@@ -36,9 +36,18 @@ def invoke(
36
36
{% if "momentum2_dev" in args .split_function_arg_names % }
37
37
momentum2 : Momentum ,
38
38
{% endif % }
39
+ {% if "prev_iter_dev" in args .split_function_arg_names % }
40
+ prev_iter : Momentum ,
41
+ {% endif % }
42
+ {% if "row_counter_dev" in args .split_function_arg_names % }
43
+ row_counter : Momentum ,
44
+ {% endif % }
39
45
{% if "iter" in args .split_function_arg_names % }
40
46
iter : int ,
41
47
{% endif % }
48
+ {% if "max_counter" in args .split_function_arg_names % }
49
+ max_counter : float ,
50
+ {% endif % }
42
51
) -> torch .Tensor :
43
52
if (common_args .host_weights .numel () > 0 ):
44
53
return torch .ops .fbgemm .split_embedding_codegen_lookup_ {{ optimizer }}_function_cpu (
@@ -84,6 +93,27 @@ def invoke(
84
93
{% if "momentum" in args .split_function_arg_names % }
85
94
momentum = optimizer_args .momentum ,
86
95
{% endif % }
96
+ {% if "counter_halflife" in args .split_function_arg_names % }
97
+ counter_halflife = optimizer_args .counter_halflife ,
98
+ {% endif % }
99
+ {% if "adjustment_iter" in args .split_function_arg_names % }
100
+ adjustment_iter = optimizer_args .adjustment_iter ,
101
+ {% endif % }
102
+ {% if "adjustment_ub" in args .split_function_arg_names % }
103
+ adjustment_ub = optimizer_args .adjustment_ub ,
104
+ {% endif % }
105
+ {% if "learning_rate_mode" in args .split_function_arg_names % }
106
+ learning_rate_mode = optimizer_args .learning_rate_mode ,
107
+ {% endif % }
108
+ {% if "grad_sum_decay" in args .split_function_arg_names % }
109
+ grad_sum_decay = optimizer_args .grad_sum_decay ,
110
+ {% endif % }
111
+ {% if "tail_id_threshold" in args .split_function_arg_names % }
112
+ tail_id_threshold = optimizer_args .tail_id_threshold ,
113
+ {% endif % }
114
+ {% if "is_tail_id_thresh_ratio" in args .split_function_arg_names % }
115
+ is_tail_id_thresh_ratio = optimizer_args .is_tail_id_thresh_ratio ,
116
+ {% endif % }
87
117
# momentum1
88
118
{% if "momentum1_dev" in args .split_function_arg_names % }
89
119
momentum1_host = momentum1 .host ,
@@ -96,10 +126,26 @@ def invoke(
96
126
momentum2_offsets = momentum2 .offsets ,
97
127
momentum2_placements = momentum2 .placements ,
98
128
{% endif % }
129
+ # prev_iter
130
+ {% if "prev_iter_dev" in args .split_function_arg_names % }
131
+ prev_iter_host = prev_iter .host ,
132
+ prev_iter_offsets = prev_iter .offsets ,
133
+ prev_iter_placements = prev_iter .placements ,
134
+ {% endif % }
135
+ # row_counter
136
+ {% if "row_counter_dev" in args .split_function_arg_names % }
137
+ row_counter_host = row_counter .host ,
138
+ row_counter_offsets = row_counter .offsets ,
139
+ row_counter_placements = row_counter .placements ,
140
+ {% endif % }
99
141
# iter
100
142
{% if "iter" in args .split_function_arg_names % }
101
143
iter = iter ,
102
144
{% endif % }
145
+ # max counter
146
+ {% if "max_counter" in args .split_function_arg_names % }
147
+ max_counter = max_counter ,
148
+ {% endif % }
103
149
)
104
150
else :
105
151
return torch .ops .fbgemm .split_embedding_codegen_lookup_ {{ optimizer }}_function (
@@ -151,6 +197,27 @@ def invoke(
151
197
{% if "momentum" in args .split_function_arg_names % }
152
198
momentum = optimizer_args .momentum ,
153
199
{% endif % }
200
+ {% if "counter_halflife" in args .split_function_arg_names % }
201
+ counter_halflife = optimizer_args .counter_halflife ,
202
+ {% endif % }
203
+ {% if "adjustment_iter" in args .split_function_arg_names % }
204
+ adjustment_iter = optimizer_args .adjustment_iter ,
205
+ {% endif % }
206
+ {% if "adjustment_ub" in args .split_function_arg_names % }
207
+ adjustment_ub = optimizer_args .adjustment_ub ,
208
+ {% endif % }
209
+ {% if "learning_rate_mode" in args .split_function_arg_names % }
210
+ learning_rate_mode = optimizer_args .learning_rate_mode ,
211
+ {% endif % }
212
+ {% if "grad_sum_decay" in args .split_function_arg_names % }
213
+ grad_sum_decay = optimizer_args .grad_sum_decay ,
214
+ {% endif % }
215
+ {% if "tail_id_threshold" in args .split_function_arg_names % }
216
+ tail_id_threshold = optimizer_args .tail_id_threshold ,
217
+ {% endif % }
218
+ {% if "is_tail_id_thresh_ratio" in args .split_function_arg_names % }
219
+ is_tail_id_thresh_ratio = optimizer_args .is_tail_id_thresh_ratio ,
220
+ {% endif % }
154
221
# momentum1
155
222
{% if "momentum1_dev" in args .split_function_arg_names % }
156
223
momentum1_dev = momentum1 .dev ,
@@ -165,9 +232,27 @@ def invoke(
165
232
momentum2_offsets = momentum2 .offsets ,
166
233
momentum2_placements = momentum2 .placements ,
167
234
{% endif % }
235
+ # prev_iter
236
+ {% if "prev_iter_dev" in args .split_function_arg_names % }
237
+ prev_iter_dev = prev_iter .dev ,
238
+ prev_iter_uvm = prev_iter .uvm ,
239
+ prev_iter_offsets = prev_iter .offsets ,
240
+ prev_iter_placements = prev_iter .placements ,
241
+ {% endif % }
242
+ # row_counter
243
+ {% if "row_counter_dev" in args .split_function_arg_names % }
244
+ row_counter_dev = row_counter .dev ,
245
+ row_counter_uvm = row_counter .uvm ,
246
+ row_counter_offsets = row_counter .offsets ,
247
+ row_counter_placements = row_counter .placements ,
248
+ {% endif % }
168
249
# iter
169
250
{% if "iter" in args .split_function_arg_names % }
170
251
iter = iter ,
171
252
{% endif % }
253
+ # max counter
254
+ {% if "max_counter" in args .split_function_arg_names % }
255
+ max_counter = max_counter ,
256
+ {% endif % }
172
257
output_dtype = common_args .output_dtype ,
173
258
)
0 commit comments