|
1 | 1 | import os
|
2 |
| -import re |
3 | 2 | import weakref
|
4 | 3 | from functools import wraps
|
5 | 4 | from typing import Optional
|
6 | 5 |
|
7 | 6 | import torch
|
8 |
| -import transformers |
9 | 7 | from accelerate.accelerator import get_state_dict_offloaded_model
|
10 | 8 | from compressed_tensors import (
|
11 | 9 | CompressionFormat,
|
@@ -86,11 +84,6 @@ def save_pretrained_wrapper(
|
86 | 84 | :param kwargs: additional kwargs to pass on to model.save_pretrained
|
87 | 85 | """
|
88 | 86 |
|
89 |
| - # HACK: Override the dtype_byte_size function in transformers to |
90 |
| - # support float8 types. Fix is posted upstream |
91 |
| - # https://github.com/huggingface/transformers/pull/30488 |
92 |
| - transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size |
93 |
| - |
94 | 87 | # compress model using compressor
|
95 | 88 | compressor = get_model_compressor(
|
96 | 89 | model=model,
|
@@ -128,18 +121,6 @@ def save_pretrained_wrapper(
|
128 | 121 | model.save_pretrained = save_pretrained_compressed(model.save_pretrained)
|
129 | 122 |
|
130 | 123 |
|
131 |
| -# HACK: Override the dtype_byte_size function in transformers to support float8 types |
132 |
| -# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 |
133 |
| -def new_dtype_byte_size(dtype): |
134 |
| - if dtype == torch.bool: |
135 |
| - return 1 / 8 |
136 |
| - bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) |
137 |
| - if bit_search is None: |
138 |
| - raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") |
139 |
| - bit_size = int(bit_search.groups()[0]) |
140 |
| - return bit_size // 8 |
141 |
| - |
142 |
| - |
143 | 124 | def patch_tied_tensors_bug(model: torch.nn.Module):
|
144 | 125 | """
|
145 | 126 | Patches bug where HF transformers will fail to untie weights under specific
|
|
0 commit comments