File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -79,6 +79,22 @@ def get_layer_mult_size(self, layer):
79
79
n_out = layer .get_attr ('n_out' )
80
80
return n_in , n_out
81
81
82
+ if 'Conv1DTranspose' in layer .class_name :
83
+ trfilt_width = (layer .get_attr ('filt_width' ) + layer .get_attr ('stride_width' ) - 1 ) \
84
+ // layer .get_attr ('stride_width' )
85
+ n_in = layer .get_attr ('n_chan' ) * trfilt_width
86
+ n_out = layer .get_attr ('n_filt' )
87
+ return n_in , n_out
88
+
89
+ if 'Conv2DTranspose' in layer .class_name :
90
+ trfilt_width = (layer .get_attr ('filt_width' ) + layer .get_attr ('stride_width' ) - 1 ) \
91
+ // layer .get_attr ('stride_width' )
92
+ trfilt_height = (layer .get_attr ('filt_height' ) + layer .get_attr ('stride_height' ) - 1 ) \
93
+ // layer .get_attr ('stride_height' )
94
+ n_in = layer .get_attr ('n_chan' ) * trfilt_height * trfilt_width
95
+ n_out = layer .get_attr ('n_filt' )
96
+ return n_in , n_out
97
+
82
98
if 'Conv1D' in layer .class_name :
83
99
n_in = layer .get_attr ('n_chan' ) * layer .get_attr ('filt_width' )
84
100
n_out = layer .get_attr ('n_filt' )
You can’t perform that action at this time.
0 commit comments