Skip to content

Commit 6eecaa5

Browse files
fix allowed reuse factors for transpose layers
1 parent 49ea6d4 commit 6eecaa5

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,22 @@ def get_layer_mult_size(self, layer):
7979
n_out = layer.get_attr('n_out')
8080
return n_in, n_out
8181

82+
if 'Conv1DTranspose' in layer.class_name:
83+
trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) \
84+
// layer.get_attr('stride_width')
85+
n_in = layer.get_attr('n_chan') * trfilt_width
86+
n_out = layer.get_attr('n_filt')
87+
return n_in, n_out
88+
89+
if 'Conv2DTranspose' in layer.class_name:
90+
trfilt_width = (layer.get_attr('filt_width') + layer.get_attr('stride_width') - 1) \
91+
// layer.get_attr('stride_width')
92+
trfilt_height = (layer.get_attr('filt_height') + layer.get_attr('stride_height') - 1) \
93+
// layer.get_attr('stride_height')
94+
n_in = layer.get_attr('n_chan') * trfilt_height * trfilt_width
95+
n_out = layer.get_attr('n_filt')
96+
return n_in, n_out
97+
8298
if 'Conv1D' in layer.class_name:
8399
n_in = layer.get_attr('n_chan') * layer.get_attr('filt_width')
84100
n_out = layer.get_attr('n_filt')

0 commit comments

Comments
 (0)