Skip to content

Commit 8497de6

Browse files
Merge pull request #4 from pyrddlgym-project/config-reader
Config reader
2 parents 1dc19c2 + 3a62580 commit 8497de6

File tree

4 files changed

+95
-15
lines changed

4 files changed

+95
-15
lines changed

pyRDDLGym_gurobi/core/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(self, rddl: RDDLLiftedModel,
119119

120120
def summarize_hyperparameters(self) -> None:
121121
print(f'Gurobi compiler hyper-params:\n'
122+
f' plan ={type(self.plan).__name__}\n'
122123
f' float_range ={self.float_range}\n'
123124
f' float_equality_tol={self.epsilon}\n'
124125
f' lookahead_horizon ={self.horizon}\n'

pyRDDLGym_gurobi/core/planner.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1+
from ast import literal_eval
2+
import configparser
3+
import os
4+
import sys
15
from typing import Any, Dict, List, Tuple, Optional
26

7+
Kwargs = Dict[str, Any]
8+
39
import gurobipy
410
from gurobipy import GRB
511

@@ -11,6 +17,70 @@
1117

1218
UNBOUNDED = (-GRB.INFINITY, +GRB.INFINITY)
1319

20+
# ***********************************************************************
21+
# CONFIG FILE MANAGEMENT
22+
#
23+
# - read config files from file path
24+
# - extract experiment settings
25+
# - instantiate planner
26+
#
27+
# ***********************************************************************
28+
29+
30+
def _parse_config_file(path: str):
31+
if not os.path.isfile(path):
32+
raise FileNotFoundError(f'File {path} does not exist.')
33+
config = configparser.RawConfigParser()
34+
config.optionxform = str
35+
config.read(path)
36+
args = {k: literal_eval(v)
37+
for section in config.sections()
38+
for (k, v) in config.items(section)}
39+
return config, args
40+
41+
42+
def _parse_config_string(value: str):
43+
config = configparser.RawConfigParser()
44+
config.optionxform = str
45+
config.read_string(value)
46+
args = {k: literal_eval(v)
47+
for section in config.sections()
48+
for (k, v) in config.items(section)}
49+
return config, args
50+
51+
52+
def _getattr_any(packages, item):
53+
for package in packages:
54+
loaded = getattr(package, item, None)
55+
if loaded is not None:
56+
return loaded
57+
return None
58+
59+
60+
def _load_config(config, args):
61+
gurobi_args = {k: args[k] for (k, _) in config.items('Gurobi')}
62+
compiler_args = {k: args[k] for (k, _) in config.items('Optimizer')}
63+
64+
# policy class
65+
plan_method = compiler_args.pop('method')
66+
plan_kwargs = compiler_args.pop('method_kwargs', {})
67+
compiler_args['plan'] = getattr(sys.modules[__name__], plan_method)(**plan_kwargs)
68+
compiler_args['model_params'] = gurobi_args
69+
70+
return compiler_args
71+
72+
73+
def load_config(path: str) -> Kwargs:
74+
'''Loads a config file at the specified file path.'''
75+
config, args = _parse_config_file(path)
76+
return _load_config(config, args)
77+
78+
79+
def load_config_from_string(value: str) -> Kwargs:
80+
'''Loads config file contents specified explicitly as a string value.'''
81+
config, args = _parse_config_string(value)
82+
return _load_config(config, args)
83+
1484

1585
# ***********************************************************************
1686
# ALL VERSIONS OF GUROBI PLANS
@@ -237,7 +307,7 @@ def params(self, compiled: GurobiRDDLCompiler,
237307
lb_name = f'lb__{action}__{k}'
238308
ub_name = f'ub__{action}__{k}'
239309
if values is None:
240-
lb, ub = self.state_bounds[states[0]]
310+
lb, ub = self.state_bounds.get(states[0], UNBOUNDED)
241311
var_bounds = UNBOUNDED if is_linear else (lb - 1, ub + 1)
242312
lb_var = compiled._add_var(model, vtype, *var_bounds)
243313
ub_var = compiled._add_var(model, vtype, *var_bounds)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[Gurobi]
2+
NonConvex=2
3+
OutputFlag=0
4+
5+
[Optimizer]
6+
method='GurobiStraightLinePlan'
7+
method_kwargs={}
8+
rollout_horizon=5
9+
verbose=1

pyRDDLGym_gurobi/examples/run_plan.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,35 +9,35 @@
99
<instance> is the instance number
1010
<horizon> is a positive integer representing the lookahead horizon
1111
'''
12+
import os
1213
import sys
1314

1415
import pyRDDLGym
15-
from pyRDDLGym_gurobi.core.planner import (
16-
GurobiStraightLinePlan, GurobiOnlineController
17-
)
16+
from pyRDDLGym_gurobi.core.planner import GurobiOnlineController, load_config
1817

1918

20-
def main(domain, instance, horizon):
19+
def main(domain, instance):
2120

2221
# create the environment
2322
env = pyRDDLGym.make(domain, instance, enforce_action_constraints=True)
2423

25-
# create the controller
26-
controller = GurobiOnlineController(rddl=env.model,
27-
plan=GurobiStraightLinePlan(),
28-
rollout_horizon=horizon,
29-
model_params={'NonConvex': 2, 'OutputFlag': 1})
24+
# load the config
25+
abs_path = os.path.dirname(os.path.abspath(__file__))
26+
config_path = os.path.join(abs_path, 'default.cfg')
27+
controller_kwargs = load_config(config_path)
28+
29+
# create the controller
30+
controller = GurobiOnlineController(rddl=env.model, **controller_kwargs)
3031
controller.evaluate(env, verbose=True, render=True)
3132

3233
env.close()
3334

3435

3536
if __name__ == "__main__":
3637
args = sys.argv[1:]
37-
if len(args) < 3:
38-
print('python run_plan.py <domain> <instance> <horizon>')
38+
if len(args) < 2:
39+
print('python run_plan.py <domain> <instance>')
3940
exit(1)
40-
domain, instance, horizon = args[:3]
41-
horizon = int(horizon)
42-
main(domain, instance, horizon)
41+
domain, instance = args[:2]
42+
main(domain, instance)
4343

0 commit comments

Comments
 (0)