Skip to content

Commit 53240c6

Browse files
authored
[Bugfix] Untie word embeddings (#1646)
## Purpose ## * Fix models which have tied word embeddings by untying the word embeddings * This was previously thought to have been fixed by `patch_tied_tensors_bug`, but that solution came from a misunderstanding that the mode config was prescriptive, rather than descriptive (that modifying the config would untie the model weights) ## Changes ## * Replace `patch_tied_tensors_bug` with `untie_word_embeddings` * Do no load models with a ranged `tie_word_embeddings` config ## Testing ## * Verified that saved model now has untied weights * Previous tied tensors tests which were failing now pass --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 2b00d04 commit 53240c6

File tree

3 files changed

+43
-57
lines changed

3 files changed

+43
-57
lines changed

src/llmcompressor/entrypoints/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from llmcompressor.pytorch.model_load.helpers import parse_dtype
2121
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
2222
modify_save_pretrained,
23-
patch_tied_tensors_bug,
23+
untie_word_embeddings,
2424
)
2525
from llmcompressor.transformers.utils.helpers import (
2626
detect_last_checkpoint,
@@ -61,7 +61,8 @@ def pre_process(model_args: "ModelArguments"):
6161
)
6262

6363
# untie tie_word_embeddings weights
64-
patch_tied_tensors_bug(model_args.model)
64+
if not model_args.tie_word_embeddings:
65+
untie_word_embeddings(model_args.model)
6566

6667
# wrap model.save_pretrained
6768
modify_save_pretrained(model_args.model)
@@ -143,7 +144,6 @@ def initialize_model_from_path(
143144
cache_dir=model_args.cache_dir,
144145
revision=model_args.model_revision,
145146
use_auth_token=True if model_args.use_auth_token else None,
146-
tie_word_embeddings=model_args.tie_word_embeddings,
147147
trust_remote_code=model_args.trust_remote_code_model,
148148
)
149149

@@ -156,7 +156,6 @@ def initialize_model_from_path(
156156
AutoConfig.from_pretrained(
157157
model_args.distill_teacher,
158158
use_auth_token=True if model_args.use_auth_token else None,
159-
tie_word_embeddings=model_args.tie_word_embeddings,
160159
trust_remote_code=model_args.trust_remote_code_model,
161160
)
162161
if model_args.distill_teacher

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
CompressionFormat,
1010
ModelCompressor,
1111
SparsityCompressionConfig,
12+
delete_offload_parameter,
1213
is_module_offloaded,
13-
update_offload_parameter,
14+
register_offload_parameter,
1415
)
1516
from loguru import logger
16-
from safetensors.torch import storage_ptr
1717
from transformers import PreTrainedModel
1818

1919
from llmcompressor.core import active_session
@@ -27,7 +27,7 @@
2727
from llmcompressor.transformers.utils import RECIPE_FILE_NAME
2828
from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path
2929

30-
__all__ = ["modify_save_pretrained"]
30+
__all__ = ["modify_save_pretrained", "untie_word_embeddings"]
3131

3232

3333
def modify_save_pretrained(model: PreTrainedModel):
@@ -120,7 +120,7 @@ def save_pretrained_wrapper(
120120
model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
121121

122122

123-
def patch_tied_tensors_bug(model: torch.nn.Module):
123+
def untie_word_embeddings(model: PreTrainedModel):
124124
"""
125125
Patches bug where HF transformers will fail to untie weights under specific
126126
circumstances (https://github.com/huggingface/transformers/issues/33689).
@@ -129,28 +129,27 @@ def patch_tied_tensors_bug(model: torch.nn.Module):
129129
130130
:param model: model to fix
131131
"""
132-
if (
133-
hasattr(model.config, "tie_word_embeddings")
134-
and not model.config.tie_word_embeddings
135-
):
136-
input_embed = model.get_input_embeddings()
137-
output_embed = model.get_output_embeddings()
138-
139-
if input_embed is None or output_embed is None:
140-
# some models fail to properly override the abstract methods
141-
return
142-
143-
if storage_ptr(input_embed.weight) == storage_ptr(output_embed.weight):
144-
for module in (input_embed, output_embed):
145-
if not is_module_offloaded(module):
146-
# create new storage ptr for onloaded weight
147-
untied_data = module.weight.data.clone()
148-
module.weight.data = untied_data
149-
else:
150-
# create new storage ptr for offloaded weight
151-
# note `update_offload_parameter` does not create a new storage ptr
152-
untied_data = module._hf_hook.weights_map["weight"].clone()
153-
update_offload_parameter(module, "weight", untied_data)
132+
input_embed = model.get_input_embeddings()
133+
output_embed = model.get_output_embeddings()
134+
135+
for module in (input_embed, output_embed):
136+
if module is None or not hasattr(module, "weight"):
137+
logger.warning(f"Cannot untie {module} which does not have weight param")
138+
continue
139+
140+
# this could be replaced by a `get_offloaded_parameter` util
141+
if not is_module_offloaded(module):
142+
untied_data = module.weight.data.clone()
143+
else:
144+
untied_data = module._hf_hook.weights_map["weight"].clone()
145+
146+
requires_grad = module.weight.requires_grad
147+
new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad)
148+
delete_offload_parameter(module, "weight")
149+
register_offload_parameter(module, "weight", new_parameter)
150+
151+
if hasattr(model.config, "tie_word_embeddings"):
152+
model.config.tie_word_embeddings = False
154153

