Skip to content

Commit 31ddb07

Browse files
authored
Automatically download SPANet weights if not found (#215)
* Automatically download SPANet weights if not found * Run precommit * Automatically download SPANet weights if not found * Run precommit * Change SSH username to nano@nano (#200) * Map weights to CUDA or CPU depending on which is enabled (#216) * Automatically download SPANet weights if not found * Run precommit
1 parent 2f5be8f commit 31ddb07

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

ada_feeding_action_select/ada_feeding_action_select/adapters/spanet_adapter.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
# Standard imports
1111
import os
12+
import gdown
1213

1314
# Third-party imports
1415
from ament_index_python.packages import get_package_share_directory
@@ -33,7 +34,8 @@ class SPANetContext(ContextAdapter):
3334

3435
def __init__(
3536
self,
36-
checkpoint: str,
37+
checkpoint_url: str,
38+
checkpoint_path: str,
3739
n_features: int = 2048,
3840
gpu_index: int = 0,
3941
) -> None:
@@ -63,8 +65,23 @@ def __init__(
6365

6466
# Load Checkpoint
6567
ckpt_file = os.path.join(
66-
get_package_share_directory("ada_feeding_action_select"), "data", checkpoint
68+
get_package_share_directory("ada_feeding_action_select"),
69+
"data",
70+
checkpoint_path,
6771
)
72+
if not os.path.exists(ckpt_file):
73+
logger.info(
74+
f"Checkpoint file not found at {ckpt_file}. Downloading from {checkpoint_url}..."
75+
)
76+
77+
try:
78+
gdown.download(checkpoint_url, ckpt_file, quiet=False)
79+
logger.info(f"Checkpoint file downloaded successfully to {ckpt_file}")
80+
except Exception as e:
81+
raise RuntimeError(f"Error downloading checkpoint: {e}")
82+
else:
83+
logger.info(f"Checkpoint file found at {ckpt_file}. Loading...")
84+
6885
ckpt = torch.load(ckpt_file, map_location=self.device)
6986
self.spanet.load_state_dict(ckpt["net"])
7087
self.spanet.eval()

ada_feeding_action_select/config/policies.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ ada_feeding_action_select:
1717

1818
context_class: ada_feeding_action_select.adapters.SPANetContext
1919
context_kws:
20-
- checkpoint # Relative to share data directory
20+
- checkpoint_url
21+
- checkpoint_path # Relative to share data directory
2122
context_kwargs:
22-
checkpoint: checkpoint/adapter/food_spanet_all_rgb_wall_ckpt_best.pth
23+
checkpoint_url: "https://drive.google.com/uc?id=1BsFe3xyex2_e7MWQEA3Q4oZEzrjzLtiH&export=download" # Direct download link
24+
checkpoint_path: "checkpoint/adapter/food_spanet_all_rgb_wall_ckpt_best.pth" # Local path (relative to share dir)
2325

2426
#context_class: ada_feeding_action_select.adapters.ColorContext
2527

0 commit comments

Comments
 (0)