Skip to content

Commit c27a22b

Browse files
authored
fix scale init in tests (#338)
1 parent bbe5491 commit c27a22b

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

tests/test_compressors/quantized_compressors/test_fp8_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def make_dummy_g_idx(columns: int, group_size: int) -> torch.Tensor:
6161
[
6262
QuantizationStrategy.GROUP,
6363
128,
64-
torch.rand((512, 8, 1)) * 0.01,
65-
torch.zeros((512, 8, 1), dtype=torch.int8),
64+
torch.rand((512, 8)) * 0.01,
65+
torch.zeros((512, 8), dtype=torch.int8),
6666
],
6767
[
6868
QuantizationStrategy.CHANNEL,

tests/test_compressors/quantized_compressors/test_int_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def get_dummy_quant_config(strategy, group_size=None, symmetric=True):
5353
QuantizationStrategy.GROUP,
5454
True,
5555
128,
56-
torch.rand((512, 8, 1)) * 0.01,
57-
torch.zeros((512, 8, 1), dtype=torch.int8),
56+
torch.rand((512, 8)) * 0.01,
57+
torch.zeros((512, 8), dtype=torch.int8),
5858
],
5959
[
6060
QuantizationStrategy.CHANNEL,

tests/test_quantization/lifecycle/test_forward.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,17 @@ def test_forward_quantize(
108108
"int",
109109
QuantizationStrategy.GROUP,
110110
128,
111-
torch.rand((512, 8, 1)) * 0.01,
112-
torch.zeros((512, 8, 1)),
111+
torch.rand((512, 8)) * 0.01,
112+
torch.zeros((512, 8)),
113113
None,
114114
),
115115
(
116116
4,
117117
"int",
118118
QuantizationStrategy.GROUP,
119119
128,
120-
torch.rand((512, 8, 1)) * 0.01,
121-
torch.zeros((512, 8, 1)),
120+
torch.rand((512, 8)) * 0.01,
121+
torch.zeros((512, 8)),
122122
make_dummy_g_idx(1024, 128),
123123
),
124124
(
@@ -135,17 +135,17 @@ def test_forward_quantize(
135135
"float",
136136
QuantizationStrategy.GROUP,
137137
128,
138-
torch.rand((512, 8, 1)) * 0.01,
139-
torch.zeros((512, 8, 1)),
138+
torch.rand((512, 8)) * 0.01,
139+
torch.zeros((512, 8)),
140140
None,
141141
),
142142
(
143143
8,
144144
"float",
145145
QuantizationStrategy.GROUP,
146146
128,
147-
torch.rand((512, 8, 1)) * 0.01,
148-
torch.zeros((512, 8, 1)),
147+
torch.rand((512, 8)) * 0.01,
148+
torch.zeros((512, 8)),
149149
make_dummy_g_idx(1024, 128),
150150
),
151151
],
@@ -174,17 +174,17 @@ def test_quantize(num_bits, type, strategy, group_size, scale, zero_point, g_idx
174174
"int",
175175
QuantizationStrategy.GROUP,
176176
128,
177-
torch.rand((512, 8, 1)) * 0.01,
178-
torch.zeros((512, 8, 1)),
177+
torch.rand((512, 8)) * 0.01,
178+
torch.zeros((512, 8)),
179179
None,
180180
),
181181
(
182182
8,
183183
"int",
184184
QuantizationStrategy.GROUP,
185185
128,
186-
torch.rand((512, 8, 1)) * 0.01,
187-
torch.zeros((512, 8, 1)),
186+
torch.rand((512, 8)) * 0.01,
187+
torch.zeros((512, 8)),
188188
make_dummy_g_idx(1024, 128),
189189
),
190190
],

0 commit comments

Comments
 (0)