Skip to content

Commit 311f6f9

Browse files
q10facebook-github-bot
authored andcommitted
Update the rowwise adagrad optimizer to leverage optimizer state offloading, v4, frontend (#4249)
Summary: X-link: facebookresearch/FBGEMM#1328 Pull Request resolved: #4249 - This diff follows up on D75329024 by plumbing the flag for enabling optimizer state offloading in TBE SSD from the backend C++ all the way up to the frontend Python code Reviewed By: spcyppt Differential Revision: D75336208 fbshipit-source-id: 8291d9606800791e06b0ecd9a072e44b7c88aec4
1 parent db8941a commit 311f6f9

File tree

5 files changed

+42
-20
lines changed

5 files changed

+42
-20
lines changed

fbgemm_gpu/codegen/genscript/generate_backward_split.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import itertools
1212
import sys
13+
from copy import deepcopy
1314
from typing import List
1415

1516
try:
@@ -164,6 +165,10 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
164165
if not kwargs.get("dense"):
165166
# Generate CUDA autograd
166167

168+
# Extract the aux_args and ssd_aux_args for later use
169+
aux_args = kwargs["aux_args"]
170+
ssd_aux_args = kwargs["ssd_aux_args"]
171+
167172
for ssd in [True, False] if kwargs.get("has_ssd_support") else [False]:
168173
template_filepath = (
169174
"training/backward/embedding_backward_split_host_template.cpp"
@@ -195,6 +200,10 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
195200
)
196201

197202
if kwargs.get("has_cpu_support") or kwargs.get("has_gpu_support"):
203+
# Since the template file only uses aux_args, reset the key
204+
# based on whether we are generated for SSD variant or not
205+
kwargs["aux_args"] = ssd_aux_args if ssd else aux_args
206+
198207
# Generates Python invoker for CUDA + CPU, and PT2
199208
template = CodeTemplate.load(
200209
"training/python/split_embedding_codegen_lookup_invoker.template"
@@ -433,28 +442,44 @@ def generate() -> None:
433442
"mixed_D", # 6
434443
],
435444
}
436-
# ssd-specific argument
445+
446+
# SSD-specific arguments
437447
ssd_aux_bool = [
448+
# When set to true, the per-row optimizer state will offloaded to
449+
# the end of each row in the SSD cache.
438450
"enable_optimizer_offloading", # 7
439451
]
452+
440453
assert (
441454
list(aux_args.keys()) == aux_names
442455
), f"{aux_names} must match {aux_args.keys()}"
443456

457+
ssd_aux_args = deepcopy(aux_args)
458+
ssd_aux_args["aux_bool"].extend(ssd_aux_bool)
459+
444460
all_optimizers = []
445461
ssd_optimizers = []
446462

447463
for optimizer in optimizers:
448464
optim = optimizer["optimizer"]
465+
449466
if (
450467
optimizer["has_cpu_support"] or optimizer["has_gpu_support"]
451468
) and optim != "dense":
452469
all_optimizers.append(optim)
453470
if optimizer["has_ssd_support"]:
454471
ssd_optimizers.append(optim)
472+
455473
BackwardSplitGenerator.generate_backward_split(
456-
ssd_tensors=ssd_tensors, aux_args=aux_args, **optimizer
474+
ssd_tensors=ssd_tensors,
475+
# Both aux_args and ssd_aux_args will be passed in, since
476+
# generate_backward_split will generate both SSD and non-SSD
477+
# variants
478+
aux_args=aux_args,
479+
ssd_aux_args=ssd_aux_args,
480+
**optimizer,
457481
)
482+
458483
BackwardSplitGenerator.generate_rocm_backward_split()
459484

460485
# Generate common device kernels for backwards
@@ -465,11 +490,10 @@ def generate() -> None:
465490
BackwardSplitGenerator.generate_backward_indices()
466491

467492
# Generate headers for backwards
468-
BackwardSplitGenerator.generate_backward_header(aux_args, aux_names)
469-
aux_args["aux_bool"].extend(ssd_aux_bool)
470-
BackwardSplitGenerator.generate_backward_header(
471-
aux_args, aux_names, is_ssd=True
472-
)
493+
for is_ssd in [True, False]:
494+
BackwardSplitGenerator.generate_backward_header(
495+
(ssd_aux_args if is_ssd else aux_args), aux_names, is_ssd=is_ssd
496+
)
473497

474498
BackwardSplitGenerator.generate_python_sources(all_optimizers, ssd_optimizers)
475499

fbgemm_gpu/codegen/training/python/lookup_args.template

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class CommonArgs(NamedTuple):
4848
use_homogeneous_placements: bool
4949
{%- if ssd %}
5050
ssd_tensors: Dict[str, torch.Tensor]
51+
enable_optimizer_offloading: bool
5152
{%- endif %}
5253
learning_rate_tensor: torch.Tensor
5354
info_B_num_bits: int

