Skip to content

Commit b2d607f

Browse files
authored
Fix onnxrt split op (#1686)
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent 6a013c5 commit b2d607f

File tree

3 files changed

+6
-1
lines changed
  • examples/onnxrt/nlp/huggingface_model/token_classification/layoutlm/quantization
  • neural_compressor/adaptor/ox_utils/operators

3 files changed

+6
-1
lines changed

examples/onnxrt/nlp/huggingface_model/token_classification/layoutlm/quantization/ptq_dynamic/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def get_label_list(labels):
244244
cache_dir=model_args.cache_dir,
245245
revision=model_args.model_revision,
246246
use_auth_token=True if model_args.use_auth_token else None,
247+
use_safetensors = False,
247248
)
248249

249250
# Tokenizer check: this script requires a fast tokenizer.

examples/onnxrt/nlp/huggingface_model/token_classification/layoutlm/quantization/ptq_static/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def get_label_list(labels):
248248
cache_dir=model_args.cache_dir,
249249
revision=model_args.model_revision,
250250
use_auth_token=True if model_args.use_auth_token else None,
251+
use_safetensors = False,
251252
)
252253

253254
# Tokenizer check: this script requires a fast tokenizer.

neural_compressor/adaptor/ox_utils/operators/split.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ def convert_check(self, convert_format):
4343
node = self.node
4444
assert convert_format in ["static"], "convert format for {} should be in ['static']".format(node.op_type)
4545

46-
parent = self.quantizer.model.get_parents(node)[0]
46+
parents = self.quantizer.model.get_parents(node)
47+
if len(parents) == 0:
48+
return False
49+
parent = parents[0]
4750
children = self.quantizer.model.get_children(node)
4851
if (
4952
parent.op_type != "DequantizeLinear" or len(children) == 0 or not node.name.endswith("_quant")

0 commit comments

Comments
 (0)