Skip to content

Commit cb8f410

Browse files
authored
Remove FP8 Patch (#1585)
## Purpose ## * Remove dead code ## Background ## The `new_dtype_byte_size` patch was added due to lack of support for FP8 in transformer's `dtype_byte_size` function. This fix was merged into transformers main over a year ago, this patch should be safe to remove now. ## Changes ## * Remove `new_dtype_byte_size` patch Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ec6345d commit cb8f410

File tree

1 file changed

+0
-19
lines changed

1 file changed

+0
-19
lines changed

src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import os
2-
import re
32
import weakref
43
from functools import wraps
54
from typing import Optional
65

76
import torch
8-
import transformers
97
from accelerate.accelerator import get_state_dict_offloaded_model
108
from compressed_tensors import (
119
CompressionFormat,
@@ -86,11 +84,6 @@ def save_pretrained_wrapper(
8684
:param kwargs: additional kwargs to pass on to model.save_pretrained
8785
"""
8886

89-
# HACK: Override the dtype_byte_size function in transformers to
90-
# support float8 types. Fix is posted upstream
91-
# https://github.com/huggingface/transformers/pull/30488
92-
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
93-
9487
# compress model using compressor
9588
compressor = get_model_compressor(
9689
model=model,
@@ -128,18 +121,6 @@ def save_pretrained_wrapper(
128121
model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
129122

130123

131-
# HACK: Override the dtype_byte_size function in transformers to support float8 types
132-
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
133-
def new_dtype_byte_size(dtype):
134-
if dtype == torch.bool:
135-
return 1 / 8
136-
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
137-
if bit_search is None:
138-
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
139-
bit_size = int(bit_search.groups()[0])
140-
return bit_size // 8
141-
142-
143124
def patch_tied_tensors_bug(model: torch.nn.Module):
144125
"""
145126
Patches bug where HF transformers will fail to untie weights under specific

0 commit comments

Comments
 (0)