@@ -34,30 +34,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot(
34
34
has_clamp);
35
35
36
36
std::vector<char > activation_data (
37
- activation_data_size<has_weight_zeros> (m, k, group_size));
38
- prepare_activation_data<has_weight_zeros> (
37
+ activation_data_size (m, k, group_size, has_weight_zeros ));
38
+ prepare_activation_data (
39
39
(void *)activation_data.data (),
40
40
m,
41
41
k,
42
42
group_size,
43
- test_case.activations .data ());
43
+ test_case.activations .data (),
44
+ has_weight_zeros);
44
45
45
- std::vector<char > weight_data (
46
- weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
47
- n, k, group_size));
48
- prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
46
+ std::vector<char > weight_data (weight_data_size<weight_nbit>(
47
+ n, k, group_size, has_weight_zeros, has_bias));
48
+ int8_t * weight_zeros_ptr = nullptr ;
49
+ if (has_weight_zeros) {
50
+ weight_zeros_ptr = test_case.weight_zeros .data ();
51
+ }
52
+ float * bias_ptr = nullptr ;
53
+ if (has_bias) {
54
+ bias_ptr = test_case.bias .data ();
55
+ }
56
+ prepare_weight_data<weight_nbit>(
49
57
(void *)weight_data.data (),
50
58
n,
51
59
k,
52
60
group_size,
53
61
test_case.weight_qvals .data (),
54
62
test_case.weight_scales .data (),
55
- test_case. weight_zeros . data () ,
56
- test_case. bias . data () );
63
+ weight_zeros_ptr ,
64
+ bias_ptr );
57
65
58
66
std::vector<float > output (m * k);
59
67
for (auto _ : state) {
60
- kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp >(
68
+ kernel<weight_nbit>(
61
69
output.data (),
62
70
/* output_m_stride=*/ n,
63
71
m,
@@ -67,7 +75,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot(
67
75
weight_data.data (),
68
76
activation_data.data (),
69
77
test_case.clamp_min ,
70
- test_case.clamp_max );
78
+ test_case.clamp_max ,
79
+ has_weight_zeros,
80
+ has_bias,
81
+ has_clamp);
71
82
}
72
83
}
73
84
@@ -95,30 +106,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot(
95
106
has_clamp);
96
107
97
108
std::vector<char > activation_data (
98
- activation_data_size<has_weight_zeros> (m, k, group_size));
99
- prepare_activation_data<has_weight_zeros> (
109
+ activation_data_size (m, k, group_size, has_weight_zeros ));
110
+ prepare_activation_data (
100
111
(void *)activation_data.data (),
101
112
m,
102
113
k,
103
114
group_size,
104
- test_case.activations .data ());
115
+ test_case.activations .data (),
116
+ has_weight_zeros);
105
117
106
- std::vector<char > weight_data (
107
- weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
108
- n, k, group_size));
109
- prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
118
+ std::vector<char > weight_data (weight_data_size<weight_nbit>(
119
+ n, k, group_size, has_weight_zeros, has_bias));
120
+ int8_t * weight_zeros_ptr = nullptr ;
121
+ if (has_weight_zeros) {
122
+ weight_zeros_ptr = test_case.weight_zeros .data ();
123
+ }
124
+ float * bias_ptr = nullptr ;
125
+ if (has_bias) {
126
+ bias_ptr = test_case.bias .data ();
127
+ }
128
+ prepare_weight_data<weight_nbit>(
110
129
(void *)weight_data.data (),
111
130
n,
112
131
k,
113
132
group_size,
114
133
test_case.weight_qvals .data (),
115
134
test_case.weight_scales .data (),
116
- test_case. weight_zeros . data () ,
117
- test_case. bias . data () );
135
+ weight_zeros_ptr ,
136
+ bias_ptr );
118
137
119
138
std::vector<float > output (m * k);
120
139
for (auto _ : state) {
121
- kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp >(
140
+ kernel<weight_nbit>(
122
141
output.data (),
123
142
/* output_m_stride=*/ n,
124
143
m,
@@ -128,7 +147,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot(
128
147
weight_data.data (),
129
148
activation_data.data (),
130
149
test_case.clamp_min ,
131
- test_case.clamp_max );
150
+ test_case.clamp_max ,
151
+ has_weight_zeros,
152
+ has_bias,
153
+ has_clamp);
132
154
}
133
155
}
134
156
@@ -156,30 +178,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
156
178
has_clamp);
157
179
158
180
std::vector<char > activation_data (
159
- activation_data_size<has_weight_zeros> (m, k, group_size));
160
- prepare_activation_data<has_weight_zeros> (
181
+ activation_data_size (m, k, group_size, has_weight_zeros ));
182
+ prepare_activation_data (
161
183
(void *)activation_data.data (),
162
184
m,
163
185
k,
164
186
group_size,
165
- test_case.activations .data ());
187
+ test_case.activations .data (),
188
+ has_weight_zeros);
166
189
167
- std::vector<char > weight_data (
168
- weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
169
- n, k, group_size));
170
- prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
190
+ std::vector<char > weight_data (weight_data_size<weight_nbit>(
191
+ n, k, group_size, has_weight_zeros, has_bias));
192
+ int8_t * weight_zeros_ptr = nullptr ;
193
+ if (has_weight_zeros) {
194
+ weight_zeros_ptr = test_case.weight_zeros .data ();
195
+ }
196
+ float * bias_ptr = nullptr ;
197
+ if (has_bias) {
198
+ bias_ptr = test_case.bias .data ();
199
+ }
200
+ prepare_weight_data<weight_nbit>(
171
201
(void *)weight_data.data (),
172
202
n,
173
203
k,
174
204
group_size,
175
205
test_case.weight_qvals .data (),
176
206
test_case.weight_scales .data (),
177
- test_case. weight_zeros . data () ,
178
- test_case. bias . data () );
207
+ weight_zeros_ptr ,
208
+ bias_ptr );
179
209
180
210
std::vector<float > output (m * k);
181
211
for (auto _ : state) {
182
- kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp >(
212
+ kernel<weight_nbit>(
183
213
output.data (),
184
214
/* output_m_stride=*/ n,
185
215
m,
@@ -189,7 +219,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
189
219
weight_data.data (),
190
220
activation_data.data (),
191
221
test_case.clamp_min ,
192
- test_case.clamp_max );
222
+ test_case.clamp_max ,
223
+ has_weight_zeros,
224
+ has_bias,
225
+ has_clamp);
193
226
}
194
227
}
195
228
0 commit comments