Skip to content

Commit 6714dd1

Browse files
bo3zvloncar
authored andcommitted
Quartus Merge layers
1 parent 67234ce commit 6714dd1

File tree

4 files changed

+770
-28
lines changed

4 files changed

+770
-28
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from hls4ml.backends.backend import get_backend
2+
from hls4ml.model.layers import Concatenate, Dot, Merge
3+
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate
4+
5+
# TODO - Very similar to vivado/merge_templates.py - only difference is on line 67: get_backend('vivado').product_type(inp1.type.precision, inp2.type.precision)
6+
# TODO - Look into ways of having passes similar accross many backends in a shared folder thorugh inheritance and overriding.
7+
8+
# Merge templates
9+
merge_config_template = """struct config{index} : nnet::merge_config {{
10+
static const unsigned n_elem = {n_elem};
11+
}};\n"""
12+
13+
merge_function_template = 'nnet::{merge}<{input1_t}, {input2_t}, {output_t}, {config}>({input1}, {input2}, {output});'
14+
merge_include_list = ['nnet_utils/nnet_merge.h', 'nnet_utils/nnet_merge_stream.h']
15+
16+
class MergeConfigTemplate(LayerConfigTemplate):
17+
def __init__(self):
18+
super().__init__(Merge)
19+
self.template = merge_config_template
20+
21+
def format(self, node):
22+
params = self._default_config_params(node)
23+
params['n_elem'] = node.get_input_variable(node.inputs[0]).size_cpp()
24+
25+
return self.template.format(**params)
26+
27+
class MergeFunctionTemplate(FunctionCallTemplate):
28+
def __init__(self):
29+
super().__init__((Merge, Concatenate, Dot), include_header=merge_include_list)
30+
self.template = merge_function_template
31+
32+
def format(self, node):
33+
params = {}
34+
params['merge'] = node.get_attr('op').lower()
35+
params['config'] = 'config{}'.format(node.index)
36+
params['input1_t'] = node.get_input_variable(node.inputs[0]).type.name
37+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
38+
params['output_t'] = node.get_output_variable().type.name
39+
params['input1'] = node.get_input_variable(node.inputs[0]).name
40+
params['input2'] = node.get_input_variable(node.inputs[1]).name
41+
params['output'] = node.get_output_variable().name
42+
43+
return self.template.format(**params)
44+
45+
46+
# Dot templates
47+
dot_config_template = """struct config{index} : nnet::dot_config {{
48+
static const unsigned n_in = {n_in};
49+
static const unsigned n_out = {n_out};
50+
51+
static const unsigned reuse_factor = {reuse};
52+
53+
typedef {accum_t.name} accum_t;
54+
55+
template<class x_T, class y_T>
56+
using product = nnet::product::{product_type}<x_T, y_T>;
57+
}};\n"""
58+
59+
class DotConfigTemplate(LayerConfigTemplate):
60+
def __init__(self):
61+
super().__init__(Dot)
62+
self.template = dot_config_template
63+
64+
def format(self, node):
65+
inp1 = node.get_input_variable(node.inputs[0])
66+
inp2 = node.get_input_variable(node.inputs[1])
67+
params = node._default_config_params()
68+
params['n_out'] = 1
69+
params['n_in'] = inp1.shape[0]
70+
params['product_type'] = get_backend('quartus').product_type(inp1.type.precision, inp2.type.precision)
71+
72+
return self.template.format(**params)
73+
74+
75+
# Concatenate templates
76+
concat_config_template = """struct config{index} : nnet::concat_config {{
77+
static const unsigned n_elem1_0 = {n_elem1_0};
78+
static const unsigned n_elem1_1 = {n_elem1_1};
79+
static const unsigned n_elem1_2 = {n_elem1_2};
80+
static const unsigned n_elem2_0 = {n_elem2_0};
81+
static const unsigned n_elem2_1 = {n_elem2_1};
82+
static const unsigned n_elem2_2 = {n_elem2_2};
83+
84+
static const int axis = {axis};
85+
}};\n"""
86+
87+
class ConcatenateConfigTemplate(LayerConfigTemplate):
88+
def __init__(self):
89+
super().__init__(Concatenate)
90+
self.template = concat_config_template
91+
92+
def format(self, node):
93+
params = self._default_config_params(node)
94+
for i in range(3):
95+
params.setdefault('n_elem1_{}'.format(i), 0)
96+
params.setdefault('n_elem2_{}'.format(i), 0)
97+
inp1 = node.get_input_variable(node.inputs[0])
98+
inp2 = node.get_input_variable(node.inputs[1])
99+
for i, (s1, s2) in enumerate(zip(inp1.shape, inp2.shape)):
100+
params['n_elem1_{}'.format(i)] = s1
101+
params['n_elem2_{}'.format(i)] = s2
102+
103+
return self.template.format(**params)
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
#ifndef NNET_MERGE_H_
2+
#define NNET_MERGE_H_
3+
4+
#include "nnet_mult.h"
5+
6+
namespace nnet {
7+
8+
struct merge_config {
9+
static const unsigned n_elem = 10;
10+
};
11+
12+
struct dot_config {
13+
static const unsigned n_in = 10;
14+
static const unsigned n_out = 1;
15+
16+
static const unsigned reuse_factor = 1;
17+
18+
typedef float accum_t;
19+
20+
template<class x_T, class y_T>
21+
using product = nnet::product::mult<x_T, y_T>;
22+
};
23+
24+
struct concat_config {
25+
static const unsigned n_elem1_0 = 10;
26+
static const unsigned n_elem1_1 = 10;
27+
static const unsigned n_elem1_2 = 10;
28+
static const unsigned n_elem2_0 = 10;
29+
static const unsigned n_elem2_1 = 10;
30+
static const unsigned n_elem2_2 = 10;
31+
32+
static const unsigned axis = -1;
33+
};
34+
35+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
36+
void add(
37+
input1_T data1[CONFIG_T::n_elem],
38+
input2_T data2[CONFIG_T::n_elem],
39+
res_T res[CONFIG_T::n_elem]
40+
) {
41+
#pragma unroll
42+
for (int i = 0; i < CONFIG_T::n_elem; i++) {
43+
res[i] = static_cast<res_T>(data1[i] + data2[i]);
44+
}
45+
}
46+
47+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
48+
void subtract(
49+
input1_T data1[CONFIG_T::n_elem],
50+
input2_T data2[CONFIG_T::n_elem],
51+
res_T res[CONFIG_T::n_elem]
52+
) {
53+
#pragma unroll
54+
for (int i = 0; i < CONFIG_T::n_elem; i++) {
55+
res[i] = static_cast<res_T>(data1[i] - data2[i]);
56+
}
57+
}
58+
59+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
60+
void multiply(
61+
input1_T data1[CONFIG_T::n_elem],
62+
input2_T data2[CONFIG_T::n_elem],
63+
res_T res[CONFIG_T::n_elem]
64+
) {
65+
#pragma unroll
66+
for (int i = 0; i < CONFIG_T::n_elem; i++) {
67+
res[i] = static_cast<res_T>(data1[i] * data2[i]);
68+
}
69+
}
70+
71+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
72+
void average(
73+
input1_T data1[CONFIG_T::n_elem],
74+
input2_T data2[CONFIG_T::n_elem],
75+
res_T res[CONFIG_T::n_elem]
76+
) {
77+
#pragma unroll
78+
for (int i = 0; i < CONFIG_T::n_elem; i++) {
79+
res[i] = static_cast<res_T>((data1[i] + data2[i]) / (res_T) 2);
80+
}
81+
}
82+
83+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
84+
void maximum(
85+
input1_T data1[CONFIG_T::n_elem],
86+
input2_T data2[CONFIG_T::n_elem],
87+
res_T res[CONFIG_T::n_elem]
88+
) {
89+
#pragma unroll
90+
for (int i = 0; i < CONFIG_T::n_elem; i++) {
91+
res[i] = (data1[i] > data2[i]) ? static_cast<res_T>(data1[i]) : static_cast<res_T>(data2[i]);
92+
}
93+
}
94+
95+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
96+
void minimum(
97+
input1_T data1[CONFIG_T::n_elem],
98+
input2_T data2[CONFIG_T::n_elem],
99+
res_T res[CONFIG_T::n_elem]
100+
) {
101+
#pragma unroll
102+
for (int i = 0; i < CONFIG_T::n_elem; i++) {
103+
res[i] = (data1[i] < data2[i]) ? static_cast<res_T>(data1[i]) : static_cast<res_T>(data2[i]);
104+
}
105+
}
106+
107+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
108+
void dot1d(
109+
input1_T data1[CONFIG_T::n_in],
110+
input2_T data2[CONFIG_T::n_in],
111+
res_T res[CONFIG_T::n_out]
112+
) {
113+
constexpr unsigned multiplier_limit = DIV_ROUNDUP(CONFIG_T::n_in, CONFIG_T::reuse_factor);
114+
115+
hls_register typename CONFIG_T::accum_t mult[CONFIG_T::n_in];
116+
Product:
117+
#pragma unroll multiplier_limit
118+
for(int i=0; i < CONFIG_T::n_in; i++) {
119+
mult[i] = CONFIG_T::template product<input1_T, input2_T>::product(data1[i], data2[i]);
120+
}
121+
122+
hls_register typename CONFIG_T::accum_t acc = 0;
123+
Accum:
124+
#pragma unroll
125+
for(int i = 0; i < CONFIG_T::n_in; i++) {
126+
acc += mult[i];
127+
}
128+
129+
res[0] = static_cast<res_T>(acc);
130+
}
131+
132+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
133+
void concatenate1d(
134+
input1_T data1[CONFIG_T::n_elem1_0],
135+
input2_T data2[CONFIG_T::n_elem2_0],
136+
res_T res[CONFIG_T::n_elem1_0 + CONFIG_T::n_elem2_0]
137+
) {
138+
#pragma unroll
139+
for (int i = 0; i < CONFIG_T::n_elem1_0; i++) {
140+
res[i] = static_cast<res_T>(data1[i]);
141+
}
142+
143+
#pragma unroll
144+
for (int i = 0; i < CONFIG_T::n_elem2_0; i++) {
145+
res[CONFIG_T::n_elem1_0 + i] = static_cast<res_T>(data2[i]);
146+
}
147+
}
148+
149+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
150+
void concatenate2d_0(
151+
input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1],
152+
input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1],
153+
res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1]
154+
) {
155+
#pragma unroll
156+
for (int i = 0; i < CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1; i++) {
157+
res[i] = static_cast<res_T>(data1[i]);
158+
}
159+
160+
#pragma unroll
161+
for (int i = 0; i < CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1; i++) {
162+
res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + i] = static_cast<res_T>(data2[i]);
163+
}
164+
}
165+
166+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
167+
void concatenate2d_1(
168+
input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1],
169+
input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1],
170+
res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1]
171+
){
172+
for (int i = 0; i < CONFIG_T::n_elem1_0; i++) {
173+
#pragma unroll
174+
for (int j = 0; j < CONFIG_T::n_elem1_1; j++) {
175+
res[i * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) + j] = static_cast<res_T>(data1[i * CONFIG_T::n_elem1_1 + j]);
176+
}
177+
178+
#pragma unroll
179+
for (int j = 0; j < CONFIG_T::n_elem2_1; j++) {
180+
res[i * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) + CONFIG_T::n_elem1_1 + j] = static_cast<res_T>(data2[i * CONFIG_T::n_elem2_1 + j]);
181+
}
182+
}
183+
}
184+
185+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
186+
void concatenate2d(
187+
input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1],
188+
input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1],
189+
res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1]
190+
) {
191+
if (CONFIG_T::axis == 2 || CONFIG_T::axis == -1) {
192+
concatenate2d_1<input1_T, input2_T, res_T, CONFIG_T>(data1, data2, res);
193+
} else {
194+
concatenate2d_0<input1_T, input2_T, res_T, CONFIG_T>(data1, data2, res);
195+
}
196+
}
197+
198+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
199+
void concatenate3d_0(
200+
input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2],
201+
input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2],
202+
res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]
203+
) {
204+
#pragma unroll
205+
for (int i = 0; i < CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2; i++) {
206+
res[i] = static_cast<res_T>(data1[i]);
207+
}
208+
209+
#pragma unroll
210+
for (int i = 0; i < CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2; i++) {
211+
res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + i] = static_cast<res_T>(data2[i]);
212+
}
213+
}
214+
215+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
216+
void concatenate3d_1(
217+
input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2],
218+
input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2],
219+
res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]
220+
) {
221+
for (int i = 0; i < CONFIG_T::n_elem1_0; i++) {
222+
for (int j = 0; j < CONFIG_T::n_elem1_1; j++) {
223+
#pragma unroll
224+
for (int k = 0; k < CONFIG_T::n_elem1_2; k++) {
225+
int res_idx = i * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) * CONFIG_T::n_elem1_2 + j * CONFIG_T::n_elem1_2 + k;
226+
int data_idx = i * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + j * CONFIG_T::n_elem1_2 + k;
227+
res[res_idx] = static_cast<res_T>(data1[data_idx]);
228+
}
229+
}
230+
231+
for (int j = 0; j < CONFIG_T::n_elem2_1; j++) {
232+
#pragma unroll
233+
for (int k=0; k<CONFIG_T::n_elem2_2; k++) {
234+
int res_idx = i * (CONFIG_T::n_elem1_1 + CONFIG_T::n_elem2_1) * CONFIG_T::n_elem1_2 + (j + CONFIG_T::n_elem1_1) * CONFIG_T::n_elem1_2 + k;
235+
int data_idx = i * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2 + j * CONFIG_T::n_elem2_2 + k;
236+
res[res_idx] = static_cast<res_T>(data2[data_idx]);
237+
}
238+
}
239+
}
240+
}
241+
242+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
243+
void concatenate3d_2(
244+
input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2],
245+
input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2],
246+
res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]
247+
) {
248+
for (int i = 0; i < CONFIG_T::n_elem1_0; i++) {
249+
for (int j = 0; j < CONFIG_T::n_elem1_1; j++) {
250+
251+
#pragma unroll
252+
for (int k = 0; k < CONFIG_T::n_elem1_2; k++) {
253+
int res_idx = i * CONFIG_T::n_elem1_1 * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + j * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + k;
254+
int data_idx = i * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + j * CONFIG_T::n_elem1_2 + k;
255+
res[res_idx] = static_cast<res_T>(data1[data_idx]);
256+
}
257+
258+
#pragma unroll
259+
for (int k = 0; k < CONFIG_T::n_elem1_2; k++) {
260+
int res_idx = i * CONFIG_T::n_elem1_1 * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + j * (CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_2) + k + CONFIG_T::n_elem1_2;
261+
int data_idx = i * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2 + j * CONFIG_T::n_elem2_2 + k;
262+
res[res_idx] = static_cast<res_T>(data2[data_idx]);
263+
}
264+
}
265+
}
266+
}
267+
268+
template<class input1_T, class input2_T, class res_T, typename CONFIG_T>
269+
void concatenate3d(
270+
input1_T data1[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2],
271+
input2_T data2[CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2],
272+
res_T res[CONFIG_T::n_elem1_0 * CONFIG_T::n_elem1_1 * CONFIG_T::n_elem1_2 + CONFIG_T::n_elem2_0 * CONFIG_T::n_elem2_1 * CONFIG_T::n_elem2_2]
273+
) {
274+
if (CONFIG_T::axis == 3 || CONFIG_T::axis == -1) {
275+
concatenate3d_2<input1_T, input2_T, res_T, CONFIG_T>(data1, data2, res);
276+
} else if (CONFIG_T::axis == 2 || CONFIG_T::axis == -2) {
277+
concatenate3d_1<input1_T, input2_T, res_T, CONFIG_T>(data1, data2, res);
278+
} else {
279+
concatenate3d_0<input1_T, input2_T, res_T, CONFIG_T>(data1, data2, res);
280+
}
281+
}
282+
283+
}
284+
285+
#endif

0 commit comments

Comments
 (0)