Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit 0da14f8

Browse files
edis219bhushan23
authored andcommitted
Remove constant for Gemm and Add node for A if A is an constant in Gemm (#472)
* 1 . remove constant for Gemm 2. add node for A if A is an constant * re-work on PR#472 to fit the convention
1 parent f0a9327 commit 0da14f8

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

onnx_coreml/_operators_nd.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,11 +674,21 @@ def _convert_gemm(builder, node, graph, err):
674674
transB = node.attrs.get('transB', False)
675675

676676
A = node.inputs[0]
677+
if A in node.input_tensors:
678+
A_tensor = node.input_tensors[A]
679+
builder.add_load_constant_nd(
680+
name=node.name + A + "_const",
681+
output_name='const_' + A,
682+
constant_value=A_tensor,
683+
shape=A_tensor.shape
684+
)
685+
A = 'const_'+A
686+
677687
if alpha != 1.0:
678688
builder.add_load_constant_nd(
679689
name=node.name + '_load_alpha',
680690
output_name='alpha_for_'+A,
681-
constant_value=alpha,
691+
constant_value=np.array([alpha]),
682692
shape=[1]
683693
)
684694
builder.add_multiply_broadcastable(
@@ -715,7 +725,7 @@ def _convert_gemm(builder, node, graph, err):
715725
builder.add_load_constant_nd(
716726
name=node.name + '_load_beta',
717727
output_name='beta_for_'+B,
718-
constant_value=beta,
728+
constant_value=np.array([beta]),
719729
shape=[1]
720730
)
721731
builder.add_multiply_broadcastable(

onnx_coreml/_transformers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def __call__(self, graph): # type: (Graph) -> Graph
284284
if any([s == 0 for s in shape]):
285285
continue
286286

287-
reshaped_tensor = tensor.reshape(shape)
287+
reshaped_tensor = tensor.reshape(shape.astype(int))
288288

289289
for child in node.children:
290290
child.parents.remove(node)
@@ -760,6 +760,21 @@ def __call__(self, graph): # type: (Graph) -> Graph
760760
axes = node.attrs.get('axes', None)
761761
output = np.squeeze(x, axis = tuple(axes))
762762
transformation_performed = True
763+
elif node.op_type == 'Gemm':
764+
alpha = node.attrs.get('alpha', 1.0)
765+
beta = node.attrs.get('beta', 1.0)
766+
transA = node.attrs.get('transA', False)
767+
transB = node.attrs.get('transB', False)
768+
769+
A_tensor = node.input_tensors[node.inputs[0]]
770+
B_tensor = node.input_tensors[node.inputs[1]]
771+
C_tensor = node.input_tensors[node.inputs[2]]
772+
773+
A_tensor = np.transpose(A_tensor) if transA else A_tensor
774+
B_tensor = np.transpose(B_tensor) if transB else B_tensor
775+
776+
output = alpha * np.dot(A_tensor, B_tensor) + beta * C_tensor
777+
transformation_performed = True
763778

764779
if transformation_performed:
765780
nodes_to_be_removed.append(node)

0 commit comments

Comments
 (0)