Skip to content

Commit 31d79d4

Browse files
committed
lint
1 parent 5ab6578 commit 31d79d4

23 files changed

+179
-75
lines changed

src/imitation/algorithms/adversarial.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,11 @@ def gen_policy(self) -> base_class.BaseRLModel:
180180
return self._gen_policy
181181

182182
def _gen_log_action_prob_from_unnormalized(
183-
self, observation: np.ndarray, *, actions: np.ndarray, logp=True,
183+
self,
184+
observation: np.ndarray,
185+
*,
186+
actions: np.ndarray,
187+
logp=True,
184188
) -> np.ndarray:
185189
"""Calculate generator log action probabilility.
186190
@@ -306,7 +310,9 @@ def train_gen(
306310
self._gen_replay_buffer.store(gen_samples)
307311

308312
def train(
309-
self, total_timesteps: int, callback: Optional[Callable[[int], None]] = None,
313+
self,
314+
total_timesteps: int,
315+
callback: Optional[Callable[[int], None]] = None,
310316
) -> None:
311317
"""Alternates between training the generator and discriminator.
312318

src/imitation/algorithms/bc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def __init__(
6767
policy_class: Type[ActorCriticPolicy] = FeedForward32Policy,
6868
policy_kwargs: Optional[Mapping[str, Any]] = None,
6969
expert_data: Union[
70-
types.TransitionsMinimal, datasets.Dataset[types.TransitionsMinimal], None,
70+
types.TransitionsMinimal,
71+
datasets.Dataset[types.TransitionsMinimal],
72+
None,
7173
] = None,
7274
batch_size: int = 32,
7375
optimizer_cls: Type[tf.train.Optimizer] = tf.train.AdamOptimizer,
@@ -122,7 +124,8 @@ def __init__(
122124
def set_expert_dataset(
123125
self,
124126
expert_data: Union[
125-
types.TransitionsMinimal, datasets.Dataset[types.TransitionsMinimal],
127+
types.TransitionsMinimal,
128+
datasets.Dataset[types.TransitionsMinimal],
126129
],
127130
):
128131
"""Replace the current expert dataset with a new one.
@@ -253,7 +256,8 @@ def save_policy(self, policy_path: str):
253256

254257
@staticmethod
255258
def reconstruct_policy(
256-
policy_path: str, sess: Optional[tf.Session] = None,
259+
policy_path: str,
260+
sess: Optional[tf.Session] = None,
257261
) -> BasePolicy:
258262
"""Reconstruct a saved policy.
259263

src/imitation/algorithms/dagger.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def schedule(i: int) -> float:
4141
return schedule
4242

4343

44-
def _save_trajectory(npz_path: str, trajectory: types.Trajectory,) -> None:
44+
def _save_trajectory(
45+
npz_path: str,
46+
trajectory: types.Trajectory,
47+
) -> None:
4548
"""Save a trajectory as a compressed Numpy file."""
4649
save_dir = os.path.dirname(npz_path)
4750
if save_dir:

src/imitation/algorithms/density_baselines.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,20 +224,20 @@ def __init__(
224224
is_stationary: bool = False,
225225
):
226226
r"""Family of simple imitation learning baseline algorithms that apply RL to
227-
maximise a rough density estimate of the demonstration trajectories.
228-
Specifically, it constructs a non-parametric estimate of `p(s)`, `p(s,s')`,
229-
`p_t(s,a)`, etc. (depending on options), then rewards the imitation learner
230-
with `r_t(s,a,s')=\log p_t(s,a,s')` (or `\log p(s,s')`, or whatever the
231-
user wants the model to condition on).
227+
maximise a rough density estimate of the demonstration trajectories.
228+
Specifically, it constructs a non-parametric estimate of `p(s)`, `p(s,s')`,
229+
`p_t(s,a)`, etc. (depending on options), then rewards the imitation learner
230+
with `r_t(s,a,s')=\log p_t(s,a,s')` (or `\log p(s,s')`, or whatever the
231+
user wants the model to condition on).
232232
233-
Args:
234-
venv: environment to train on.
235-
rollouts: list of expert trajectories to imitate.
236-
imitation_trainer: RL algorithm & initial policy that will
237-
be used to train the imitation learner.
238-
kernel, kernel_bandwidth, density_type, is_stationary,
239-
n_expert_trajectories: these are passed directly to `DensityReward`;
240-
refer to documentation for that class."""
233+
Args:
234+
venv: environment to train on.
235+
rollouts: list of expert trajectories to imitate.
236+
imitation_trainer: RL algorithm & initial policy that will
237+
be used to train the imitation learner.
238+
kernel, kernel_bandwidth, density_type, is_stationary,
239+
n_expert_trajectories: these are passed directly to `DensityReward`;
240+
refer to documentation for that class."""
241241
self.venv = venv
242242
self.imitation_trainer = imitation_trainer
243243
self.reward_fn = DensityReward(

src/imitation/data/buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def from_data(
331331
obs_shape=obs_shape,
332332
act_shape=act_shape,
333333
obs_dtype=transitions.obs.dtype,
334-
act_dtype=transitions.acts.dtype,
334+
act_dtype=transitions.acts.dtype, # pytype: disable=wrong-arg-types
335335
)
336336
instance.store(transitions, truncate_ok=truncate_ok)
337337
return instance

src/imitation/data/rollout.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def f(trajectories: Sequence[types.TrajectoryWithRew]):
187187

188188

189189
def make_sample_until(
190-
n_timesteps: Optional[int], n_episodes: Optional[int],
190+
n_timesteps: Optional[int],
191+
n_episodes: Optional[int],
191192
) -> GenTrajTerminationFn:
192193
"""Returns a termination condition sampling until n_timesteps or n_episodes.
193194

src/imitation/envs/examples/airl_envs/dynamic_mjc/model_builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ def __init__(self, name):
4848
@contextmanager
4949
def asfile(self):
5050
"""Usage:
51-
model = MJCModel('reacher')
52-
with model.asfile() as f:
53-
print f.read() # prints a dump of the model
54-
"""
51+
model = MJCModel('reacher')
52+
with model.asfile() as f:
53+
print f.read() # prints a dump of the model
54+
"""
5555
with tempfile.NamedTemporaryFile(mode="w+", suffix=".xml", delete=True) as f:
5656
self.root.write(f)
5757
f.seek(0)

src/imitation/envs/examples/airl_envs/utils.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
def flat_to_one_hot(val, ndim):
55
"""
66
7-
>>> flat_to_one_hot(2, ndim=4)
8-
array([ 0., 0., 1., 0.])
9-
>>> flat_to_one_hot(4, ndim=5)
10-
array([ 0., 0., 0., 0., 1.])
11-
>>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)
12-
array([[ 0., 0., 1., 0., 0.],
13-
[ 0., 0., 0., 0., 1.],
14-
[ 0., 0., 0., 1., 0.]])
15-
"""
7+
>>> flat_to_one_hot(2, ndim=4)
8+
array([ 0., 0., 1., 0.])
9+
>>> flat_to_one_hot(4, ndim=5)
10+
array([ 0., 0., 0., 0., 1.])
11+
>>> flat_to_one_hot(np.array([2, 4, 3]), ndim=5)
12+
array([[ 0., 0., 1., 0., 0.],
13+
[ 0., 0., 0., 0., 1.],
14+
[ 0., 0., 0., 1., 0.]])
15+
"""
1616
shape = np.array(val).shape
1717
v = np.zeros(shape + (ndim,))
1818
if len(shape) == 1:
@@ -24,13 +24,13 @@ def flat_to_one_hot(val, ndim):
2424

2525
def one_hot_to_flat(val):
2626
"""
27-
>>> one_hot_to_flat(np.array([0,0,0,0,1]))
28-
4
29-
>>> one_hot_to_flat(np.array([0,0,1,0]))
30-
2
31-
>>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))
32-
array([2, 0, 1])
33-
"""
27+
>>> one_hot_to_flat(np.array([0,0,0,0,1]))
28+
4
29+
>>> one_hot_to_flat(np.array([0,0,1,0]))
30+
2
31+
>>> one_hot_to_flat(np.array([[0,0,1,0], [1,0,0,0], [0,1,0,0]]))
32+
array([2, 0, 1])
33+
"""
3434
idxs = np.array(np.where(val == 1.0))[-1]
3535
if len(val.shape) == 1:
3636
return int(idxs)

src/imitation/policies/serialize.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,15 @@ def f(path: str, venv: VecEnv) -> Iterator[BasePolicy]:
100100

101101
policy_registry.register(
102102
"random",
103-
value=registry.build_loader_fn_require_space(registry.dummy_context(RandomPolicy),),
103+
value=registry.build_loader_fn_require_space(
104+
registry.dummy_context(RandomPolicy),
105+
),
104106
)
105107
policy_registry.register(
106108
"zero",
107-
value=registry.build_loader_fn_require_space(registry.dummy_context(ZeroPolicy),),
109+
value=registry.build_loader_fn_require_space(
110+
registry.dummy_context(ZeroPolicy),
111+
),
108112
)
109113

110114

@@ -142,7 +146,9 @@ def load_policy(
142146

143147

144148
def save_stable_model(
145-
output_dir: str, model: BaseRLModel, vec_normalize: Optional[VecNormalize] = None,
149+
output_dir: str,
150+
model: BaseRLModel,
151+
vec_normalize: Optional[VecNormalize] = None,
146152
) -> None:
147153
"""Serialize policy.
148154

src/imitation/rewards/discrim_net.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,11 @@ def reward_train(
231231
return rew
232232

233233
def reward_test(
234-
self, obs: np.ndarray, act: np.ndarray, next_obs: np.ndarray, dones: np.ndarray,
234+
self,
235+
obs: np.ndarray,
236+
act: np.ndarray,
237+
next_obs: np.ndarray,
238+
dones: np.ndarray,
235239
) -> np.ndarray:
236240
"""Vectorized reward for training an expert during transfer learning.
237241

src/imitation/rewards/reward_net.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,9 @@ def reward_output_train(self):
217217

218218
@abstractmethod
219219
def build_phi_network(
220-
self, obs_input: tf.Tensor, next_obs_input: tf.Tensor,
220+
self,
221+
obs_input: tf.Tensor,
222+
next_obs_input: tf.Tensor,
221223
) -> Tuple[tf.Tensor, tf.Tensor, networks.LayersDict]:
222224
"""Build the reward shaping network (disentangles dynamics from reward).
223225
@@ -373,7 +375,10 @@ def __init__(
373375
self.theta_kwargs = theta_kwargs or {}
374376
self.phi_kwargs = phi_kwargs or {}
375377
RewardNetShaped.__init__(
376-
self, observation_space, action_space, **kwargs,
378+
self,
379+
observation_space,
380+
action_space,
381+
**kwargs,
377382
)
378383
serialize.LayersSerializable.__init__(**params, layers=self._layers)
379384

src/imitation/scripts/config/expert_demos.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def hopper():
9898
@expert_demos_ex.named_config
9999
def humanoid():
100100
env_name = "Humanoid-v2"
101-
init_rl_kwargs = dict(n_steps=2048,) # batch size of 2048*8=16384 due to num_vec
101+
init_rl_kwargs = dict(
102+
n_steps=2048,
103+
) # batch size of 2048*8=16384 due to num_vec
102104
total_timesteps = int(10e6) # fairly discontinuous, needs at least 5e6
103105

104106

@@ -160,7 +162,9 @@ def fast():
160162
# Shared settings
161163

162164
ant_shared_locals = dict(
163-
init_rl_kwargs=dict(n_steps=2048,), # batch size of 2048*8=16384 due to num_vec
165+
init_rl_kwargs=dict(
166+
n_steps=2048,
167+
), # batch size of 2048*8=16384 due to num_vec
164168
total_timesteps=int(5e6),
165169
max_episode_steps=500, # To match `inverse_rl` settings.
166170
)

src/imitation/scripts/config/train_adversarial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def train_defaults():
5050

5151
# Modifies the __init__ arguments for the imitation policy
5252
init_rl_kwargs = dict(
53-
policy_class=base.FeedForward32Policy, **DEFAULT_INIT_RL_KWARGS,
53+
policy_class=base.FeedForward32Policy,
54+
**DEFAULT_INIT_RL_KWARGS,
5455
)
5556
gen_batch_size = 2048 # Batch size for generator updates
5657

src/imitation/scripts/parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def parallel(
8989
base_config_updates["data_dir"] = data_dir
9090

9191
trainable = _ray_tune_sacred_wrapper(
92-
sacred_ex_name, run_name, base_named_configs, base_config_updates,
92+
sacred_ex_name,
93+
run_name,
94+
base_named_configs,
95+
base_config_updates,
9396
)
9497

9598
# Disable all Ray Loggers.

src/imitation/scripts/train_adversarial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ def train(
162162
algorithm_kwargs_shared = algorithm_kwargs.get("shared", {})
163163
algorithm_kwargs_algo = algorithm_kwargs.get(algorithm, {})
164164
final_algorithm_kwargs = dict(
165-
**algorithm_kwargs_shared, **algorithm_kwargs_algo,
165+
**algorithm_kwargs_shared,
166+
**algorithm_kwargs_algo,
166167
)
167168

168169
if algorithm.lower() == "gail":

src/imitation/util/logger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77

88
def _build_output_formats(
9-
folder: str, format_strs: Sequence[str] = None,
9+
folder: str,
10+
format_strs: Sequence[str] = None,
1011
) -> Sequence[sb_logger.KVWriter]:
1112
"""Build output formats for initializing a Stable Baselines Logger.
1213

tests/test_buffer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,12 @@ def _check_buf(buf):
230230
rews = np.array([0.5, 1.0], dtype=float)
231231
buf_rew = ReplayBuffer.from_data(
232232
types.TransitionsWithRew(
233-
obs=obs, acts=acts, next_obs=next_obs, rews=rews, dones=dones, infos=infos,
233+
obs=obs,
234+
acts=acts,
235+
next_obs=next_obs,
236+
rews=rews,
237+
dones=dones,
238+
infos=infos,
234239
)
235240
)
236241
_check_buf(buf_rew)

tests/test_buffering_wrapper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def step(self, action):
4646
return t, t * 10, done, {}
4747

4848

49-
def _make_buffering_venv(error_on_premature_reset: bool,) -> BufferingWrapper:
49+
def _make_buffering_venv(
50+
error_on_premature_reset: bool,
51+
) -> BufferingWrapper:
5052
venv = DummyVecEnv([_CountingEnv] * 2)
5153
venv = BufferingWrapper(venv, error_on_premature_reset)
5254
venv.reset()
@@ -73,7 +75,12 @@ def concat(x):
7375
dones = concat(t.dones for t in trans_list)
7476
infos = concat(t.infos for t in trans_list)
7577
return types.TransitionsWithRew(
76-
obs=obs, next_obs=next_obs, rews=rews, acts=acts, dones=dones, infos=infos,
78+
obs=obs,
79+
next_obs=next_obs,
80+
rews=rews,
81+
acts=acts,
82+
dones=dones,
83+
infos=infos,
7784
)
7885

7986

tests/test_data.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def test_valid_trajectories(
102102
assert len(traj) == length
103103

104104
def test_invalid_trajectories(
105-
self, trajectory: types.Trajectory, trajectory_rew: types.TrajectoryWithRew,
105+
self,
106+
trajectory: types.Trajectory,
107+
trajectory_rew: types.TrajectoryWithRew,
106108
) -> None:
107109
"""Checks input validation catches space and dtype related errors."""
108110
trajs = [trajectory, trajectory_rew]
@@ -304,7 +306,13 @@ def test_dict_dataset_parallel_rows(
304306
Nontrivially, shuffled datasets should maintain this order.
305307
"""
306308
dataset_cls, kwargs = dict_dataset_params
307-
range_data_map = {k: i + np.arange(50,) for i, k in enumerate("abcd")}
309+
range_data_map = {
310+
k: i
311+
+ np.arange(
312+
50,
313+
)
314+
for i, k in enumerate("abcd")
315+
}
308316
dict_dataset = dataset_cls(range_data_map, **kwargs)
309317
for _ in range(n_checks):
310318
n_samples = np.random.randint(max_batch_size) + 1
@@ -333,7 +341,11 @@ def arange_dataset(self, shuffle, dataset_size):
333341
return ds
334342

335343
def test_epoch_order_dict_dataset_shuffle_order(
336-
self, arange_dataset, shuffle, dataset_size, n_checks=3,
344+
self,
345+
arange_dataset,
346+
shuffle,
347+
dataset_size,
348+
n_checks=3,
337349
):
338350
"""Check that epoch order is deterministic iff not shuffled.
339351
@@ -350,7 +362,10 @@ def test_epoch_order_dict_dataset_shuffle_order(
350362
assert same_order != shuffle
351363

352364
def test_epoch_order_dict_dataset_order_property(
353-
self, arange_dataset, max_batch_size=31, n_epochs=4,
365+
self,
366+
arange_dataset,
367+
max_batch_size=31,
368+
n_epochs=4,
354369
):
355370
"""No sample should be returned n+1 times until others are returned n times."""
356371
counter = collections.Counter({i: 0 for i in range(arange_dataset.size())})

0 commit comments

Comments
 (0)