@@ -18,7 +18,7 @@ def transform(self, model, node):
18
18
dim = len (node .get_input_variable ().shape ) - 1
19
19
input_shape = node .get_input_variable ().shape
20
20
21
- pointwise_attrs = {
21
+ conv_attrs = {
22
22
'data_format' : 'channels_last' ,
23
23
'padding' : 'valid' ,
24
24
'n_chan' : input_shape [- 1 ],
@@ -28,7 +28,7 @@ def transform(self, model, node):
28
28
}
29
29
30
30
if dim == 1 :
31
- pointwise_attrs .update (
31
+ conv_attrs .update (
32
32
{
33
33
'in_width' : input_shape [0 ],
34
34
'out_width' : input_shape [0 ],
@@ -39,7 +39,7 @@ def transform(self, model, node):
39
39
}
40
40
)
41
41
elif dim == 2 :
42
- pointwise_attrs .update (
42
+ conv_attrs .update (
43
43
{
44
44
'in_height' : input_shape [0 ],
45
45
'in_width' : input_shape [1 ],
@@ -59,7 +59,7 @@ def transform(self, model, node):
59
59
raise Exception ('Cannot replace Dense over {dim}D tensor with Conv{dim}D.' .format (dim = dim ))
60
60
61
61
class_name = 'Conv' + str (dim ) + 'D'
62
- pw_node = model .make_node (class_name , node .name , pointwise_attrs , node .inputs .copy ())
63
- model .replace_node (node , pw_node )
62
+ conv_node = model .make_node (class_name , node .name , conv_attrs , node .inputs .copy ())
63
+ model .replace_node (node , conv_node )
64
64
65
65
return True
0 commit comments