155154

156155
def get_model_compressor(

tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
2929
get_model_compressor,
3030
modify_save_pretrained,
31-
patch_tied_tensors_bug,
31+
untie_word_embeddings,
3232
)
3333
from tests.testing_utils import requires_gpu
3434

@@ -224,8 +224,6 @@ def test_quant_model_reload(format, dtype, tmp_path):
224224
shutil.rmtree(tmp_path)
225225

226226

227-
# technically only tie_word_embeddings=False is supported right now
228-
# setting to True is discouraged
229227
@pytest.mark.parametrize(
230228
"offload,torch_dtype,tie_word_embeddings,device",
231229
[
@@ -237,25 +235,23 @@ def test_quant_model_reload(format, dtype, tmp_path):
237235
# offloading
238236
(True, torch.float16, False, "cpu"),
239237
(True, torch.float32, False, "cpu"),
240-
# (True, torch.float16, True, "cpu"), # TODO: fails
241-
# (True, torch.float32, True, "cpu"), # TODO: fails
238+
(True, torch.float16, True, "cpu"),
239+
(True, torch.float32, True, "cpu"),
242240
],
243241
)
244242
def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path):
245243
model_path = "nm-testing/llama2.c-stories15M"
246244
save_path = tmp_path / "save_path"
247245

248-
model = AutoModelForCausalLM.from_pretrained(
249-
model_path,
250-
tie_word_embeddings=tie_word_embeddings,
251-
torch_dtype=torch_dtype,
252-
)
246+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
253247
if offload:
254248
model = dispatch_model(model, {"": device}, force_hooks=True)
255249
else:
256250
model = model.to(device)
257251

258-
patch_tied_tensors_bug(model)
252+
if not tie_word_embeddings:
253+
untie_word_embeddings(model)
254+
259255
modify_save_pretrained(model)
260256
model.save_pretrained(save_path, safe_serialization=True)
261257

@@ -294,22 +290,18 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp
294290
(True, torch.float32, True, "cpu"),
295291
],
296292
)
297-
def test_model_shared_tensors(
298-
offload, torch_dtype, tie_word_embeddings, device, tmp_path
299-
):
293+
def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device):
300294
# load model
301-
model = AutoModelForCausalLM.from_pretrained(
302-
"nm-testing/llama2.c-stories15M",
303-
torch_dtype=torch_dtype,
304-
tie_word_embeddings=tie_word_embeddings,
305-
)
306-
patch_tied_tensors_bug(model)
307-
295+
model_path = "nm-testing/llama2.c-stories15M"
296+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype)
308297
if offload:
309298
model = dispatch_model(model, {"": device}, force_hooks=True)
310299
else:
311300
model = model.to(device)
312301

302+
if not tie_word_embeddings:
303+
untie_word_embeddings(model)
304+
313305
# modify lm head
314306
with torch.no_grad(), align_module_device(model.lm_head):
315307
update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1)
@@ -332,12 +324,8 @@ def test_model_shared_tensors(
332324
(False, torch.float32, True, "cuda:0"),
333325
],
334326
)
335-
def test_model_shared_tensors_gpu(
336-
offload, torch_dtype, tie_word_embeddings, device, tmp_path
337-
):
338-
test_model_shared_tensors(
339-
offload, torch_dtype, tie_word_embeddings, device, tmp_path
340-
)
327+
def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device):
328+
test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device)
341329

342330

343331
@requires_gpu

0 commit comments

Comments
 (0)