14
14
from . import fx
15
15
from torch ._decomp import get_decompositions
16
16
17
-
18
17
def import_exported_model (
19
18
prog : torch .export .ExportedProgram ,
20
19
output_type : str ,
@@ -24,18 +23,44 @@ def import_exported_model(
24
23
decomp_table = get_decompositions (
25
24
[torch .ops .aten .lstm .input , torch .ops .aten .gru .input ]
26
25
)
27
-
28
26
prog = prog .run_decompositions (decomp_table )
29
27
28
+ backend_legal_ops = None
29
+
30
30
match output_type :
31
31
case "torch" :
32
32
output_type = OutputType .TORCH
33
33
case "tosa" :
34
34
output_type = OutputType .TOSA
35
+ backend_legal_ops = [
36
+ "aten.flatten.using_ints" ,
37
+ "aten.native_layer_norm" ,
38
+ "aten.adaptive_avg_pool1d" ,
39
+ "aten.adaptive_avg_pool2d" ,
40
+ "aten.adaptive_max_pool1d" ,
41
+ "aten.adaptive_max_pool2d" ,
42
+ "aten.linear" ]
35
43
case "linalg_on_tensors" :
36
44
output_type = OutputType .LINALG_ON_TENSORS
45
+ backend_legal_ops = [
46
+ "aten.flatten.using_ints" ,
47
+ "aten.adaptive_avg_pool1d" ,
48
+ "aten.adaptive_avg_pool2d" ,
49
+ "aten.adaptive_max_pool1d" ,
50
+ "aten.adaptive_max_pool2d" ,
51
+ "aten.unflatten.int" ,
52
+ ]
37
53
case "tosa_linalg" :
38
54
output_type = OutputType .TOSA_LINALG
55
+ backend_legal_ops = [
56
+ "aten.flatten.using_ints" ,
57
+ "aten.native_layer_norm" ,
58
+ "aten.adaptive_avg_pool1d" ,
59
+ "aten.adaptive_avg_pool2d" ,
60
+ "aten.adaptive_max_pool1d" ,
61
+ "aten.adaptive_max_pool2d" ,
62
+ "aten.linear" ,
63
+ "aten.unflatten.int" ]
39
64
case "raw" :
40
65
output_type = OutputType .RAW
41
66
case _:
@@ -49,6 +74,12 @@ def import_exported_model(
49
74
50
75
if output_type != OutputType .RAW :
51
76
backend_legal_op_arg_str = ""
77
+ if backend_legal_ops is not None :
78
+ if not len (backend_legal_ops ) == 0 :
79
+ backend_legal_op_arg_str = "backend-legal-ops=" + "," .join (
80
+ backend_legal_ops
81
+ )
82
+
52
83
extra_library_file_name = ""
53
84
option_string = (
54
85
"{"
0 commit comments