diff --git a/thop/fx_profile.py b/thop/fx_profile.py index 8faadf7..f8b85b4 100644 --- a/thop/fx_profile.py +++ b/thop/fx_profile.py @@ -2,9 +2,9 @@ import torch import torch as th import torch.nn as nn -from distutils.version import LooseVersion +import pkg_resources as pkg -if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): +if pkg.parse_version(torch.__version__) < pkg.parse_version("1.8.0"): logging.warning( f"torch.fx requires version higher than 1.8.0. " f"But You are using an old version PyTorch {torch.__version__}. " diff --git a/thop/profile.py b/thop/profile.py index 6b15d27..db2c84e 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -1,4 +1,4 @@ -from distutils.version import LooseVersion +import pkg_resources as pkg from thop.vision.basic_hooks import * from thop.rnn_hooks import * @@ -9,7 +9,7 @@ from .utils import prGreen, prRed, prYellow -if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): +if pkg.parse_version(torch.__version__) < pkg.parse_version("1.0.0"): logging.warning( "You are using an old version PyTorch {version}, which THOP does NOT support.".format( version=torch.__version__ @@ -65,7 +65,7 @@ nn.PixelShuffle: zero_ops, } -if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): +if pkg.parse_version(torch.__version__) >= pkg.parse_version("1.1.0"): register_hooks.update({nn.SyncBatchNorm: count_normalization})