From 113beb789be9add8067cdf385ada2adb4589485c Mon Sep 17 00:00:00 2001 From: Willyzw Date: Sun, 17 Jan 2021 00:07:20 +0100 Subject: [PATCH] minor fixes performed minor fixes to 1. enable running on cpu only 2. resolve pytorch-scatter cuda version conflict --- README.md | 4 ++-- generate.py | 2 +- src/checkpoints.py | 8 ++++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 432aa3d..eedb26e 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ You can create an anaconda environment called `conv_onet` using conda env create -f environment.yaml conda activate conv_onet ``` -**Note**: you might need to install **torch-scatter** mannually following [the official instruction](https://github.com/rusty1s/pytorch_scatter#pytorch-140): +**Note**: you might need to (re-)install **torch-scatter** mannually following [the official instruction](https://github.com/rusty1s/pytorch_scatter#pytorch-140): ``` -pip install torch-scatter==2.0.4 -f https://pytorch-geometric.com/whl/torch-1.4.0+cu101.html +pip install torch-scatter==2.0.4 --force-reinstall -f https://pytorch-geometric.com/whl/torch-1.4.0+cu101.html ``` Next, compile the extension modules. diff --git a/generate.py b/generate.py index 85fc59d..d11f638 100644 --- a/generate.py +++ b/generate.py @@ -40,7 +40,7 @@ # Model model = config.get_model(cfg, device=device, dataset=dataset) -checkpoint_io = CheckpointIO(out_dir, model=model) +checkpoint_io = CheckpointIO(out_dir, model=model, is_cuda=is_cuda) checkpoint_io.load(cfg['test']['model_file']) # Generator diff --git a/src/checkpoints.py b/src/checkpoints.py index 4070b42..edd04f3 100644 --- a/src/checkpoints.py +++ b/src/checkpoints.py @@ -12,8 +12,9 @@ class CheckpointIO(object): Args: checkpoint_dir (str): path where checkpoints are saved ''' - def __init__(self, checkpoint_dir='./chkpts', **kwargs): + def __init__(self, checkpoint_dir='./chkpts', is_cuda=True, **kwargs): self.module_dict = kwargs + self.is_cuda = is_cuda self.checkpoint_dir = checkpoint_dir if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) @@ -75,7 +76,10 @@ def load_url(self, url): ''' print(url) print('=> Loading checkpoint from url...') - state_dict = model_zoo.load_url(url, progress=True) + if self.is_cuda: + state_dict = model_zoo.load_url(url, progress=True) + else: + state_dict = model_zoo.load_url(url, progress=True, map_location=torch.device('cpu')) scalars = self.parse_state_dict(state_dict) return scalars