|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import enum
|
11 |
| -from typing import Any, Dict # noqa: F401 |
| 11 | +import itertools |
| 12 | +from typing import Any, Dict, List, Tuple # noqa: F401 |
12 | 13 |
|
13 | 14 | import torch
|
14 | 15 |
|
| 16 | +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( |
| 17 | + EmbeddingLocation, |
| 18 | + SplitState, |
| 19 | +) |
| 20 | + |
15 | 21 |
|
16 | 22 | @enum.unique
|
17 | 23 | class EmbOptimType(enum.Enum):
|
@@ -68,6 +74,49 @@ def state_size_nbytes(
|
68 | 74 | else:
|
69 | 75 | return 0
|
70 | 76 |
|
| 77 | + def ssd_state_splits( |
| 78 | + self, |
| 79 | + embedding_specs: List[Tuple[int, int]], # Tuple of (rows, dims) |
| 80 | + optimizer_state_dtypes: Dict[str, "SparseType"] = {}, # noqa: B006 |
| 81 | + enable_optimizer_offloading: bool = False, |
| 82 | + ) -> List[Tuple[SplitState, str, torch.dtype]]: |
| 83 | + """ |
| 84 | + Returns the split planning for the optimizer states |
| 85 | + """ |
| 86 | + (rows, _) = zip(*embedding_specs) |
| 87 | + T_ = len(embedding_specs) |
| 88 | + |
| 89 | + # This is the cumulative row counts for rowwise states |
| 90 | + row_count_cumsum: List[int] = [0] + list(itertools.accumulate(rows)) |
| 91 | + # This is the cumulative element counts for elementwise states |
| 92 | + table_size_cumsum: List[int] = [0] + list( |
| 93 | + itertools.accumulate([r * d for r, d in embedding_specs]) |
| 94 | + ) |
| 95 | + |
| 96 | + if self == EmbOptimType.EXACT_ROWWISE_ADAGRAD: |
| 97 | + params = {"momentum1": row_count_cumsum} |
| 98 | + elif self == EmbOptimType.PARTIAL_ROWWISE_ADAM: |
| 99 | + params = {"momentum1": table_size_cumsum, "momentum2": row_count_cumsum} |
| 100 | + else: |
| 101 | + params = {} |
| 102 | + |
| 103 | + return [ |
| 104 | + ( |
| 105 | + SplitState( |
| 106 | + dev_size=( |
| 107 | + cumsum_table[-1] if not enable_optimizer_offloading else 0 |
| 108 | + ), |
| 109 | + host_size=0, |
| 110 | + uvm_size=0, |
| 111 | + placements=[EmbeddingLocation.DEVICE for _ in range(T_)], |
| 112 | + offsets=cumsum_table[:-1], |
| 113 | + ), |
| 114 | + name, |
| 115 | + self._extract_dtype(optimizer_state_dtypes, name), |
| 116 | + ) |
| 117 | + for (name, cumsum_table) in params.items() |
| 118 | + ] |
| 119 | + |
71 | 120 | def dtype(self) -> torch.dtype:
|
72 | 121 | """
|
73 | 122 | Returns the dtype of the optimizer state
|
|
0 commit comments