Skip to content

Commit 2651ee4

Browse files
authored
rewrite pytorch forward wrapper (#1121)
1 parent fbc5ee0 commit 2651ee4

File tree

1 file changed

+51
-55
lines changed

1 file changed

+51
-55
lines changed

neural_compressor/adaptor/pytorch.py

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -97,80 +97,76 @@ def get_torch_white_list(approach):
9797

9898

9999
def pytorch_forward_wrapper(model, input, device='cpu', conf=None, running_mode='inference'):
100-
if device == "ipex" and IPEX_110: # pragma: no cover
101-
if isinstance(input, torch.Tensor):
102-
if running_mode == "calibration":
103-
with ipex.quantization.calibrate(conf, default_recipe=True):
104-
input = input.contiguous(memory_format=torch.channels_last)
105-
output = model(input)
106-
else:
107-
input = input.contiguous(memory_format=torch.channels_last)
108-
output = model(input)
109-
elif isinstance(input, list) or isinstance(input, tuple):
110-
if running_mode == "calibration":
111-
with ipex.quantization.calibrate(conf, default_recipe=True):
112-
output = model(*input)
113-
else:
114-
output = model(*input)
115-
elif isinstance(input, dict):
116-
if running_mode == "calibration":
117-
with ipex.quantization.calibrate(conf, default_recipe=True):
118-
output = model(**input)
119-
else:
120-
output = model(**input)
121-
elif device == "ipex" and IPEX_112: # pragma: no cover
122-
if isinstance(input, torch.Tensor):
123-
input = input.contiguous(memory_format=torch.channels_last)
124-
output = model(input)
125-
elif isinstance(input, list) or isinstance(input, tuple):
126-
output = model(*input)
127-
elif isinstance(input, dict):
100+
if isinstance(input, dict) or isinstance(input, UserDict):
101+
if device=='cpu':
128102
output = model(**input)
129-
else:
130-
if isinstance(input, dict) or isinstance(input, UserDict):
131-
if device=='cpu':
103+
elif device=='ipex': # pragma: no cover
104+
# have to split the case to avoid exposing ipex.DEVICE outside
105+
# which require intel extension installed
106+
if IPEX_110:
107+
if running_mode == "calibration":
108+
with ipex.quantization.calibrate(conf, default_recipe=True):
109+
output = model(**input)
110+
else:
111+
output = model(**input)
112+
elif IPEX_112:
132113
output = model(**input)
133-
elif device=='ipex': # pragma: no cover
134-
# have to split the case to avoid exposing ipex.DEVICE outside
135-
# which require intel extension installed
114+
else:
136115
for inp in input.keys():
137116
input[inp] = input[inp].to(ipex.DEVICE) \
138117
if isinstance(input[inp], torch.Tensor) else input[inp]
139118
with ipex.AutoMixPrecision(conf, running_mode=running_mode):
140119
output = model(**input)
141-
else: # pragma: no cover
142-
for inp in input.keys():
143-
input[inp] = input[inp].to("dpcpp" if device=="gpu" else device) \
144-
if isinstance(input[inp], torch.Tensor) else input[inp]
145-
output = model(**input)
146-
elif isinstance(input, list) or isinstance(input, tuple):
147-
if device=='cpu':
120+
else: # pragma: no cover
121+
for inp in input.keys():
122+
input[inp] = input[inp].to("dpcpp" if device=="gpu" else device) \
123+
if isinstance(input[inp], torch.Tensor) else input[inp]
124+
output = model(**input)
125+
elif isinstance(input, list) or isinstance(input, tuple):
126+
if device=='cpu':
127+
output = model(*input)
128+
elif device=='ipex': # pragma: no cover
129+
if IPEX_110:
130+
if running_mode == "calibration":
131+
with ipex.quantization.calibrate(conf, default_recipe=True):
132+
output = model(*input)
133+
else:
134+
output = model(*input)
135+
elif IPEX_112:
148136
output = model(*input)
149-
elif device=='ipex': # pragma: no cover
137+
else:
150138
input = [inp.to(ipex.DEVICE) \
151139
if isinstance(inp, torch.Tensor) else inp
152140
for inp in input]
153141
with ipex.AutoMixPrecision(conf, running_mode=running_mode):
154142
output = model(*input)
155-
else: # pragma: no cover
156-
tmp_device = "dpcpp" if device=="gpu" else device
157-
input = [inp.to(tmp_device) \
158-
if isinstance(inp, torch.Tensor) else inp
159-
for inp in input] # pylint: disable=E1133
160-
output = model(*input)
161-
else:
162-
if device=='cpu' or not isinstance(input, torch.Tensor):
143+
else: # pragma: no cover
144+
tmp_device = "dpcpp" if device=="gpu" else device
145+
input = [inp.to(tmp_device) \
146+
if isinstance(inp, torch.Tensor) else inp
147+
for inp in input] # pylint: disable=E1133
148+
output = model(*input)
149+
else:
150+
if device=='cpu' or not isinstance(input, torch.Tensor):
151+
output = model(input)
152+
elif device=='ipex': # pragma: no cover
153+
if IPEX_110:
154+
if running_mode == "calibration":
155+
with ipex.quantization.calibrate(conf, default_recipe=True):
156+
output = model(input)
157+
else:
158+
output = model(input)
159+
elif IPEX_112:
163160
output = model(input)
164-
elif device=='ipex': # pragma: no cover
161+
else:
165162
input = input.to(ipex.DEVICE)
166163
with ipex.AutoMixPrecision(conf, running_mode=running_mode):
167164
output = model(input)
168-
else: # pragma: no cover
169-
input = input.to("dpcpp" if device=="gpu" else device) # pylint: disable=no-member
170-
output = model(input)
165+
else: # pragma: no cover
166+
input = input.to("dpcpp" if device=="gpu" else device) # pylint: disable=no-member
167+
output = model(input)
171168
return output
172169

173-
174170
def get_ops_recursively(model, prefix, ops={}):
175171
"""This is a helper function for `graph_info`,
176172
and it will get all ops from model.

0 commit comments

Comments
 (0)