Skip to content

Commit 1bc36c3

Browse files
pskiran1GuanLuo
andauthored
fix: Resolve integer overflow in Load API file decoding (#7787)
Co-authored-by: GuanLuo <41310872+GuanLuo@users.noreply.github.com>
1 parent 9175390 commit 1bc36c3

File tree

7 files changed

+235
-22
lines changed

7 files changed

+235
-22
lines changed

qa/L0_cuda_shared_memory/cuda_shared_memory_test.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@
3030

3131
sys.path.append("../common")
3232

33+
import base64
3334
import os
3435
import time
3536
import unittest
3637
from functools import partial
3738

3839
import infer_util as iu
3940
import numpy as np
41+
import requests
4042
import test_util as tu
4143
import tritonclient.grpc as grpcclient
4244
import tritonclient.http as httpclient
@@ -564,5 +566,108 @@ def callback(user_data, result, error):
564566
self._cleanup_server(shm_handles)
565567

566568

569+
class CudaSharedMemoryTestRawHttpRequest(unittest.TestCase):
570+
def setUp(self):
571+
self.url = "localhost:8000"
572+
self.client = httpclient.InferenceServerClient(url=self.url, verbose=True)
573+
self.valid_shm_handle = None
574+
575+
def tearDown(self):
576+
self.client.unregister_cuda_shared_memory()
577+
if self.valid_shm_handle:
578+
cshm.destroy_shared_memory_region(self.valid_shm_handle)
579+
self.client.close()
580+
581+
def _generate_mock_base64_raw_handle(self, data_length):
582+
original_data_length = data_length * 3 // 4
583+
random_data = b"A" * original_data_length
584+
encoded_data = base64.b64encode(random_data)
585+
586+
assert (
587+
len(encoded_data) == data_length
588+
), "Encoded data length does not match the required length."
589+
return encoded_data
590+
591+
def _send_register_cshm_request(self, raw_handle, device_id, byte_size, shm_name):
592+
cuda_shared_memory_register_request = {
593+
"raw_handle": {"b64": raw_handle.decode("utf-8")},
594+
"device_id": device_id,
595+
"byte_size": byte_size,
596+
}
597+
598+
url = "http://{}/v2/cudasharedmemory/region/{}/register".format(
599+
self.url, shm_name
600+
)
601+
headers = {"Content-Type": "application/json"}
602+
603+
# Send POST request
604+
response = requests.post(
605+
url, headers=headers, json=cuda_shared_memory_register_request
606+
)
607+
return response
608+
609+
def test_exceeds_cshm_handle_size_limit(self):
610+
# byte_size greater than INT_MAX
611+
byte_size = 1 << 31
612+
device_id = 0
613+
shm_name = "invalid_shm"
614+
615+
raw_handle = self._generate_mock_base64_raw_handle(byte_size)
616+
response = self._send_register_cshm_request(
617+
raw_handle, device_id, byte_size, shm_name
618+
)
619+
self.assertNotEqual(response.status_code, 200)
620+
621+
try:
622+
error_message = response.json().get("error", "")
623+
self.assertIn(
624+
"'raw_handle' exceeds the maximum allowed data size limit INT_MAX",
625+
error_message,
626+
)
627+
except ValueError:
628+
self.fail("Response is not valid JSON")
629+
630+
def test_invalid_small_cshm_handle(self):
631+
byte_size = 64
632+
device_id = 0
633+
shm_name = "invalid_shm"
634+
635+
raw_handle = self._generate_mock_base64_raw_handle(byte_size)
636+
response = self._send_register_cshm_request(
637+
raw_handle, device_id, byte_size, shm_name
638+
)
639+
self.assertNotEqual(response.status_code, 200)
640+
641+
try:
642+
error_message = response.json().get("error", "")
643+
self.assertIn(
644+
"'raw_handle' must be a valid base64 encoded cudaIpcMemHandle_t",
645+
error_message,
646+
)
647+
except ValueError:
648+
self.fail("Response is not valid JSON")
649+
650+
def test_valid_cshm_handle(self):
651+
byte_size = 64
652+
device_id = 0
653+
shm_name = "test_shm"
654+
655+
# Create valid shared memory
656+
self.valid_shm_handle = cshm.create_shared_memory_region(
657+
shm_name, byte_size, device_id
658+
)
659+
raw_handle = cshm.get_raw_handle(self.valid_shm_handle)
660+
661+
response = self._send_register_cshm_request(
662+
raw_handle, device_id, byte_size, shm_name
663+
)
664+
self.assertEqual(response.status_code, 200)
665+
666+
# Verify shared memory status
667+
status = self.client.get_cuda_shared_memory_status()
668+
self.assertEqual(len(status), 1)
669+
self.assertEqual(status[0]["name"], shm_name)
670+
671+
567672
if __name__ == "__main__":
568673
unittest.main()

qa/L0_cuda_shared_memory/test.sh

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,38 @@ for i in \
8484
done
8585
done
8686

87+
for i in \
88+
test_exceeds_cshm_handle_size_limit \
89+
test_invalid_small_cshm_handle \
90+
test_valid_cshm_handle; do
91+
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1"
92+
SERVER_LOG="./$i.server.log"
93+
CLIENT_LOG="./$i.client.log"
94+
run_server
95+
if [ "$SERVER_PID" == "0" ]; then
96+
echo -e "\n***\n*** Failed to start $SERVER\n***"
97+
cat $SERVER_LOG
98+
exit 1
99+
fi
100+
echo "Test: $i, client type: HTTP" >>$CLIENT_LOG
101+
set +e
102+
python $SHM_TEST CudaSharedMemoryTestRawHttpRequest.$i >>$CLIENT_LOG 2>&1
103+
if [ $? -ne 0 ]; then
104+
echo -e "\n***\n*** Test Failed\n***"
105+
RET=1
106+
else
107+
check_test_results $TEST_RESULT_FILE 1
108+
if [ $? -ne 0 ]; then
109+
cat $CLIENT_LOG
110+
echo -e "\n***\n*** Test Result Verification Failed\n***"
111+
RET=1
112+
fi
113+
fi
114+
set -e
115+
kill $SERVER_PID
116+
wait $SERVER_PID
117+
done
118+
87119
mkdir -p python_models/simple/1/
88120
cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/
89121
cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/

qa/L0_http/http_test.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/python
2-
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -29,6 +29,8 @@
2929

3030
sys.path.append("../common")
3131

32+
import base64
33+
import json
3234
import threading
3335
import time
3436
import unittest
@@ -44,6 +46,9 @@ class HttpTest(tu.TestResultCollector):
4446
def _get_infer_url(self, model_name):
4547
return "http://localhost:8000/v2/models/{}/infer".format(model_name)
4648

49+
def _get_load_model_url(self, model_name):
50+
return "http://localhost:8000/v2/repository/models/{}/load".format(model_name)
51+
4752
def _raw_binary_helper(
4853
self, model, input_bytes, expected_output_bytes, extra_headers={}
4954
):
@@ -231,6 +236,43 @@ def test_descriptive_status_code(self):
231236
)
232237
t.join()
233238

