Skip to content

Commit b8241da

Browse files
Sabin Devkotafacebook-github-bot
authored andcommitted
Update API interface and reroute backend for exact_rowwise_adagrad FE when using freq based methods (#1352)
Summary: Pull Request resolved: #1352 1. Update interface to accomadate rowwise_adagrad_with_counter. 2. Route backend for rowwise_adagrad to the new rowwise_adagrad_with_counter when freq based methods (e.g. freq sgd, counter adjusted regularization) are used. Reviewed By: csmiler Differential Revision: D36788395 fbshipit-source-id: 8eb5da8a5c8b52bc1e237af1054aac9f7245c443
1 parent 30833fa commit b8241da

7 files changed

+618
-74
lines changed

fbgemm_gpu/codegen/__init__.template

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_lars_sgd as loo
1313
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_adam as lookup_partial_rowwise_adam # noqa: F401
1414
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_partial_rowwise_lamb as lookup_partial_rowwise_lamb # noqa: F401
1515
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad as lookup_rowwise_adagrad # noqa: F401
16+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_adagrad_with_counter as lookup_rowwise_adagrad_with_counter # noqa: F401
1617
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_sgd as lookup_sgd # noqa: F401
1718
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_approx_sgd as lookup_approx_sgd # noqa: F401
1819
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_approx_rowwise_adagrad as lookup_approx_rowwise_adagrad # noqa: F401
20+
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_approx_rowwise_adagrad_with_counter as lookup_approx_rowwise_adagrad_with_counter # noqa: F401
1921
import fbgemm_gpu.split_embedding_codegen_lookup_invokers.lookup_rowwise_weighted_adagrad as lookup_rowwise_weighted_adagrad # noqa: F401

fbgemm_gpu/codegen/embedding_backward_code_generator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,11 @@ def rowwise_adagrad_with_counter() -> None:
646646
split_precomputation = """
647647
at::acc_type<cache_t, true> freq = 1.0;
648648
at::acc_type<cache_t, true> l2_wd = 0.0;
649+
at::acc_type<cache_t, true> tail_id_threshold_val = tail_id_threshold;
650+
CUDA_KERNEL_ASSERT(max_counter > 0.0); // avoid divide by zero error
651+
if (is_tail_id_thresh_ratio == 1){
652+
tail_id_threshold_val = floorf(tail_id_threshold * max_counter);
653+
}
649654
if (counter_halflife > 0 && threadIdx.x == 0) {
650655
// if id occurs multiple times in a batch, iter_delta=1
651656
const auto iter_delta = prev_iter[idx] == 0 ? 1.0 : iter * 1.0 - prev_iter[idx];
@@ -660,6 +665,7 @@ def rowwise_adagrad_with_counter() -> None:
660665
}
661666
freq = SHFL_SYNC(freq, 0);
662667
l2_wd = SHFL_SYNC(l2_wd, 0);
668+
tail_id_threshold_val = SHFL_SYNC(tail_id_threshold_val, 0);
663669
664670
at::acc_type<cache_t, true> g_local_sum_square = 0.0;
665671
@@ -682,10 +688,7 @@ def rowwise_adagrad_with_counter() -> None:
682688
at::acc_type<cache_t, true> multiplier;
683689
at::acc_type<cache_t, true> adjusted_multiplier;
684690
at::acc_type<cache_t, true> exp_reg_correction;
685-
at::acc_type<cache_t, true> tail_id_threshold_val = tail_id_threshold;
686-
if (is_tail_id_thresh_ratio == 1){
687-
tail_id_threshold_val = floorf(tail_id_threshold * max_counter);
688-
}
691+
689692
if (threadIdx.x == 0) {
690693
at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
691694
momentum1[idx] = new_sum_square_grads;

fbgemm_gpu/codegen/lookup_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ class OptimizerArgs(NamedTuple):
4444
weight_decay_mode: int
4545
eta: float
4646
momentum: float
47+
counter_halflife: int
48+
adjustment_iter: int
49+
adjustment_ub: float
50+
learning_rate_mode: int
51+
grad_sum_decay: int
52+
tail_id_threshold: float
53+
is_tail_id_thresh_ratio: int
4754

4855

4956
class Momentum(NamedTuple):

fbgemm_gpu/codegen/split_embedding_codegen_lookup_invoker.template

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,18 @@ def invoke(
3636
{% if "momentum2_dev" in args.split_function_arg_names %}
3737
momentum2: Momentum,
3838
{% endif %}
39+
{% if "prev_iter_dev" in args.split_function_arg_names %}
40+
prev_iter: Momentum,
41+
{% endif %}
42+
{% if "row_counter_dev" in args.split_function_arg_names %}
43+
row_counter: Momentum,
44+
{% endif %}
3945
{% if "iter" in args.split_function_arg_names %}
4046
iter: int,
4147
{% endif %}
48+
{% if "max_counter" in args.split_function_arg_names %}
49+
max_counter: float,
50+
{% endif %}
4251
) -> torch.Tensor:
4352
if (common_args.host_weights.numel() > 0):
4453
return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
@@ -84,6 +93,27 @@ def invoke(
8493
{% if "momentum" in args.split_function_arg_names %}
8594
momentum=optimizer_args.momentum,
8695
{% endif %}
96+
{% if "counter_halflife" in args.split_function_arg_names %}
97+
counter_halflife=optimizer_args.counter_halflife,
98+
{% endif %}
99+
{% if "adjustment_iter" in args.split_function_arg_names %}
100+
adjustment_iter=optimizer_args.adjustment_iter,
101+
{% endif %}
102+
{% if "adjustment_ub" in args.split_function_arg_names %}
103+
adjustment_ub=optimizer_args.adjustment_ub,
104+
{% endif %}
105+
{% if "learning_rate_mode" in args.split_function_arg_names %}
106+
learning_rate_mode=optimizer_args.learning_rate_mode,
107+
{% endif %}
108+
{% if "grad_sum_decay" in args.split_function_arg_names %}
109+
grad_sum_decay=optimizer_args.grad_sum_decay,
110+
{% endif %}
111+
{% if "tail_id_threshold" in args.split_function_arg_names %}
112+
tail_id_threshold=optimizer_args.tail_id_threshold,
113+
{% endif %}
114+
{% if "is_tail_id_thresh_ratio" in args.split_function_arg_names %}
115+
is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio,
116+
{% endif %}
87117
# momentum1
88118
{% if "momentum1_dev" in args.split_function_arg_names %}
89119
momentum1_host=momentum1.host,
@@ -96,10 +126,26 @@ def invoke(
96126
momentum2_offsets=momentum2.offsets,
97127
momentum2_placements=momentum2.placements,
98128
{% endif %}
129+
# prev_iter
130+
{% if "prev_iter_dev" in args.split_function_arg_names %}
131+
prev_iter_host=prev_iter.host,
132+
prev_iter_offsets=prev_iter.offsets,
133+
prev_iter_placements=prev_iter.placements,
134+
{% endif %}
135+
# row_counter
136+
{% if "row_counter_dev" in args.split_function_arg_names %}
137+
row_counter_host=row_counter.host,
138+
row_counter_offsets=row_counter.offsets,
139+
row_counter_placements=row_counter.placements,
140+
{% endif %}
99141
# iter
100142
{% if "iter" in args.split_function_arg_names %}
101143
iter=iter,
102144
{% endif %}
145+
# max counter
146+
{% if "max_counter" in args.split_function_arg_names %}
147+
max_counter=max_counter,
148+
{% endif %}
103149
)
104150
else:
105151
return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function(
@@ -151,6 +197,27 @@ def invoke(
151197
{% if "momentum" in args.split_function_arg_names %}
152198
momentum=optimizer_args.momentum,
153199
{% endif %}
200+
{% if "counter_halflife" in args.split_function_arg_names %}
201+
counter_halflife=optimizer_args.counter_halflife,
202+
{% endif %}
203+
{% if "adjustment_iter" in args.split_function_arg_names %}
204+
adjustment_iter=optimizer_args.adjustment_iter,
205+
{% endif %}
206+
{% if "adjustment_ub" in args.split_function_arg_names %}
207+
adjustment_ub=optimizer_args.adjustment_ub,
208+
{% endif %}
209+
{% if "learning_rate_mode" in args.split_function_arg_names %}
210+
learning_rate_mode=optimizer_args.learning_rate_mode,
211+
{% endif %}
212+
{% if "grad_sum_decay" in args.split_function_arg_names %}
213+
grad_sum_decay=optimizer_args.grad_sum_decay,
214+
{% endif %}
215+
{% if "tail_id_threshold" in args.split_function_arg_names %}
216+
tail_id_threshold=optimizer_args.tail_id_threshold,
217+
{% endif %}
218+
{% if "is_tail_id_thresh_ratio" in args.split_function_arg_names %}
219+
is_tail_id_thresh_ratio=optimizer_args.is_tail_id_thresh_ratio,
220+
{% endif %}
154221
# momentum1
155222
{% if "momentum1_dev" in args.split_function_arg_names %}
156223
momentum1_dev=momentum1.dev,
@@ -165,9 +232,27 @@ def invoke(
165232
momentum2_offsets=momentum2.offsets,
166233
momentum2_placements=momentum2.placements,
167234
{% endif %}
235+
# prev_iter
236+
{% if "prev_iter_dev" in args.split_function_arg_names %}
237+
prev_iter_dev=prev_iter.dev,
238+
prev_iter_uvm=prev_iter.uvm,
239+
prev_iter_offsets=prev_iter.offsets,
240+
prev_iter_placements=prev_iter.placements,
241+
{% endif %}
242+
# row_counter
243+
{% if "row_counter_dev" in args.split_function_arg_names %}
244+
row_counter_dev=row_counter.dev,
245+
row_counter_uvm=row_counter.uvm,
246+
row_counter_offsets=row_counter.offsets,
247+
row_counter_placements=row_counter.placements,
248+
{% endif %}
168249
# iter
169250
{% if "iter" in args.split_function_arg_names %}
170251
iter=iter,
171252
{% endif %}
253+
# max counter
254+
{% if "max_counter" in args.split_function_arg_names %}
255+
max_counter=max_counter,
256+
{% endif %}
172257
output_dtype=common_args.output_dtype,
173258
)

0 commit comments

Comments
 (0)