16
16
import boto3
17
17
import pytest
18
18
from sagemaker import Session
19
+ from sagemaker .tensorflow import TensorFlow
19
20
20
21
logger = logging .getLogger (__name__ )
21
22
logging .getLogger ('boto' ).setLevel (logging .INFO )
@@ -31,6 +32,8 @@ def pytest_addoption(parser):
31
32
parser .addoption ('--instance-type' )
32
33
parser .addoption ('--accelerator-type' , default = None )
33
34
parser .addoption ('--region' , default = 'us-west-2' )
35
+ parser .addoption ('--framework-version' , default = TensorFlow .LATEST_VERSION )
36
+ parser .addoption ('--processor' , default = 'cpu' , choices = ['gpu' , 'cpu' ])
34
37
parser .addoption ('--tag' )
35
38
36
39
@@ -60,8 +63,20 @@ def region(request):
60
63
61
64
62
65
@pytest .fixture (scope = 'session' )
63
- def tag (request ):
64
- return request .config .getoption ('--tag' )
66
+ def framework_version (request ):
67
+ return request .config .getoption ('--framework-version' )
68
+
69
+
70
+ @pytest .fixture (scope = 'session' )
71
+ def processor (request ):
72
+ return request .config .getoption ('--processor' )
73
+
74
+
75
+ @pytest .fixture (scope = 'session' )
76
+ def tag (request , framework_version , processor ):
77
+ provided_tag = request .config .getoption ('--tag' )
78
+ default_tag = '{}-{}-py2' .format (framework_version , processor )
79
+ return provided_tag if provided_tag is not None else default_tag
65
80
66
81
67
82
@pytest .fixture (scope = 'session' )
0 commit comments