Skip to content

Commit c1645e3

Browse files
fix signatures on model_validator functions (#314)
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent e918d1e commit c1645e3

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
206206
return value
207207

208208
@model_validator(mode="after")
209-
def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
209+
def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
210210
# extract user-passed values from dictionary
211211
strategy = model.strategy
212212
group_size = model.group_size

src/compressed_tensors/quantization/quant_scheme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class QuantizationScheme(BaseModel):
4848
output_activations: Optional[QuantizationArgs] = None
4949

5050
@model_validator(mode="after")
51-
def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
51+
def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
5252
inputs = model.input_activations
5353
outputs = model.output_activations
5454

0 commit comments

Comments
 (0)