fbgemm_gpu/codegen/training/python/split_embedding_codegen_lookup_invoker.template

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
9292
"Please check the frontend and backend version. "
9393
)
9494
{{ arg_type }}.append(dict_{{ arg_type }}["{{ var }}"])
95+
9596
{%- endfor %}
9697
{%- endmacro %}
9798

@@ -203,12 +204,9 @@ def invoke(
203204
"use_uniq_cache_locations_bwd": common_args.use_uniq_cache_locations_bwd,
204205
"use_homogeneous_placements": common_args.use_homogeneous_placements,
205206
"apply_global_weight_decay": apply_global_weight_decay,
206-
{%- if not ssd %}
207-
"mixed_D": mixed_D
208-
{%- else %}
209207
"mixed_D": mixed_D,
210-
# TODO: Update this when frontend is ready to land
211-
"enable_optimizer_offloading": False
208+
{%- if ssd %}
209+
"enable_optimizer_offloading": common_args.enable_optimizer_offloading,
212210
{%- endif %}
213211
}
214212
dict_optim_int: Dict[str, int] = {}

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2142,6 +2142,7 @@ def forward(
21422142
"post_bwd_evicted_indices": post_bwd_evicted_indices_cpu,
21432143
"actions_count": actions_count_cpu,
21442144
},
2145+
enable_optimizer_offloading=self.enable_optimizer_offloading,
21452146
# pyre-fixme[6]: Expected `lookup_args_ssd.VBEMetadata` but got `lookup_args.VBEMetadata`
21462147
vbe_metadata=vbe_metadata,
21472148
learning_rate_tensor=self.learning_rate_tensor,

fbgemm_gpu/test/tbe/ssd/ssd_split_tbe_training_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,9 +1590,7 @@ def test_kv_db_forward(
15901590
@given(
15911591
**default_st,
15921592
num_buckets=st.integers(min_value=10, max_value=15),
1593-
opt_offloading=st.just(
1594-
False
1595-
), # make it st.booleans when Benson's opt offloading diff is landed
1593+
enable_optimizer_offloading=st.booleans(),
15961594
backend_type=st.sampled_from([BackendType.SSD, BackendType.DRAM]),
15971595
)
15981596
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
@@ -1612,7 +1610,7 @@ def test_kv_emb_state_dict(
16121610
trigger_bounds_check: bool,
16131611
mixed_B: bool,
16141612
num_buckets: int,
1615-
opt_offloading: bool,
1613+
enable_optimizer_offloading: bool,
16161614
backend_type: BackendType,
16171615
) -> None:
16181616
# Constants
@@ -1648,7 +1646,7 @@ def test_kv_emb_state_dict(
16481646
output_dtype=output_dtype,
16491647
share_table=share_table,
16501648
num_buckets=num_buckets,
1651-
enable_optimizer_offloading=opt_offloading,
1649+
enable_optimizer_offloading=enable_optimizer_offloading,
16521650
backend_type=backend_type,
16531651
)
16541652

@@ -1786,8 +1784,6 @@ def test_kv_emb_state_dict(
17861784
self.assertLess(table_index, len(emb_state_dict_list))
17871785
assert len(split_optimizer_states[table_index]) == num_ids
17881786
opt = split_optimizer_states[table_index]
1789-
if opt_offloading:
1790-
opt = opt[bucket_asc_ids_list[table_index].view(-1)]
17911787
new_ref_weight = torch.addcdiv(
17921788
emb_r_w.float(),
17931789
value=-lr,
@@ -1817,6 +1813,7 @@ def test_kv_emb_state_dict(
18171813
@given(
18181814
**default_st,
18191815
num_buckets=st.integers(min_value=10, max_value=15),
1816+
enable_optimizer_offloading=st.booleans(),
18201817
)
18211818
@settings(verbosity=Verbosity.verbose, max_examples=MAX_EXAMPLES, deadline=None)
18221819
def test_kv_opt_state_w_offloading(
@@ -1835,6 +1832,7 @@ def test_kv_opt_state_w_offloading(
18351832
trigger_bounds_check: bool,
18361833
mixed_B: bool,
18371834
num_buckets: int,
1835+
enable_optimizer_offloading: bool,
18381836
) -> None:
18391837
# Constants
18401838
lr = 0.5
@@ -1870,7 +1868,7 @@ def test_kv_opt_state_w_offloading(
18701868
output_dtype=output_dtype,
18711869
share_table=share_table,
18721870
num_buckets=num_buckets,
1873-
enable_optimizer_offloading=False,
1871+
enable_optimizer_offloading=enable_optimizer_offloading,
18741872
)
18751873

18761874
# Generate inputs

0 commit comments

Comments
 (0)