239+
def test_loading_large_invalid_model(self):
240+
# Generate large base64 encoded data
241+
data_length = 1 << 31
242+
int_max = (1 << 31) - 1
243+
random_data = b"A" * data_length
244+
encoded_data = base64.b64encode(random_data)
245+
246+
assert (
247+
len(encoded_data) > int_max
248+
), "Encoded data length does not match the required length."
249+
250+
# Prepare payload with large base64 encoded data
251+
payload = {
252+
"parameters": {
253+
"config": json.dumps({"backend": "onnxruntime"}),
254+
"file:1/model.onnx": encoded_data.decode("utf-8"),
255+
}
256+
}
257+
headers = {"Content-Type": "application/json"}
258+
259+
# Send POST request
260+
response = requests.post(
261+
self._get_load_model_url("invalid_onnx"), headers=headers, json=payload
262+
)
263+
264+
# Assert the response is not successful
265+
self.assertNotEqual(response.status_code, 200)
266+
try:
267+
error_message = response.json().get("error", "")
268+
self.assertIn(
269+
"'file:1/model.onnx' exceeds the maximum allowed data size limit "
270+
"INT_MAX",
271+
error_message,
272+
)
273+
except ValueError:
274+
self.fail("Response is not valid JSON")
275+
234276

235277
if __name__ == "__main__":
236278
unittest.main()

qa/L0_http/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ fi
624624

625625
TEST_RESULT_FILE='test_results.txt'
626626
PYTHON_TEST=http_test.py
627-
EXPECTED_NUM_TESTS=9
627+
EXPECTED_NUM_TESTS=10
628628
set +e
629629
python $PYTHON_TEST >$CLIENT_LOG 2>&1
630630
if [ $? -ne 0 ]; then

src/common.cc

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -27,11 +27,16 @@
2727
#include "common.h"
2828

2929
#include <algorithm>
30+
#include <climits>
3031
#include <iterator>
3132

3233
#include "restricted_features.h"
3334
#include "triton/core/tritonserver.h"
3435

