@@ -38,9 +38,8 @@ static ProtobufShutter protobufShutter;
38
38
void setTensorLocations (
39
39
ImporterContext* ctx, std::vector<std::string> const & tensors, std::vector<std::string> const & locations)
40
40
{
41
- ONNXTRT_CHECK ((tensors.size () >= locations.size ())
42
- && " The size of tensors misaligns with the size of the attribute trt_outputs_loc." ,
43
- nvonnxparser::ErrorCode::kINVALID_GRAPH );
41
+ ONNXTRT_CHECK (tensors.size () >= locations.size (),
42
+ " The size of tensors misaligns with the size of the attribute trt_outputs_loc." , ErrorCode::kINVALID_GRAPH );
44
43
for (size_t i = 0 ; i < locations.size (); ++i)
45
44
{
46
45
std::string tensor = tensors.at (i);
@@ -50,8 +49,8 @@ void setTensorLocations(
50
49
51
50
if (ctx->tensorLocations ().count (tensor) > 0 )
52
51
{
53
- ONNXTRT_CHECK (( ctx->tensorLocations ()[tensor] == loc) && " The tensor location cannot be changed." ,
54
- nvonnxparser:: ErrorCode::kINVALID_GRAPH );
52
+ ONNXTRT_CHECK (ctx->tensorLocations ()[tensor] == loc, " The tensor location cannot be changed." ,
53
+ ErrorCode::kINVALID_GRAPH );
55
54
}
56
55
else
57
56
{
@@ -65,16 +64,19 @@ template <typename T>
65
64
void setStringMap (
66
65
ImporterContext* ctx, std::vector<std::string> const & tensors, std::vector<T> const & data, StringMap<T>& map)
67
66
{
68
- ONNXTRT_CHECK (( tensors.size () >= data.size ())
69
- && " The size of tensors misaligns with the size of the attribute trt_outputs_range_min/max." ,
70
- nvonnxparser:: ErrorCode::kINVALID_GRAPH );
67
+ ONNXTRT_CHECK (tensors.size () >= data.size (),
68
+ " The size of tensors misaligns with the size of the attribute trt_outputs_range_min/max." ,
69
+ ErrorCode::kINVALID_GRAPH );
71
70
for (size_t i = 0 ; i < data.size (); ++i)
72
71
{
73
72
std::string name = tensors.at (i);
74
73
T dataName = data.at (i);
75
74
if (map.count (name) > 0 )
76
75
{
77
- ONNXTRT_CHECK ( (map[name] == dataName) && " The order of tensorRangeMin/Max in context misaligns with the order of the attribute trt_outputs_range_min/max." , nvonnxparser::ErrorCode::kINVALID_GRAPH );
76
+ ONNXTRT_CHECK (map[name] == dataName,
77
+ " The order of tensorRangeMin/Max in context misaligns with the order of the attribute "
78
+ " trt_outputs_range_min/max." ,
79
+ ErrorCode::kINVALID_GRAPH );
78
80
}
79
81
else
80
82
{
@@ -163,7 +165,14 @@ void parseNode(
163
165
LOG_VERBOSE (ssInputs.str ());
164
166
165
167
// UINT8 weights that are not Q/DQ inputs will be converted to INT32
166
- if (node.op_type () != " QuantizeLinear" && node.op_type () != " DequantizeLinear" )
168
+ // If the UINT8 quantization flag is enabled, constants with UINT8 will also be permitted.
169
+ uint32_t uint8AsymmetricQuantizationFlag = 1U
170
+ << static_cast <uint32_t >(nvonnxparser::OnnxParserFlag::kENABLE_UINT8_AND_ASYMMETRIC_QUANTIZATION_DLA );
171
+ bool allowUint8Quantization = ctx->getFlags () & uint8AsymmetricQuantizationFlag;
172
+
173
+ bool skipUInt8Conversion = (node.op_type () == " QuantizeLinear" || node.op_type () == " DequantizeLinear"
174
+ || (allowUint8Quantization && node.op_type () == " Constant" ));
175
+ if (!skipUInt8Conversion)
167
176
{
168
177
for (auto & nodeInput : nodeInputs)
169
178
{
@@ -289,20 +298,26 @@ void parseNode(
289
298
{
290
299
ctx->registerTensor (std::move (output), outputName);
291
300
}
292
- // UINT8 is only allowed as network inputs, network outputs, and constants for QDQ nodes. Therefore any
293
- // non-constant node that produces an UINT8-typed output that is not also a graph output is unsupported.
294
- if (output.getType () == " UINT8" && node.op_type () != " Constant" )
301
+ // UINT8 is only allowed as network inputs, network outputs, and constants for QDQ nodes unless the UINT8
302
+ // quantization flag is set. If the UINT8 quantization flag is set, then UINT8 is also permitted as a
303
+ // QuantizeLinear output or Gather output (when they feed into a dequantize node). Other than the cases listed,
304
+ // any non-constant node that produces an UINT8-typed output that is not also a graph output is unsupported.
305
+ if (output.getType () == " UINT8" )
295
306
{
296
- bool legalUINT8 = false ;
307
+ bool legalUINT8 = node.op_type () == " Constant"
308
+ || (allowUint8Quantization && (node.op_type () == " Gather" || node.op_type () == " QuantizeLinear" ));
297
309
for (auto const & graphOutput : ctx->getGraphOutputNames ())
298
310
{
299
311
if (graphOutput.name () == outputName)
300
312
{
301
313
legalUINT8 = true ;
314
+ break ;
302
315
}
303
316
}
304
- ONNXTRT_CHECK_NODE (legalUINT8, " TensorRT does not support UINT8 types for intermediate tensors!" , node,
305
- nodeIdx, ErrorCode::kUNSUPPORTED_NODE );
317
+ ONNXTRT_CHECK_NODE (legalUINT8,
318
+ " TensorRT does not support UINT8 types for intermediate tensors. For UINT8 quantization, the "
319
+ " kIMPORT_UINT8_QUANTIZATION flag must be set. (DLA version >= 3.16 only)" ,
320
+ node, nodeIdx, ErrorCode::kUNSUPPORTED_NODE );
306
321
}
307
322
trtCnt++;
308
323
}
@@ -366,9 +381,8 @@ void parseGraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& graph,
366
381
{
367
382
LOG_VERBOSE (" Importing initializer: " << initializer.name ());
368
383
ShapedWeights weights;
369
- ONNXTRT_CHECK (
370
- ctx->getWeightsContext ().convertOnnxWeights (initializer, &weights) && " Failed to import initializer." ,
371
- ErrorCode::kUNSUPPORTED_NODE );
384
+ ONNXTRT_CHECK (ctx->getWeightsContext ().convertOnnxWeights (initializer, &weights),
385
+ " Failed to import initializer: " << initializer.name (), ErrorCode::kUNSUPPORTED_NODE );
372
386
ctx->registerTensor (TensorOrWeights{std::move (weights)}, initializer.name ());
373
387
}
374
388
}
@@ -385,7 +399,7 @@ void parseGraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& graph,
385
399
386
400
std::vector<size_t > topoOrder;
387
401
ONNXTRT_CHECK (
388
- toposort (graph.node (), &topoOrder) && " Failed to sort the model topologically." , ErrorCode::kINVALID_GRAPH );
402
+ toposort (graph.node (), &topoOrder), " Failed to sort the model topologically." , ErrorCode::kINVALID_GRAPH );
389
403
390
404
for (auto const & nodeIndex : topoOrder)
391
405
{
@@ -682,7 +696,7 @@ bool ModelImporter::isSubgraphSupported(int64_t const index) noexcept
682
696
errorMessage << " Query index " << index
683
697
<< " exceeds subgraph support vector (size = " << mSubGraphSupportVector .size ()
684
698
<< " ). Have you called supports_model_v2?" ;
685
- ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index) && errorMessage.str (). c_str (),
699
+ ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index), errorMessage.str (),
686
700
ErrorCode::kINVALID_VALUE );
687
701
return mSubGraphSupportVector [index].second ;
688
702
}
@@ -698,7 +712,7 @@ int64_t* ModelImporter::getSubgraphNodes(int64_t const index, int64_t& subgraphL
698
712
errorMessage << " Query index " << index
699
713
<< " exceeds subgraph support vector (size = " << mSubGraphSupportVector .size ()
700
714
<< " ). Have you called supports_model_v2?" ;
701
- ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index) && errorMessage.str (). c_str (),
715
+ ONNXTRT_CHECK (mSubGraphSupportVector .size () > static_cast <uint64_t >(index), errorMessage.str (),
702
716
ErrorCode::kINVALID_VALUE );
703
717
subgraphLength = mSubGraphSupportVector [index].first .size ();
704
718
return mSubGraphSupportVector [index].first .data ();
@@ -769,8 +783,8 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
769
783
mImporterCtx .clearOpsets ();
770
784
// Add domain import limit for security reasons
771
785
int32_t const MAX_DOMAINS = 1024 ;
772
- ONNXTRT_CHECK (model.opset_import ().size () <= MAX_DOMAINS
773
- && " Model contains more than 1024 domains! Parsing will halt for security reasons." ,
786
+ ONNXTRT_CHECK (model.opset_import ().size () <= MAX_DOMAINS,
787
+ " Model contains more than 1024 domains! Parsing will halt for security reasons." ,
774
788
ErrorCode::kUNSUPPORTED_GRAPH );
775
789
for (int32_t i = 0 ; i < model.opset_import ().size (); ++i)
776
790
{
@@ -808,8 +822,8 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
808
822
// Mark outputs defined in the ONNX model (unless tensors are user-requested)
809
823
for (::ONNX_NAMESPACE::ValueInfoProto const & output : graph.output ())
810
824
{
811
- ONNXTRT_CHECK ((mImporterCtx .tensors ().count (output.name ())) && " The output tensor was not registered. " ,
812
- ErrorCode::kINVALID_GRAPH );
825
+ ONNXTRT_CHECK ((mImporterCtx .tensors ().count (output.name ())),
826
+ " The output tensor " << output. name () << " was not registered. " , ErrorCode::kINVALID_GRAPH );
813
827
nvinfer1::ITensor* output_tensor_ptr
814
828
= &convertToTensor (mImporterCtx .tensors ().at (output.name ()), &mImporterCtx );
815
829
LOG_VERBOSE (" Marking " << output_tensor_ptr->getName () << " as output: " << output.name ());
@@ -821,21 +835,19 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
821
835
// TODO: Does this break things by changing the name of the input tensor?
822
836
output_tensor_ptr->setName ((" __" + output.name ()).c_str ());
823
837
output_tensor_ptr = &identity (&mImporterCtx , output_tensor_ptr).tensor ();
824
- ONNXTRT_CHECK (output_tensor_ptr && " Failed to add an Identity layer." , ErrorCode::kUNSUPPORTED_NODE );
838
+ ONNXTRT_CHECK (output_tensor_ptr, " Failed to add an Identity layer." , ErrorCode::kUNSUPPORTED_NODE );
825
839
output_tensor_ptr->setName (output.name ().c_str ());
826
840
}
827
841
828
842
mImporterCtx .network ()->markOutput (*output_tensor_ptr);
829
843
nvinfer1::DataType output_trt_dtype;
830
844
831
- ONNXTRT_CHECK (convertDtype (output.type ().tensor_type ().elem_type (), &output_trt_dtype)
832
- && " Failed to convert ONNX date type to TensorRT data type." ,
833
- ErrorCode::kUNSUPPORTED_NODE );
845
+ ONNXTRT_CHECK (convertDtype (output.type ().tensor_type ().elem_type (), &output_trt_dtype),
846
+ " Failed to convert ONNX date type to TensorRT data type." , ErrorCode::kUNSUPPORTED_NODE );
834
847
// For INT32 data type, output type must match tensor type
835
848
ONNXTRT_CHECK ((output_tensor_ptr->getType () != nvinfer1::DataType::kINT32
836
- || output_trt_dtype == nvinfer1::DataType::kINT32 )
837
- && " For INT32 tensors, the output type must also be INT32." ,
838
- ErrorCode::kUNSUPPORTED_NODE );
849
+ || output_trt_dtype == nvinfer1::DataType::kINT32 ),
850
+ " For INT32 tensors, the output type must also be INT32." , ErrorCode::kUNSUPPORTED_NODE );
839
851
// Note: Without this, output type is always float32
840
852
output_tensor_ptr->setType (output_trt_dtype);
841
853
if (output_trt_dtype == nvinfer1::DataType::kINT64 )
@@ -890,15 +902,15 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
890
902
// Set locations for all tensors
891
903
for (auto const & tensor : ctx->tensorLocations ())
892
904
{
893
- ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ) && " The tensor does not have an assigned location." ,
905
+ ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ), " The tensor does not have an assigned location." ,
894
906
nvonnxparser::ErrorCode::kINVALID_GRAPH );
895
907
tensors.at (tensor.first )->setLocation (tensor.second );
896
908
}
897
909
// Set dynamic range for all tensors
898
910
for (auto const & tensor : ctx->tensorRangeMins ())
899
911
{
900
912
// if there's a min range, there must be a max range as well
901
- ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ) && " The tensor does not have an assigned location ." ,
913
+ ONNXTRT_CHECK ((tensors.count (tensor.first ) > 0 ), " The tensor does not have its dynamic range set ." ,
902
914
nvonnxparser::ErrorCode::kINVALID_GRAPH );
903
915
if (!std::isnan (tensor.second ))
904
916
{
@@ -911,7 +923,7 @@ void ModelImporter::importModel(::ONNX_NAMESPACE::ModelProto const& model)
911
923
// Set precisions for all layers.
912
924
for (auto const & layer : ctx->layerPrecisions ())
913
925
{
914
- ONNXTRT_CHECK ((layers.count (layer.first ) > 0 ) && " The layer does not have an assigned precision." ,
926
+ ONNXTRT_CHECK ((layers.count (layer.first ) > 0 ), " The layer does not have an assigned precision." ,
915
927
nvonnxparser::ErrorCode::kINVALID_GRAPH );
916
928
layers.at (layer.first )->setPrecision (layer.second );
917
929
}
@@ -932,6 +944,7 @@ bool ModelImporter::parseFromFile(char const* onnxModelFile, int32_t verbosity)
932
944
{
933
945
ONNXTRT_TRY
934
946
{
947
+ ONNXTRT_CHECK (onnxModelFile, " Input file cannot be empty." , ErrorCode::kINVALID_VALUE );
935
948
auto * ctx = &mImporterCtx ;
936
949
937
950
// Define S_ISREG macro for Windows
@@ -940,23 +953,16 @@ bool ModelImporter::parseFromFile(char const* onnxModelFile, int32_t verbosity)
940
953
#endif
941
954
942
955
struct stat sb;
943
- if (stat (onnxModelFile, &sb) == 0 && !S_ISREG (sb.st_mode ))
944
- {
945
- LOG_ERROR (" Input is not a regular file: " << onnxModelFile);
946
- return false ;
947
- }
956
+ ONNXTRT_CHECK (stat (onnxModelFile, &sb) == 0 && S_ISREG (sb.st_mode ),
957
+ " Input file cannot be found, or is not a regular file: " << onnxModelFile, ErrorCode::kINVALID_VALUE );
948
958
949
959
GOOGLE_PROTOBUF_VERIFY_VERSION;
950
960
951
961
// Own the ONNX model for weights to persist.
952
962
mONNXModels .emplace_back ();
953
963
::ONNX_NAMESPACE::ModelProto& onnxModel = mONNXModels .back ();
954
- bool const fileLoadSuccess = ParseFromFileAsBinary (&onnxModel, onnxModelFile);
955
- if (!fileLoadSuccess)
956
- {
957
- LOG_ERROR (" Failed to parse ONNX model from file: " << onnxModelFile << " !" );
958
- return false ;
959
- }
964
+ ONNXTRT_CHECK (ParseFromFileAsBinary (&onnxModel, onnxModelFile),
965
+ " Cannot read from input file: " << onnxModelFile, ErrorCode::kINVALID_VALUE );
960
966
961
967
// Keep track of the absolute path to the ONNX file.
962
968
mImporterCtx .setOnnxFileLocation (onnxModelFile);
0 commit comments