Skip to content

Commit 1e74bc6

Browse files
Read port range from container support for TFS port (#106)
1 parent cfa6520 commit 1e74bc6

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

src/tf_container/serve.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
from tf_container.run import logger
3333
import time
3434

35-
36-
TF_SERVING_PORT = 9000
35+
DEFAULT_TF_SERVING_PORT = 9000
3736
GENERIC_MODEL_NAME = "generic_model"
3837
TF_SERVING_MAXIMUM_LOAD_MODEL_TIME_IN_SECONDS = 60 * 15
3938

@@ -97,17 +96,23 @@ def _recursive_copy(src, dst):
9796

9897

9998
def transformer(user_module):
100-
grpc_proxy_client = proxy_client.GRPCProxyClient(TF_SERVING_PORT)
99+
env = cs.HostingEnvironment()
100+
101+
port = int(cs.Server.next_safe_port(env.port_range)) if env.port_range else DEFAULT_TF_SERVING_PORT
102+
103+
grpc_proxy_client = proxy_client.GRPCProxyClient(port)
101104
_wait_model_to_load(grpc_proxy_client, TF_SERVING_MAXIMUM_LOAD_MODEL_TIME_IN_SECONDS)
102105

103106
return Transformer.from_module(user_module, grpc_proxy_client)
104107

105108

106109
def load_dependencies():
107110
env = cs.HostingEnvironment()
111+
112+
port = cs.Server.next_safe_port(env.port_range) if env.port_range else DEFAULT_TF_SERVING_PORT
108113
saved_model_path = os.path.join(env.model_dir, 'export/Servo')
109114
subprocess.Popen(['tensorflow_model_server',
110-
'--port={}'.format(TF_SERVING_PORT),
115+
'--port={}'.format(port),
111116
'--model_name={}'.format(GENERIC_MODEL_NAME),
112117
'--model_base_path={}'.format(saved_model_path)])
113118

test/unit/test_serve.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from container_support.serving import UnsupportedAcceptTypeError, UnsupportedContentTypeError
2222

2323
JSON_CONTENT_TYPE = "application/json"
24+
FIRST_PORT = '1111'
25+
LAST_PORT = '2222'
26+
SAFE_PORT_RANGE = '{}-{}'.format(FIRST_PORT, LAST_PORT)
2427

2528

2629
@pytest.fixture(scope="module")
@@ -271,11 +274,15 @@ def test_transformer_method(proxy_client, serve):
271274

272275

273276
@patch('subprocess.Popen')
274-
def test_load_dependencies(popen, serve):
277+
@patch('container_support.HostingEnvironment')
278+
def test_load_dependencies_with_default_port(hosting_env, popen, serve):
275279
with patch('os.environ') as env:
276280
env['SAGEMAKER_PROGRAM'] = 'script.py'
277281
env['SAGEMAKER_SUBMIT_DIRECTORY'] = 's3://what/ever'
278282

283+
hosting_env.return_value.port_range = None
284+
hosting_env.return_value.model_dir = '/opt/ml/model'
285+
279286
serve.Transformer.from_module = Mock()
280287
serve.load_dependencies()
281288

@@ -285,6 +292,25 @@ def test_load_dependencies(popen, serve):
285292
'--model_base_path=/opt/ml/model/export/Servo'])
286293

287294

295+
@patch('subprocess.Popen')
296+
@patch('container_support.HostingEnvironment')
297+
def test_load_dependencies_with_safe_port(hosting_env, popen, serve):
298+
with patch('os.environ') as env:
299+
env['SAGEMAKER_PROGRAM'] = 'script.py'
300+
env['SAGEMAKER_SUBMIT_DIRECTORY'] = 's3://what/ever'
301+
302+
hosting_env.return_value.port_range = SAFE_PORT_RANGE
303+
hosting_env.return_value.model_dir = '/opt/ml/model'
304+
305+
serve.Transformer.from_module = Mock()
306+
serve.load_dependencies()
307+
308+
popen.assert_called_with(['tensorflow_model_server',
309+
'--port={}'.format(FIRST_PORT),
310+
'--model_name=generic_model',
311+
'--model_base_path=/opt/ml/model/export/Servo'])
312+
313+
288314
@patch('tf_container.proxy_client.GRPCProxyClient')
289315
def test_wait_model_to_load(proxy_client, serve):
290316
client = proxy_client()

0 commit comments

Comments
 (0)