Skip to content

Commit eb43a42

Browse files
authored
ENH: Config syntax for importing objects (#463)
1 parent f9940ed commit eb43a42

File tree

5 files changed

+26
-16
lines changed

5 files changed

+26
-16
lines changed

examples/galaxies/sdss_photometry.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cosmology: !astropy.cosmology.default_cosmology.get
1+
cosmology: !astropy.cosmology.default_cosmology.get []
22
z_range: !numpy.linspace [0, 2, 21]
33
M_star: !astropy.modeling.models.Linear1D [-0.9408582, -20.40492365]
44
phi_star: !astropy.modeling.models.Exponential1D [0.00370253, -9.73858]

skypy/pipeline/_config.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,23 @@ def construct_ref(self, node):
5252
return Ref(ref)
5353

5454
def construct_call(self, name, node):
55-
if isinstance(node, yaml.ScalarNode):
56-
arg = self.construct_scalar(node)
57-
args = [arg] if arg != '' else []
58-
kwargs = {}
59-
elif isinstance(node, yaml.SequenceNode):
60-
args = self.construct_sequence(node)
61-
kwargs = {}
62-
elif isinstance(node, yaml.MappingNode):
63-
args = []
64-
kwargs = self.construct_mapping(node)
65-
6655
try:
67-
function = import_function(name)
56+
object = import_function(name)
6857
except (ModuleNotFoundError, AttributeError) as e:
6958
raise ImportError(f'{e}\n{node.start_mark}') from e
7059

71-
return Call(function, args, kwargs)
60+
if isinstance(node, yaml.ScalarNode):
61+
if node.value:
62+
raise ValueError(f'{node.value}: ScalarNode should be empty to import an object')
63+
return object
64+
else:
65+
if isinstance(node, yaml.SequenceNode):
66+
args = self.construct_sequence(node)
67+
kwargs = {}
68+
if isinstance(node, yaml.MappingNode):
69+
args = []
70+
kwargs = self.construct_mapping(node)
71+
return Call(object, args, kwargs)
7272

7373
def construct_quantity(self, node):
7474
value = self.construct_scalar(node)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
bad_object: !astropy.cosmology.Planck15 "bad value"

skypy/pipeline/tests/data/test_config.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
test_int: 1
22
test_float: 1.0
33
test_str: hello world
4-
test_func: !numpy.random.uniform
5-
test_func_with_arg: !len 'hello world'
4+
test_func: !numpy.random.uniform []
5+
test_func_with_arg: !len ['hello world']
6+
test_object: !astropy.cosmology.Planck15
67
cosmology: !astropy.cosmology.FlatLambdaCDM
78
H0: 67.74
89
Om0: 0.3075

skypy/pipeline/tests/test_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from skypy.pipeline import load_skypy_yaml
44
from skypy.pipeline._items import Call
55
from astropy import units
6+
from astropy.cosmology.core import Cosmology
67

78

89
def test_load_skypy_yaml():
@@ -18,6 +19,8 @@ def test_load_skypy_yaml():
1819
assert isinstance(config['test_float'], float)
1920
assert isinstance(config['test_str'], str)
2021
assert isinstance(config['test_func'], Call)
22+
assert isinstance(config['test_func_with_arg'], Call)
23+
assert isinstance(config['test_object'], Cosmology)
2124
assert isinstance(config['cosmology'], Call)
2225
assert isinstance(config['tables']['test_table_1']['test_column_3'], Call)
2326

@@ -31,6 +34,11 @@ def test_load_skypy_yaml():
3134
with pytest.raises(ImportError):
3235
load_skypy_yaml(filename)
3336

37+
# Bad object
38+
filename = get_pkg_data_filename('data/bad_object.yml')
39+
with pytest.raises(ValueError):
40+
load_skypy_yaml(filename)
41+
3442

3543
def test_empty_ref():
3644
filename = get_pkg_data_filename('data/test_empty_ref.yml')

0 commit comments

Comments
 (0)