Skip to content

Commit a7d53a5

Browse files
hlkysayakpaul
andauthored
Don't override torch_dtype and don't use when quantization_config is set (#11039)
* Don't use `torch_dtype` when `quantization_config` is set * up * djkajka * Apply suggestions from code review --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 8a63aa5 commit a7d53a5

File tree

8 files changed

+19
-42
lines changed

8 files changed

+19
-42
lines changed

src/diffusers/loaders/single_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,12 +360,12 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
360360
cache_dir = kwargs.pop("cache_dir", None)
361361
local_files_only = kwargs.pop("local_files_only", False)
362362
revision = kwargs.pop("revision", None)
363-
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
363+
torch_dtype = kwargs.pop("torch_dtype", None)
364364
disable_mmap = kwargs.pop("disable_mmap", False)
365365

366366
is_legacy_loading = False
367367

368-
if not isinstance(torch_dtype, torch.dtype):
368+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
369369
torch_dtype = torch.float32
370370
logger.warning(
371371
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."

src/diffusers/loaders/single_file_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
255255
subfolder = kwargs.pop("subfolder", None)
256256
revision = kwargs.pop("revision", None)
257257
config_revision = kwargs.pop("config_revision", None)
258-
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
258+
torch_dtype = kwargs.pop("torch_dtype", None)
259259
quantization_config = kwargs.pop("quantization_config", None)
260260
device = kwargs.pop("device", None)
261261
disable_mmap = kwargs.pop("disable_mmap", False)
262262

263-
if not isinstance(torch_dtype, torch.dtype):
263+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
264264
torch_dtype = torch.float32
265265
logger.warning(
266266
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
880880
local_files_only = kwargs.pop("local_files_only", None)
881881
token = kwargs.pop("token", None)
882882
revision = kwargs.pop("revision", None)
883-
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
883+
torch_dtype = kwargs.pop("torch_dtype", None)
884884
subfolder = kwargs.pop("subfolder", None)
885885
device_map = kwargs.pop("device_map", None)
886886
max_memory = kwargs.pop("max_memory", None)
@@ -893,7 +893,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
893893
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
894894
disable_mmap = kwargs.pop("disable_mmap", False)
895895

896-
if not isinstance(torch_dtype, torch.dtype):
896+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
897897
torch_dtype = torch.float32
898898
logger.warning(
899899
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."

src/diffusers/pipelines/kolors/text_encoder.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,6 @@ def forward(self, hidden_states: torch.Tensor):
104104
return (self.weight * hidden_states).to(input_dtype)
105105

106106

107-
def _config_to_kwargs(args):
108-
common_kwargs = {
109-
"dtype": args.torch_dtype,
110-
}
111-
return common_kwargs
112-
113-
114107
class CoreAttention(torch.nn.Module):
115108
def __init__(self, config: ChatGLMConfig, layer_number):
116109
super(CoreAttention, self).__init__()
@@ -314,7 +307,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
314307
self.qkv_hidden_size,
315308
bias=config.add_bias_linear or config.add_qkv_bias,
316309
device=device,
317-
**_config_to_kwargs(config),
318310
)
319311

320312
self.core_attention = CoreAttention(config, self.layer_number)
@@ -325,7 +317,6 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
325317
config.hidden_size,
326318
bias=config.add_bias_linear,
327319
device=device,
328-
**_config_to_kwargs(config),
329320
)
330321

331322
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
@@ -449,7 +440,6 @@ def __init__(self, config: ChatGLMConfig, device=None):
449440
config.ffn_hidden_size * 2,
450441
bias=self.add_bias,
451442
device=device,
452-
**_config_to_kwargs(config),
453443
)
454444

455445
def swiglu(x):
@@ -459,9 +449,7 @@ def swiglu(x):
459449
self.activation_func = swiglu
460450

461451
# Project back to h.
462-
self.dense_4h_to_h = nn.Linear(
463-
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
464-
)
452+
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
465453

466454
def forward(self, hidden_states):
467455
# [s, b, 4hp]
@@ -488,18 +476,14 @@ def __init__(self, config: ChatGLMConfig, layer_number, device=None):
488476

489477
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
490478
# Layernorm on the input data.
491-
self.input_layernorm = LayerNormFunc(
492-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
493-
)
479+
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
494480

495481
# Self attention.
496482
self.self_attention = SelfAttention(config, layer_number, device=device)
497483
self.hidden_dropout = config.hidden_dropout
498484

499485
# Layernorm on the attention output
500-
self.post_attention_layernorm = LayerNormFunc(
501-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
502-
)
486+
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
503487

504488
# MLP
505489
self.mlp = MLP(config, device=device)
@@ -569,9 +553,7 @@ def build_layer(layer_number):
569553
if self.post_layer_norm:
570554
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
571555
# Final layer norm before output.
572-
self.final_layernorm = LayerNormFunc(
573-
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
574-
)
556+
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
575557

576558
self.gradient_checkpointing = False
577559

@@ -679,9 +661,7 @@ def __init__(self, config: ChatGLMConfig, device=None):
679661

680662
self.hidden_size = config.hidden_size
681663
# Word embeddings (parallel).
682-
self.word_embeddings = nn.Embedding(
683-
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
684-
)
664+
self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
685665
self.fp32_residual_connection = config.fp32_residual_connection
686666

687667
def forward(self, input_ids):
@@ -784,16 +764,13 @@ def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
784764
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
785765
)
786766

787-
self.rotary_pos_emb = RotaryEmbedding(
788-
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
789-
)
767+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
790768
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
791769
self.output_layer = init_method(
792770
nn.Linear,
793771
config.hidden_size,
794772
config.padded_vocab_size,
795773
bias=False,
796-
dtype=config.torch_dtype,
797774
**init_kwargs,
798775
)
799776
self.pre_seq_len = config.pre_seq_len

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
686686
token = kwargs.pop("token", None)
687687
revision = kwargs.pop("revision", None)
688688
from_flax = kwargs.pop("from_flax", False)
689-
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
689+
torch_dtype = kwargs.pop("torch_dtype", None)
690690
custom_pipeline = kwargs.pop("custom_pipeline", None)
691691
custom_revision = kwargs.pop("custom_revision", None)
692692
provider = kwargs.pop("provider", None)
@@ -703,7 +703,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
703703
use_onnx = kwargs.pop("use_onnx", None)
704704
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
705705

706-
if not isinstance(torch_dtype, torch.dtype):
706+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
707707
torch_dtype = torch.float32
708708
logger.warning(
709709
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
@@ -1456,8 +1456,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14561456

14571457
if load_components_from_hub and not trust_remote_code:
14581458
raise ValueError(
1459-
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
1460-
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
1459+
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
1460+
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
14611461
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
14621462
)
14631463

tests/pipelines/kolors/test_kolors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9090
)
9191
torch.manual_seed(0)
9292
text_encoder = ChatGLMModel.from_pretrained(
93-
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
93+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
9494
)
9595
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
9696

tests/pipelines/kolors/test_kolors_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9494
)
9595
torch.manual_seed(0)
9696
text_encoder = ChatGLMModel.from_pretrained(
97-
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
97+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
9898
)
9999
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
100100

tests/pipelines/pag/test_pag_kolors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
9999
)
100100
torch.manual_seed(0)
101101
text_encoder = ChatGLMModel.from_pretrained(
102-
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
102+
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.float32
103103
)
104104
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
105105

0 commit comments

Comments
 (0)