Skip to content

Commit b8a4c36

Browse files
authored
Allow warm-starting through pre-trained policies for fine-tuning (#480)
* Initial go at adding support for pre-trained policies * Reconstruct policy for imitation learning * Add tests for warmstart feature * Suggested changes and fixes * Typos and linting issues * Included Adam suggestions * Fix multiple linting issues * Fix sphinx version
1 parent f3c870b commit b8a4c36

File tree

7 files changed

+132
-13
lines changed

7 files changed

+132
-13
lines changed

src/imitation/scripts/common/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def make_venv(
138138
) -> vec_env.VecEnv:
139139
"""Builds the vector environment.
140140
141-
Args:
141+
Args:
142142
env_name: The environment to train in.
143143
num_vec: Number of `gym.Env` instances to combine into a vector environment.
144144
parallel: Whether to use "true" parallelism. If True, then use `SubProcVecEnv`.

src/imitation/scripts/common/rl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ def load_rl_algo_from_path(
144144
_seed: int,
145145
) -> base_class.BaseAlgorithm:
146146
agent = serialize.load_stable_baselines_model(
147-
rl_cls,
148-
agent_path,
149-
venv,
147+
cls=rl_cls,
148+
path=agent_path,
149+
venv=venv,
150150
seed=_seed,
151151
**rl_kwargs,
152152
)

src/imitation/scripts/config/train_adversarial.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def defaults():
2929
algorithm_specific = {} # algorithm_specific[algorithm] is merged with config
3030

3131
checkpoint_interval = 0 # Num epochs between checkpoints (<0 disables)
32+
agent_path = None # Path to load agent from, optional.
3233

3334

3435
@train_adversarial_ex.config

src/imitation/scripts/config/train_imitation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def config():
4040
expert_policy_type=None, # 'ppo', 'random', or 'zero'
4141
total_timesteps=1e5,
4242
)
43+
agent_path = None # Path to load agent from, optional.
4344

4445

4546
@train_imitation_ex.config

src/imitation/scripts/train_adversarial.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import os
55
import os.path as osp
6-
from typing import Any, Mapping, Type
6+
from typing import Any, Mapping, Optional, Type
77

88
import sacred.commands
99
import torch as th
@@ -72,6 +72,7 @@ def train_adversarial(
7272
algorithm_kwargs: Mapping[str, Any],
7373
total_timesteps: int,
7474
checkpoint_interval: int,
75+
agent_path: Optional[str],
7576
) -> Mapping[str, Mapping[str, float]]:
7677
"""Train an adversarial-network-based imitation learning algorithm.
7778
@@ -94,6 +95,10 @@ def train_adversarial(
9495
`checkpoint_interval` rounds and after training is complete. If 0,
9596
then only save weights after training is complete. If <0, then don't
9697
save weights at all.
98+
agent_path: Path to a directory containing a pre-trained agent. If
99+
provided, then the agent will be initialized using this stored policy
100+
(warm start). If not provided, then the agent will be initialized using
101+
a random policy.
97102
98103
Returns:
99104
A dictionary with two keys. "imit_stats" gives the return value of
@@ -111,7 +116,12 @@ def train_adversarial(
111116
expert_trajs = demonstrations.load_expert_trajs()
112117

113118
venv = common_config.make_venv()
114-
gen_algo = rl.make_rl_algo(venv)
119+
120+
if agent_path is None:
121+
gen_algo = rl.make_rl_algo(venv)
122+
else:
123+
gen_algo = rl.load_rl_algo_from_path(agent_path=agent_path, venv=venv)
124+
115125
reward_net = reward.make_reward_net(venv)
116126

117127
logger.info(f"Using '{algo_cls}' algorithm")

src/imitation/scripts/train_imitation.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import logging
44
import os.path as osp
5+
import warnings
56
from typing import Any, Mapping, Optional, Type
67

78
from sacred.observers import FileStorageObserver
89
from stable_baselines3.common import policies, utils, vec_env
910

10-
from imitation.algorithms.bc import BC
11+
from imitation.algorithms import bc as bc_algorithm
1112
from imitation.algorithms.dagger import SimpleDAggerTrainer
1213
from imitation.data import rollout
1314
from imitation.policies import serialize
@@ -22,13 +23,19 @@ def make_policy(
2223
venv: vec_env.VecEnv,
2324
policy_cls: Type[policies.BasePolicy],
2425
policy_kwargs: Mapping[str, Any],
26+
agent_path: Optional[str],
2527
) -> policies.BasePolicy:
2628
"""Makes policy.
2729
2830
Args:
2931
venv: Vectorized environment we will be imitating demos from.
3032
policy_cls: Type of a Stable Baselines3 policy architecture.
33+
Specify only if policy_path is not specified.
3134
policy_kwargs: Keyword arguments for policy constructor.
35+
Specify only if policy_path is not specified.
36+
agent_path: Path to serialized policy. If provided, then load the
37+
policy from this path. Otherwise, make a new policy.
38+
Specify only if policy_cls and policy_kwargs are not specified.
3239
3340
Returns:
3441
A Stable Baselines3 policy.
@@ -43,7 +50,14 @@ def make_policy(
4350
"lr_schedule": utils.get_schedule_fn(1),
4451
},
4552
)
46-
policy = policy_cls(**policy_kwargs)
53+
if agent_path is not None:
54+
warnings.warn(
55+
"When agent_path is specified, policy_cls and policy_kwargs are ignored.",
56+
RuntimeWarning,
57+
)
58+
policy = bc_algorithm.reconstruct_policy(agent_path)
59+
else:
60+
policy = policy_cls(**policy_kwargs)
4761
logger.info(f"Policy network summary:\n {policy}")
4862
return policy
4963

@@ -88,27 +102,31 @@ def train_imitation(
88102
bc_train_kwargs: Mapping[str, Any],
89103
dagger: Mapping[str, Any],
90104
use_dagger: bool,
105+
agent_path: Optional[str],
91106
) -> Mapping[str, Mapping[str, float]]:
92107
"""Runs DAgger (if `use_dagger`) or BC (otherwise) training.
93108
94109
Args:
95110
bc_kwargs: Keyword arguments passed through to `bc.BC` constructor.
96-
bc_train_kwargs: Keyword arguments passed through to `BC.train` method.
111+
bc_train_kwargs: Keyword arguments passed through to `BC.train()` method.
97112
dagger: Arguments for DAgger training.
98113
use_dagger: If True, train using DAgger; otherwise, use BC.
114+
agent_path: Path to serialized policy. If provided, then load the
115+
policy from this path. Otherwise, make a new policy.
116+
Specify only if policy_cls and policy_kwargs are not specified.
99117
100118
Returns:
101119
Statistics for rollouts from the trained policy and demonstration data.
102120
"""
103121
custom_logger, log_dir = common.setup_logging()
104122
venv = common.make_venv()
105-
imit_policy = make_policy(venv)
123+
imit_policy = make_policy(venv, agent_path=agent_path)
106124

107125
expert_trajs = None
108126
if not use_dagger or dagger["use_offline_rollouts"]:
109127
expert_trajs = demonstrations.load_expert_trajs()
110128

111-
bc_trainer = BC(
129+
bc_trainer = bc_algorithm.BC(
112130
observation_space=venv.observation_space,
113131
action_space=venv.action_space,
114132
policy=imit_policy,

tests/test_scripts.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def test_main_console(script_mod):
8888
_rl_agent_loading_configs = {
8989
"agent_path": CARTPOLE_TEST_POLICY_PATH,
9090
# FIXME(yawen): the policy we load was trained on 8 parallel environments
91-
# and for some reason using it breaks if we use just 1 (like would be the
92-
# default with the fast named_config)
91+
# and for some reason using it breaks if we use just 1 (like would be the
92+
# default with the fast named_config)
9393
"common": dict(num_vec=8),
9494
}
9595

@@ -232,6 +232,40 @@ def test_train_dagger_main(tmpdir):
232232
assert isinstance(run.result, dict)
233233

234234

235+
def test_train_dagger_warmstart(tmpdir):
236+
run = train_imitation.train_imitation_ex.run(
237+
command_name="dagger",
238+
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"],
239+
config_updates=dict(
240+
common=dict(log_root=tmpdir),
241+
demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH),
242+
dagger=dict(
243+
expert_policy_type="ppo",
244+
expert_policy_path=CARTPOLE_TEST_POLICY_PATH,
245+
),
246+
),
247+
)
248+
assert run.status == "COMPLETED"
249+
250+
log_dir = pathlib.Path(run.config["common"]["log_dir"])
251+
policy_path = log_dir / "scratch" / "policy-latest.pt"
252+
run_warmstart = train_imitation.train_imitation_ex.run(
253+
command_name="dagger",
254+
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"],
255+
config_updates=dict(
256+
common=dict(log_root=tmpdir),
257+
demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH),
258+
dagger=dict(
259+
expert_policy_type="ppo",
260+
expert_policy_path=CARTPOLE_TEST_POLICY_PATH,
261+
),
262+
agent_path=policy_path,
263+
),
264+
)
265+
assert run_warmstart.status == "COMPLETED"
266+
assert isinstance(run_warmstart.result, dict)
267+
268+
235269
def test_train_dagger_error_and_exceptions(tmpdir):
236270
with pytest.raises(Exception, match=".*expert_policy_path cannot be None.*"):
237271
train_imitation.train_imitation_ex.run(
@@ -261,6 +295,32 @@ def test_train_bc_main(tmpdir):
261295
assert isinstance(run.result, dict)
262296

263297

298+
def test_train_bc_warmstart(tmpdir):
299+
run = train_imitation.train_imitation_ex.run(
300+
command_name="bc",
301+
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"],
302+
config_updates=dict(
303+
common=dict(log_root=tmpdir),
304+
demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH),
305+
),
306+
)
307+
assert run.status == "COMPLETED"
308+
309+
policy_path = pathlib.Path(run.config["common"]["log_dir"]) / "final.th"
310+
run_warmstart = train_imitation.train_imitation_ex.run(
311+
command_name="bc",
312+
named_configs=["cartpole"] + ALGO_FAST_CONFIGS["imitation"],
313+
config_updates=dict(
314+
common=dict(log_root=tmpdir),
315+
demonstrations=dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH),
316+
agent_path=policy_path,
317+
),
318+
)
319+
320+
assert run_warmstart.status == "COMPLETED"
321+
assert isinstance(run_warmstart.result, dict)
322+
323+
264324
TRAIN_RL_PPO_CONFIGS = [{}, _rl_agent_loading_configs]
265325

266326

@@ -376,6 +436,35 @@ def test_train_adversarial(tmpdir, named_configs, command):
376436
_check_train_ex_result(run.result)
377437

378438

439+
@pytest.mark.parametrize("command", ("airl", "gail"))
440+
def test_train_adversarial_warmstart(tmpdir, command):
441+
named_configs = ["cartpole"] + ALGO_FAST_CONFIGS["adversarial"]
442+
config_updates = {
443+
"common": dict(log_root=tmpdir),
444+
"demonstrations": dict(rollout_path=CARTPOLE_TEST_ROLLOUT_PATH),
445+
}
446+
run = train_adversarial.train_adversarial_ex.run(
447+
command_name=command,
448+
named_configs=named_configs,
449+
config_updates=config_updates,
450+
)
451+
452+
log_dir = pathlib.Path(run.config["common"]["log_dir"])
453+
policy_path = log_dir / "checkpoints" / "final" / "gen_policy"
454+
455+
run_warmstart = train_adversarial.train_adversarial_ex.run(
456+
command_name=command,
457+
named_configs=named_configs,
458+
config_updates={
459+
"agent_path": policy_path,
460+
**config_updates,
461+
},
462+
)
463+
464+
assert run_warmstart.status == "COMPLETED"
465+
_check_train_ex_result(run_warmstart.result)
466+
467+
379468
@pytest.mark.parametrize("command", ("airl", "gail"))
380469
def test_train_adversarial_sac(tmpdir, command):
381470
"""Smoke test for imitation.scripts.train_adversarial."""

0 commit comments

Comments
 (0)