Skip to content

Commit 1af0e2f

Browse files
authored
fix: smp will not be imported if not specified by user (#651)
* fix: smp will not be imported if not specified by user * fix: add pipeline_parallel_degree for smp after v1.60 * remove tf related model parallel vars * version bump
1 parent 6cb0d55 commit 1af0e2f

File tree

3 files changed

+13
-24
lines changed

3 files changed

+13
-24
lines changed

smdebug/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.29"
1+
__version__ = "1.0.30"

smdebug/core/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
SMDebugRuntimeError,
3636
SMDebugTypeError,
3737
SMDebugValueError,
38+
SMDebugError
3839
)
3940

4041

@@ -49,18 +50,18 @@ class FRAMEWORK(Enum):
4950
_smddp_tf_imported = None
5051
_smddp_pt_imported = None
5152
_is_using_smmodelparallel = None
53+
_smp_imported = None
5254

53-
try:
54-
import smdistributed.modelparallel.tensorflow as smp
5555

56-
_smp_imported = smp
57-
except (ImportError, ModuleNotFoundError):
56+
if check_smmodelparallel_training():
5857
try:
5958
import smdistributed.modelparallel.torch as smp
6059

6160
_smp_imported = smp
6261
except (ImportError, ModuleNotFoundError):
6362
_smp_imported = None
63+
except Exception as e:
64+
raise SMDebugError(e)
6465

6566

6667
try:
@@ -644,8 +645,13 @@ def check_smmodelparallel_training():
644645
else:
645646
try:
646647
smp_flag = json.loads(os.getenv("SM_HPS"))
647-
if "mp_parameters" in smp_flag and "partitions" in smp_flag["mp_parameters"]:
648-
_is_using_smmodelparallel = True
648+
if "mp_parameters" in smp_flag:
649+
if "pipeline_parallel_degree" in smp_flag["mp_parameters"]:
650+
_is_using_smmodelparallel = True
651+
elif "partitions" in smp_flag["mp_parameters"]:
652+
_is_using_smmodelparallel = True
653+
else:
654+
_is_using_smmodelparallel = False
649655
else:
650656
_is_using_smmodelparallel = False
651657
except:

smdebug/tensorflow/base_hook.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,6 @@
4040
load_tf_config_json,
4141
)
4242

43-
try:
44-
import smdistributed.modelparallel.tensorflow as smp # noqa isort:skip
45-
46-
_smp_imported = smp
47-
except ImportError:
48-
_smp_imported = None
49-
5043

5144
DEFAULT_INCLUDE_COLLECTIONS = [
5245
CollectionKeys.METRICS,
@@ -195,11 +188,6 @@ def _get_worker_name(self) -> str:
195188
"""
196189
self._assert_distribution_strategy()
197190
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
198-
if _smp_imported and _smp_imported.core.initialized:
199-
# when model parallel is being used, there will be multiple processes
200-
# with same hvd rank, hence use smp.rank
201-
return f"worker_{smp.rank()}"
202-
203191
import horovod.tensorflow as hvd
204192

205193
return f"worker_{hvd.rank()}"
@@ -277,11 +265,6 @@ def _get_custom_and_default_collections(self) -> Tuple[Set["Collection"], Set["C
277265
def _get_num_workers(self):
278266
self._assert_distribution_strategy()
279267
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
280-
if _smp_imported and smp.core.initialized:
281-
# when model parallel is being used, there will be multiple hvd process groups,
282-
# hence use smp.size
283-
return smp.size()
284-
285268
import horovod.tensorflow as hvd
286269

287270
return hvd.size()

0 commit comments

Comments
 (0)