Skip to content

Commit d195771

Browse files
authored
feat: Add convert_single_tensor_to_hf API for state dict adapter (#759)
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent d8b3778 commit d195771

File tree

13 files changed

+982
-116
lines changed

13 files changed

+982
-116
lines changed

nemo_automodel/components/checkpoint/state_dict_adapter.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,18 @@ def from_hf(
5858
The converted native model state dict
5959
"""
6060
pass
61+
62+
@abstractmethod
63+
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
64+
"""Convert a single tensor from native format to HuggingFace format.
65+
66+
Args:
67+
fqn: Fully qualified name of the tensor in native format
68+
tensor: The tensor to convert
69+
**kwargs: Additional arguments for conversion
70+
71+
Returns:
72+
List of (fqn, tensor) tuples in HuggingFace format.
73+
Returns a list because some native tensors may split into multiple HF tensors.
74+
"""
75+
pass

nemo_automodel/components/models/deepseek_v3/state_dict_adapter.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,15 @@ def to_hf(
9696
"""Convert from native model state dict to HuggingFace format.
9797
Automatically detects format based on backend.enable_deepep configuration.
9898
"""
99-
hf_state_dict = self._to_hf_w_split_experts(state_dict)
99+
hf_state_dict = {}
100+
for fqn, tensor in state_dict.items():
101+
converted_tensors = self.convert_single_tensor_to_hf(
102+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
103+
)
104+
for key, value in converted_tensors:
105+
hf_state_dict[key] = value
100106

101-
if exclude_key_regex:
102-
hf_state_dict = {k: v for k, v in hf_state_dict.items() if not re.match(exclude_key_regex, k)}
103-
104-
if quantization:
105-
return self._add_quantization_scale_inv_tensors(hf_state_dict)
106-
else:
107-
return hf_state_dict
107+
return hf_state_dict
108108

109109
def from_hf(
110110
self,
@@ -124,6 +124,54 @@ def from_hf(
124124
hf_state_dict = self._dequantize(hf_state_dict)
125125
return self._from_hf_w_merged_experts(hf_state_dict, device_mesh)
126126

127+
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
128+
"""Convert a single tensor from native format to HuggingFace format.
129+
130+
Args:
131+
fqn: Fully qualified name of the tensor in native format
132+
tensor: The tensor to convert
133+
**kwargs: Additional arguments for conversion
134+
135+
Returns:
136+
List of (fqn, tensor) tuples in HuggingFace format
137+
"""
138+
quantization = kwargs.get("quantization", False)
139+
exclude_key_regex = kwargs.get("exclude_key_regex", None)
140+
141+
expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs)
142+
if expert_result is not None:
143+
result = expert_result
144+
else:
145+
result = [(fqn, tensor)]
146+
147+
if exclude_key_regex:
148+
result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)]
149+
150+
if quantization:
151+
quantized_result = []
152+
for key, value in result:
153+
if key.endswith(".weight") and not any(
154+
non_quantized_key in key
155+
for non_quantized_key in [
156+
"input_layernorm.weight",
157+
"post_attention_layernorm.weight",
158+
"norm.weight",
159+
"lm_head.weight",
160+
"embed_tokens.weight",
161+
"mlp.gate.weight",
162+
]
163+
):
164+
value = value.to(dtype=torch.float8_e4m3fn)
165+
expected_scale_shape = calculate_scale_shape(value)
166+
weight_scale_inv = torch.ones(expected_scale_shape, dtype=torch.float32, device=value.device)
167+
quantized_result.append((key, value))
168+
quantized_result.append((key + "_scale_inv", weight_scale_inv))
169+
else:
170+
quantized_result.append((key, value))
171+
return quantized_result
172+
173+
return result
174+
127175

128176
def calculate_scale_shape(weight: torch.Tensor, BLOCK_SIZE: int = BLOCK_SIZE) -> torch.Size:
129177
# Calculate the scale tensor shape

nemo_automodel/components/models/glm4_moe/state_dict_adapter.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,14 @@ def __init__(
6161
def to_hf(
6262
self, state_dict: dict[str, Any], exclude_key_regex: Optional[str] = None, quantization: bool = False, **kwargs
6363
) -> dict[str, Any]:
64-
hf_state_dict = self._to_hf_w_split_experts(state_dict)
65-
if exclude_key_regex:
66-
import re
64+
hf_state_dict = {}
65+
for fqn, tensor in state_dict.items():
66+
converted_tensors = self.convert_single_tensor_to_hf(
67+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
68+
)
69+
for key, value in converted_tensors:
70+
hf_state_dict[key] = value
6771

68-
hf_state_dict = {k: v for k, v in hf_state_dict.items() if not re.match(exclude_key_regex, k)}
6972
return hf_state_dict
7073

7174
def from_hf(
@@ -80,3 +83,29 @@ def from_hf(
8083
self._uses_model_prefix = key.startswith("model.")
8184
break
8285
return self._from_hf_w_merged_experts(hf_state_dict, device_mesh)
86+
87+
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
88+
"""Convert a single tensor from native format to HuggingFace format.
89+
90+
Args:
91+
fqn: Fully qualified name of the tensor in native format
92+
tensor: The tensor to convert
93+
**kwargs: Additional arguments for conversion
94+
95+
Returns:
96+
List of (fqn, tensor) tuples in HuggingFace format
97+
"""
98+
exclude_key_regex = kwargs.get("exclude_key_regex", None)
99+
100+
expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs)
101+
if expert_result is not None:
102+
result = expert_result
103+
else:
104+
result = [(fqn, tensor)]
105+
106+
if exclude_key_regex:
107+
import re
108+
109+
result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)]
110+
111+
return result

nemo_automodel/components/models/gpt_oss/state_dict_adapter.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,14 @@ def to_hf(
212212
self, state_dict: dict[str, Any], exclude_key_regex: Optional[str] = None, quantization: bool = False, **kwargs
213213
) -> dict[str, Any]:
214214
"""Convert from native model state dict to HuggingFace format."""
215-
hf_state_dict = dict(state_dict)
216-
hf_state_dict = self._apply_key_mapping(hf_state_dict, self.internal_to_hf_map)
217-
218-
# Apply exclude regex if provided
219-
if exclude_key_regex:
220-
hf_state_dict = {k: v for k, v in hf_state_dict.items() if not re.match(exclude_key_regex, k)}
215+
hf_state_dict = {}
216+
for fqn, tensor in state_dict.items():
217+
converted_tensors = self.convert_single_tensor_to_hf(
218+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
219+
)
220+
for key, value in converted_tensors:
221+
hf_state_dict[key] = value
221222

222-
if quantization:
223-
hf_state_dict = self._add_quantization_block_scale_tensors(hf_state_dict)
224223
return hf_state_dict
225224

226225
def from_hf(
@@ -244,3 +243,51 @@ def from_hf(
244243
native_state_dict = self._apply_key_mapping(native_state_dict, self.hf_to_internal_map)
245244

246245
return native_state_dict
246+
247+
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
248+
"""Convert a single tensor from native format to HuggingFace format.
249+
250+
Args:
251+
fqn: Fully qualified name of the tensor in native format
252+
tensor: The tensor to convert
253+
**kwargs: Additional arguments for conversion
254+
255+
Returns:
256+
List of (fqn, tensor) tuples in HuggingFace format
257+
"""
258+
quantization = kwargs.get("quantization", False)
259+
exclude_key_regex = kwargs.get("exclude_key_regex", None)
260+
261+
hf_fqn = fqn
262+
for pattern, replacement in self.internal_to_hf_map.items():
263+
if fqn.endswith(pattern):
264+
hf_fqn = fqn[: -len(pattern)] + replacement
265+
break
266+
267+
if exclude_key_regex:
268+
if re.match(exclude_key_regex, hf_fqn):
269+
return []
270+
271+
if quantization:
272+
if hf_fqn.endswith("gate_up_proj") or hf_fqn.endswith("down_proj"):
273+
layer_name, projection_type = hf_fqn.rsplit(".", 1)
274+
n_experts, _, dim = tensor.shape
275+
276+
if isinstance(tensor, torch.distributed.tensor.DTensor):
277+
placements, device_mesh = tensor.placements, tensor.device_mesh
278+
blocks_tensors = torch.distributed.tensor.ones(
279+
(n_experts, dim, 90, 16), placements=placements, device_mesh=device_mesh, dtype=torch.uint8
280+
)
281+
scales_tensors = torch.distributed.tensor.ones(
282+
(n_experts, dim, 90), placements=placements, device_mesh=device_mesh, dtype=torch.uint8
283+
)
284+
else:
285+
blocks_tensors = torch.ones((n_experts, dim, 90, 16), dtype=torch.uint8)
286+
scales_tensors = torch.ones((n_experts, dim, 90), dtype=torch.uint8)
287+
288+
return [
289+
(f"{layer_name}.{projection_type}_blocks", blocks_tensors),
290+
(f"{layer_name}.{projection_type}_scales", scales_tensors),
291+
]
292+
293+
return [(hf_fqn, tensor)]

nemo_automodel/components/models/qwen3_moe/state_dict_adapter.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,14 @@ def __init__(
5555
def to_hf(
5656
self, state_dict: dict[str, Any], exclude_key_regex: Optional[str] = None, quantization: bool = False, **kwargs
5757
) -> dict[str, Any]:
58-
hf_state_dict = self._to_hf_w_split_experts(state_dict)
59-
if exclude_key_regex:
60-
import re
58+
hf_state_dict = {}
59+
for fqn, tensor in state_dict.items():
60+
converted_tensors = self.convert_single_tensor_to_hf(
61+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
62+
)
63+
for key, value in converted_tensors:
64+
hf_state_dict[key] = value
6165

62-
hf_state_dict = {k: v for k, v in hf_state_dict.items() if not re.match(exclude_key_regex, k)}
6366
return hf_state_dict
6467

6568
def from_hf(
@@ -74,3 +77,29 @@ def from_hf(
7477
self._uses_model_prefix = key.startswith("model.")
7578
break
7679
return self._from_hf_w_merged_experts(hf_state_dict, device_mesh)
80+
81+
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
82+
"""Convert a single tensor from native format to HuggingFace format.
83+
84+
Args:
85+
fqn: Fully qualified name of the tensor in native format
86+
tensor: The tensor to convert
87+
**kwargs: Additional arguments for conversion
88+
89+
Returns:
90+
List of (fqn, tensor) tuples in HuggingFace format
91+
"""
92+
exclude_key_regex = kwargs.get("exclude_key_regex", None)
93+
94+
expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs)
95+
if expert_result is not None:
96+
result = expert_result
97+
else:
98+
result = [(fqn, tensor)]
99+
100+
if exclude_key_regex:
101+
import re
102+
103+
result = [(k, v) for k, v in result if not re.match(exclude_key_regex, k)]
104+
105+
return result

nemo_automodel/components/models/qwen3_next/state_dict_adapter.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,14 @@ def _apply_key_mapping(self, state_dict: dict[str, Any], mapping: dict[str, str]
9393
def to_hf(
9494
self, state_dict: dict[str, Any], exclude_key_regex: Optional[str] = None, quantization: bool = False, **kwargs
9595
) -> dict[str, Any]:
96-
# First convert routed experts from grouped to split format
97-
hf_state_dict = self._to_hf_w_split_experts(state_dict)
96+
hf_state_dict = {}
97+
for fqn, tensor in state_dict.items():
98+
converted_tensors = self.convert_single_tensor_to_hf(
99+
fqn, tensor, exclude_key_regex=exclude_key_regex, quantization=quantization, **kwargs
100+
)
101+
for key, value in converted_tensors:
102+
hf_state_dict[key] = value
98103

99-
# Then apply key mappings for shared experts (shared_experts -> shared_expert)
100-
hf_state_dict = self._apply_key_mapping(hf_state_dict, self.internal_to_hf_map)
101-
102-
if exclude_key_regex:
103-
import re
104-
105-
hf_state_dict = {k: v for k, v in hf_state_dict.items() if not re.match(exclude_key_regex, k)}
106104
return hf_state_dict
107105

108106
def from_hf(
@@ -122,3 +120,38 @@ def from_hf(
122120

123121
# Then convert routed experts from split to grouped format
124122
return self._from_hf_w_merged_experts(hf_state_dict, device_mesh)
123+
124+
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
125+
"""Convert a single tensor from native format to HuggingFace format.
126+
127+
Args:
128+
fqn: Fully qualified name of the tensor in native format
129+
tensor: The tensor to convert
130+
**kwargs: Additional arguments for conversion
131+
132+
Returns:
133+
List of (fqn, tensor) tuples in HuggingFace format
134+
"""
135+
exclude_key_regex = kwargs.get("exclude_key_regex", None)
136+
137+
expert_result = self._convert_single_merged_expert_to_hf_split_experts(fqn, tensor, **kwargs)
138+
if expert_result is not None:
139+
result = expert_result
140+
else:
141+
result = [(fqn, tensor)]
142+
143+
mapped_result = []
144+
for key, value in result:
145+
new_key = key
146+
for pattern, replacement in self.internal_to_hf_map.items():
147+
if pattern in key:
148+
new_key = new_key.replace(pattern, replacement)
149+
break
150+
mapped_result.append((new_key, value))
151+
152+
if exclude_key_regex:
153+
import re
154+
155+
mapped_result = [(k, v) for k, v in mapped_result if not re.match(exclude_key_regex, k)]
156+
157+
return mapped_result

0 commit comments

Comments
 (0)