36+
extern "C" {
37+
#include <b64/cdecode.h>
38+
}
39+
3540
namespace triton { namespace server {
3641

3742
TRITONSERVER_Error*
@@ -102,4 +107,27 @@ Contains(const std::vector<std::string>& vec, const std::string& str)
102107
return std::find(vec.begin(), vec.end(), str) != vec.end();
103108
}
104109

110+
TRITONSERVER_Error*
111+
DecodeBase64(
112+
const char* input, size_t input_len, std::vector<char>& decoded_data,
113+
size_t& decoded_size, const std::string& name)
114+
{
115+
if (input_len > static_cast<size_t>(INT_MAX)) {
116+
return TRITONSERVER_ErrorNew(
117+
TRITONSERVER_ERROR_INVALID_ARG,
118+
("'" + name + "' exceeds the maximum allowed data size limit INT_MAX")
119+
.c_str());
120+
}
121+
122+
// The decoded size cannot be larger than the input
123+
decoded_data.resize(input_len + 1);
124+
base64_decodestate state;
125+
base64_init_decodestate(&state);
126+
127+
decoded_size =
128+
base64_decode_block(input, input_len, decoded_data.data(), &state);
129+
130+
return nullptr;
131+
}
132+
105133
}} // namespace triton::server

src/common.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -167,6 +167,18 @@ int64_t GetElementCount(const std::vector<int64_t>& dims);
167167
/// \return True if the str is found, false otherwise.
168168
bool Contains(const std::vector<std::string>& vec, const std::string& str);
169169

170+
/// Decodes a Base64 encoded string and stores the result in a vector.
171+
///
172+
/// \param input The Base64 encoded input string to decode.
173+
/// \param input_len The length of the input string.
174+
/// \param decoded_data A vector to store the decoded data.
175+
/// \param decoded_size The size of the decoded data.
176+
/// \param name The name associated with the decoding process.
177+
/// \return The error status.
178+
TRITONSERVER_Error* DecodeBase64(
179+
const char* input, size_t input_len, std::vector<char>& decoded_data,
180+
size_t& decoded_size, const std::string& name);
181+
170182
/// Joins container of strings into a single string delimited by
171183
/// 'delim'.
172184
///

src/http_server.cc

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,6 @@
4646
#define TRITONJSON_STATUSSUCCESS nullptr
4747
#include "triton/common/triton_json.h"
4848

49-
extern "C" {
50-
#include <b64/cdecode.h>
51-
}
52-
5349
namespace triton { namespace server {
5450

5551
#define RETURN_AND_CALLBACK_IF_ERR(X, CALLBACK) \
@@ -1546,14 +1542,12 @@ HTTPAPIServer::HandleRepositoryControl(
15461542
param = TRITONSERVER_ParameterNew(
15471543
m.c_str(), TRITONSERVER_PARAMETER_STRING, param_str);
15481544
} else if (m.rfind("file:", 0) == 0) {
1549-
// Decode base64
1550-
base64_decodestate s;
1551-
base64_init_decodestate(&s);
1552-
1553-
// The decoded can not be larger than the input...
1554-
binary_files.emplace_back(std::vector<char>(param_len + 1));
1555-
size_t decoded_size = base64_decode_block(
1556-
param_str, param_len, binary_files.back().data(), &s);
1545+
size_t decoded_size;
1546+
binary_files.emplace_back(std::vector<char>());
1547+
RETURN_AND_RESPOND_IF_ERR(
1548+
req, DecodeBase64(
1549+
param_str, param_len, binary_files.back(),
1550+
decoded_size, m));
15571551
param = TRITONSERVER_ParameterBytesNew(
15581552
m.c_str(), binary_files.back().data(), decoded_size);
15591553
}
@@ -2443,13 +2437,13 @@ HTTPAPIServer::HandleCudaSharedMemory(
24432437
}
24442438

24452439
if (err == nullptr) {
2446-
base64_decodestate s;
2447-
base64_init_decodestate(&s);
2440+
size_t decoded_size;
2441+
std::vector<char> raw_handle;
2442+
RETURN_AND_RESPOND_IF_ERR(
2443+
req, DecodeBase64(
2444+
b64_handle, b64_handle_len, raw_handle, decoded_size,
2445+
"raw_handle"));
24482446

2449-
// The decoded can not be larger than the input...
2450-
std::vector<char> raw_handle(b64_handle_len + 1);
2451-
size_t decoded_size = base64_decode_block(
2452-
b64_handle, b64_handle_len, raw_handle.data(), &s);
24532447
if (decoded_size != sizeof(cudaIpcMemHandle_t)) {
24542448
err = TRITONSERVER_ErrorNew(
24552449
TRITONSERVER_ERROR_INVALID_ARG,

0 commit comments

Comments
 (0)