@@ -97,80 +97,76 @@ def get_torch_white_list(approach):
97
97
98
98
99
99
def pytorch_forward_wrapper (model , input , device = 'cpu' , conf = None , running_mode = 'inference' ):
100
- if device == "ipex" and IPEX_110 : # pragma: no cover
101
- if isinstance (input , torch .Tensor ):
102
- if running_mode == "calibration" :
103
- with ipex .quantization .calibrate (conf , default_recipe = True ):
104
- input = input .contiguous (memory_format = torch .channels_last )
105
- output = model (input )
106
- else :
107
- input = input .contiguous (memory_format = torch .channels_last )
108
- output = model (input )
109
- elif isinstance (input , list ) or isinstance (input , tuple ):
110
- if running_mode == "calibration" :
111
- with ipex .quantization .calibrate (conf , default_recipe = True ):
112
- output = model (* input )
113
- else :
114
- output = model (* input )
115
- elif isinstance (input , dict ):
116
- if running_mode == "calibration" :
117
- with ipex .quantization .calibrate (conf , default_recipe = True ):
118
- output = model (** input )
119
- else :
120
- output = model (** input )
121
- elif device == "ipex" and IPEX_112 : # pragma: no cover
122
- if isinstance (input , torch .Tensor ):
123
- input = input .contiguous (memory_format = torch .channels_last )
124
- output = model (input )
125
- elif isinstance (input , list ) or isinstance (input , tuple ):
126
- output = model (* input )
127
- elif isinstance (input , dict ):
100
+ if isinstance (input , dict ) or isinstance (input , UserDict ):
101
+ if device == 'cpu' :
128
102
output = model (** input )
129
- else :
130
- if isinstance (input , dict ) or isinstance (input , UserDict ):
131
- if device == 'cpu' :
103
+ elif device == 'ipex' : # pragma: no cover
104
+ # have to split the case to avoid exposing ipex.DEVICE outside
105
+ # which require intel extension installed
106
+ if IPEX_110 :
107
+ if running_mode == "calibration" :
108
+ with ipex .quantization .calibrate (conf , default_recipe = True ):
109
+ output = model (** input )
110
+ else :
111
+ output = model (** input )
112
+ elif IPEX_112 :
132
113
output = model (** input )
133
- elif device == 'ipex' : # pragma: no cover
134
- # have to split the case to avoid exposing ipex.DEVICE outside
135
- # which require intel extension installed
114
+ else :
136
115
for inp in input .keys ():
137
116
input [inp ] = input [inp ].to (ipex .DEVICE ) \
138
117
if isinstance (input [inp ], torch .Tensor ) else input [inp ]
139
118
with ipex .AutoMixPrecision (conf , running_mode = running_mode ):
140
119
output = model (** input )
141
- else : # pragma: no cover
142
- for inp in input .keys ():
143
- input [inp ] = input [inp ].to ("dpcpp" if device == "gpu" else device ) \
144
- if isinstance (input [inp ], torch .Tensor ) else input [inp ]
145
- output = model (** input )
146
- elif isinstance (input , list ) or isinstance (input , tuple ):
147
- if device == 'cpu' :
120
+ else : # pragma: no cover
121
+ for inp in input .keys ():
122
+ input [inp ] = input [inp ].to ("dpcpp" if device == "gpu" else device ) \
123
+ if isinstance (input [inp ], torch .Tensor ) else input [inp ]
124
+ output = model (** input )
125
+ elif isinstance (input , list ) or isinstance (input , tuple ):
126
+ if device == 'cpu' :
127
+ output = model (* input )
128
+ elif device == 'ipex' : # pragma: no cover
129
+ if IPEX_110 :
130
+ if running_mode == "calibration" :
131
+ with ipex .quantization .calibrate (conf , default_recipe = True ):
132
+ output = model (* input )
133
+ else :
134
+ output = model (* input )
135
+ elif IPEX_112 :
148
136
output = model (* input )
149
- elif device == 'ipex' : # pragma: no cover
137
+ else :
150
138
input = [inp .to (ipex .DEVICE ) \
151
139
if isinstance (inp , torch .Tensor ) else inp
152
140
for inp in input ]
153
141
with ipex .AutoMixPrecision (conf , running_mode = running_mode ):
154
142
output = model (* input )
155
- else : # pragma: no cover
156
- tmp_device = "dpcpp" if device == "gpu" else device
157
- input = [inp .to (tmp_device ) \
158
- if isinstance (inp , torch .Tensor ) else inp
159
- for inp in input ] # pylint: disable=E1133
160
- output = model (* input )
161
- else :
162
- if device == 'cpu' or not isinstance (input , torch .Tensor ):
143
+ else : # pragma: no cover
144
+ tmp_device = "dpcpp" if device == "gpu" else device
145
+ input = [inp .to (tmp_device ) \
146
+ if isinstance (inp , torch .Tensor ) else inp
147
+ for inp in input ] # pylint: disable=E1133
148
+ output = model (* input )
149
+ else :
150
+ if device == 'cpu' or not isinstance (input , torch .Tensor ):
151
+ output = model (input )
152
+ elif device == 'ipex' : # pragma: no cover
153
+ if IPEX_110 :
154
+ if running_mode == "calibration" :
155
+ with ipex .quantization .calibrate (conf , default_recipe = True ):
156
+ output = model (input )
157
+ else :
158
+ output = model (input )
159
+ elif IPEX_112 :
163
160
output = model (input )
164
- elif device == 'ipex' : # pragma: no cover
161
+ else :
165
162
input = input .to (ipex .DEVICE )
166
163
with ipex .AutoMixPrecision (conf , running_mode = running_mode ):
167
164
output = model (input )
168
- else : # pragma: no cover
169
- input = input .to ("dpcpp" if device == "gpu" else device ) # pylint: disable=no-member
170
- output = model (input )
165
+ else : # pragma: no cover
166
+ input = input .to ("dpcpp" if device == "gpu" else device ) # pylint: disable=no-member
167
+ output = model (input )
171
168
return output
172
169
173
-
174
170
def get_ops_recursively (model , prefix , ops = {}):
175
171
"""This is a helper function for `graph_info`,
176
172
and it will get all ops from model.
0 commit comments