Skip to content

Commit 1d2b226

Browse files
committed
changes for Cobaya 2.0 compatibility
1 parent 0ebc06f commit 1d2b226

File tree

8 files changed

+159
-118
lines changed

8 files changed

+159
-118
lines changed

getdist/chains.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def slice_or_none(x, start=None, end=None):
4040
return getattr(x, "__getitem__", lambda _: None)(slice(start, end))
4141

4242

43-
def chainFiles(root, chain_indices=None, ext='.txt', first_chain=0, last_chain=-1, chain_exclude=None):
43+
def chainFiles(root, chain_indices=None, ext='.txt', separator="_",
44+
first_chain=0, last_chain=-1, chain_exclude=None):
4445
"""
4546
Creates a list of file names for samples given a root name and optional filters
4647
@@ -59,8 +60,8 @@ def chainFiles(root, chain_indices=None, ext='.txt', first_chain=0, last_chain=-
5960
fname = root
6061
if index > 0:
6162
# deal with just-folder prefix
62-
if not root.endswith("/"):
63-
fname += '_'
63+
if not root.endswith((os.sep, "/")):
64+
fname += separator
6465
fname += str(index)
6566
if not fname.endswith(ext): fname += ext
6667
if index > first_chain and not os.path.exists(fname) or 0 < last_chain < index: break
@@ -849,17 +850,22 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab
849850
:param kwargs: extra options for :class:`~.chains.WeightedSamples`'s constructor
850851
851852
"""
853+
from getdist.cobaya_interface import get_sampler_type, _separator_files
854+
852855
self.chains = None
853856
WeightedSamples.__init__(self, **kwargs)
854857
self.jobItem = jobItem
855858
self.ignore_lines = float(kwargs.get('ignore_rows', 0))
856859
self.root = root
857860
if not paramNamesFile and root:
858-
mid = ('' if root.endswith("/") else "__")
859-
if os.path.exists(root + '.paramnames'):
860-
paramNamesFile = root + '.paramnames'
861-
elif os.path.exists(root + mid + 'full.yaml'):
862-
paramNamesFile = root + mid + 'full.yaml'
861+
mid = not root.endswith((os.sep, "/"))
862+
endings = ['.paramnames', ('__' if mid else '') + 'full.yaml',
863+
(_separator_files if mid else '') + 'updated.yaml']
864+
try:
865+
paramNamesFile = next(
866+
root + ending for ending in endings if os.path.exists(root + ending))
867+
except StopIteration:
868+
paramNamesFile = None
863869
self.setParamNames(paramNamesFile or names)
864870
if labels is not None:
865871
self.paramNames.setLabels(labels)
@@ -871,7 +877,6 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab
871877
raise ValueError("Unknown sampler type %s" % sampler)
872878
self.sampler = sampler.lower()
873879
elif isinstance(paramNamesFile, six.string_types) and paramNamesFile.endswith("yaml"):
874-
from getdist.yaml_format_tools import get_sampler_type
875880
self.sampler = get_sampler_type(paramNamesFile)
876881
else:
877882
self.sampler = "mcmc"

getdist/yaml_format_tools.py renamed to getdist/cobaya_interface.py

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
# JT 2017-18
1+
# JT 2017-19
22

33
from __future__ import division
44
from importlib import import_module
55
from six import string_types
66
from copy import deepcopy
7-
import re
87
from collections import OrderedDict as odict
98
import numpy as np
10-
import yaml
9+
1110

1211
# Conventions
1312
_prior = "prior"
@@ -21,6 +20,7 @@
2120
_p_derived = "derived"
2221
_p_renames = "renames"
2322
_separator = "__"
23+
_separator_files = "."
2424
_minuslogprior = "minuslogprior"
2525
_prior_1d_name = "0"
2626
_chi2 = "chi2"
@@ -29,78 +29,6 @@
2929
_post = "post"
3030

3131

32-
# Exceptions
33-
class InputSyntaxError(Exception):
34-
"""Syntax error in YAML input."""
35-
36-
37-
# Better loader for YAML
38-
# 1. Matches 1e2 as 100 (no need for dot, or sign after e),
39-
# from http://stackoverflow.com/a/30462009
40-
# 2. Wrapper to load mappings as OrderedDict (for likelihoods and params),
41-
# from http://stackoverflow.com/a/21912744
42-
def yaml_load(text_stream, Loader=yaml.Loader, object_pairs_hook=odict, file_name=None):
43-
class OrderedLoader(Loader):
44-
pass
45-
46-
def construct_mapping(loader, node):
47-
loader.flatten_mapping(node)
48-
return object_pairs_hook(loader.construct_pairs(node))
49-
50-
OrderedLoader.add_constructor(
51-
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
52-
OrderedLoader.add_implicit_resolver(
53-
u'tag:yaml.org,2002:float',
54-
re.compile(u'''^(?:
55-
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
56-
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
57-
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
58-
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
59-
|[-+]?\\.(?:inf|Inf|INF)
60-
|\\.(?:nan|NaN|NAN))$''', re.X),
61-
list(u'-+0123456789.'))
62-
63-
# Ignore python objects
64-
def dummy_object_loader(loader, suffix, node):
65-
return None
66-
67-
OrderedLoader.add_multi_constructor(
68-
u'tag:yaml.org,2002:python/name:', dummy_object_loader)
69-
try:
70-
return yaml.load(text_stream, OrderedLoader)
71-
# Redefining the general exception to give more user-friendly information
72-
except yaml.YAMLError as exception:
73-
errstr = "Error in your input file " + ("'" + file_name + "'" if file_name else "")
74-
if hasattr(exception, "problem_mark"):
75-
line = 1 + exception.problem_mark.line
76-
column = 1 + exception.problem_mark.column
77-
signal = " --> "
78-
signal_right = " <---- "
79-
sep = "|"
80-
context = 4
81-
lines = text_stream.split("\n")
82-
pre = ((("\n" + " " * len(signal) + sep).join(
83-
[""] + lines[max(line - 1 - context, 0):line - 1]))) + "\n"
84-
errorline = (signal + sep + lines[line - 1] +
85-
signal_right + "column %s" % column)
86-
post = ((("\n" + " " * len(signal) + sep).join(
87-
[""] + lines[line + 1 - 1:min(line + 1 + context - 1, len(lines))]))) + "\n"
88-
raise InputSyntaxError(
89-
errstr + " at line %d, column %d." % (line, column) +
90-
pre + errorline + post +
91-
"Maybe inconsistent indentation, '=' instead of ':', "
92-
"no space after ':', or a missing ':' on an empty group?")
93-
else:
94-
raise InputSyntaxError(errstr)
95-
96-
97-
def yaml_load_file(input_file):
98-
"""Wrapper to load a yaml file."""
99-
with open(input_file, "r") as f:
100-
lines = "".join(f.readlines())
101-
return yaml_load(lines, file_name=input_file)
102-
103-
10432
def get_info_params(info):
10533
"""
10634
Extracts parameter info from the new yaml format.
@@ -207,7 +135,8 @@ def expand_info_param(info_param):
207135

208136
def get_sampler_type(filename_or_info):
209137
if isinstance(filename_or_info, string_types):
138+
from getdist.yaml_tools import yaml_load_file
210139
filename_or_info = yaml_load_file(filename_or_info)
211140
default_sampler_for_chain_type = "mcmc"
212141
sampler = list(filename_or_info.get(_sampler, [default_sampler_for_chain_type]))[0]
213-
return {"mcmc": "mcmc", "polychord": "nested"}[sampler]
142+
return {"mcmc": "mcmc", "polychord": "nested", "minimize": "minimize"}[sampler]

getdist/gui/mainwindow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,8 +1006,8 @@ def _updateComboBoxRootname(self, listOfRoots):
10061006
self.comboBoxRootname.clear()
10071007
self.listRoots.show()
10081008
self.pushButtonRemove.show()
1009-
baseRoots = [(os.path.basename(root) if not root.endswith("/")
1010-
else os.path.basename(root[:-1]) + "/")
1009+
baseRoots = [(os.path.basename(root) if not root.endswith((os.sep, "/"))
1010+
else os.path.basename(root[:-1]) + os.sep)
10111011
for root in listOfRoots]
10121012
self.comboBoxRootname.addItems(baseRoots)
10131013
if len(baseRoots) > 1:
@@ -1039,8 +1039,8 @@ def newRootItem(self, root):
10391039
else:
10401040
path = self.rootdirname
10411041
# new style, if the prefix is just a folder
1042-
if root[-1] == "/":
1043-
path = "/".join(path.split("/")[:-1])
1042+
if root[-1] in (os.sep, "/"):
1043+
path = os.sep.join(path.replace("/", os.sep).split(os.sep)[:-1])
10441044
info = plots.RootInfo(root, path, self.batch)
10451045
plotter.sampleAnalyser.addRoot(info)
10461046

getdist/mcsamples.py

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,22 @@ def loadMCSamples(file_root, ini=None, jobItem=None, no_cache=False, settings={}
6666
if settings and dist_settings: raise ValueError('Use settings or dist_settings')
6767
if dist_settings: settings = dist_settings
6868
files = chainFiles(file_root)
69+
if not files: # try new Cobaya format
70+
files = chainFiles(file_root, separator='.')
6971
path, name = os.path.split(file_root)
7072
path = getdist.cache_dir or path
7173
if not os.path.exists(path): os.mkdir(path)
7274
cachefile = os.path.join(path, name) + '.py_mcsamples'
7375
samples = MCSamples(file_root, jobItem=jobItem, ini=ini, settings=settings)
7476
if os.path.isfile(file_root + '.paramnames'):
7577
allfiles = files + [file_root + '.ranges', file_root + '.paramnames', file_root + '.properties.ini']
76-
else: # new format (txt+yaml)
77-
mid = "" if file_root.endswith("/") else "__"
78-
allfiles = files + [file_root + mid + ending for ending in ['input.yaml', 'full.yaml']]
78+
else: # Cobaya
79+
folder = os.path.dirname(file_root)
80+
prefix = os.path.basename(file_root)
81+
allfiles = files + [
82+
os.path.join(folder, f) for f in os.listdir(folder) if (
83+
f.startswith(prefix) and
84+
any([f.lower().endswith(end) for end in ['updated.yaml', 'full.yaml']]))]
7985
if not no_cache and os.path.exists(cachefile) and lastModified(allfiles) < os.path.getmtime(cachefile):
8086
try:
8187
with open(cachefile, 'rb') as inp:
@@ -95,12 +101,18 @@ def loadMCSamples(file_root, ini=None, jobItem=None, no_cache=False, settings={}
95101
return samples
96102

97103

98-
def loadCobayaSamples(info, collections, name_tag=None,
99-
ignore_rows=0, ini=None, settings={}):
104+
def loadCobayaSamples(*args, **kwargs):
105+
logging.warning("'loadCobayaSamples' will be deprecated in the future. "
106+
"Use 'MCSamplesFromCobaya' instead.")
107+
return MCSamplesFromCobaya(*args, **kwargs)
108+
109+
110+
def MCSamplesFromCobaya(info, collections, name_tag=None,
111+
ignore_rows=0, ini=None, settings={}):
100112
"""
101-
Loads a set of samples from Cobaya's output.
113+
Creates a set of samples from Cobaya's output.
102114
Parameter names, ranges and labels are taken from the "info" dictionary
103-
(always use the "full", updated one generated by `cobaya.run`).
115+
(always use the "updated" one generated by `cobaya.run`).
104116
105117
For a description of the various analysis settings and default values see
106118
`analysis_defaults.ini <http://getdist.readthedocs.org/en/latest/analysis_settings.html>`_.
@@ -114,6 +126,10 @@ def loadCobayaSamples(info, collections, name_tag=None,
114126
:param settings: dictionary of analysis settings to override defaults
115127
:return: The :class:`MCSamples` instance
116128
"""
129+
from getdist.cobaya_interface import _p_label, _p_renames, _weight, _minuslogpost
130+
from getdist.cobaya_interface import get_info_params, get_range, is_derived_param
131+
from getdist.cobaya_interface import get_sampler_type, _post
132+
117133
if not hasattr(info, "keys"):
118134
raise TypeError("Cannot regonise arguments. Are you sure you are calling "
119135
"with (info, collections, ...) in that order?")
@@ -127,9 +143,6 @@ def loadCobayaSamples(info, collections, name_tag=None,
127143
"The second argument does not appear to be a (list of) samples `Collection`.")
128144
if not all([list(c.data) == columns for c in collections[1:]]):
129145
raise ValueError("The given collections don't have the same columns.")
130-
from getdist.yaml_format_tools import _p_label, _p_renames, _weight, _minuslogpost
131-
from getdist.yaml_format_tools import get_info_params, get_range, is_derived_param
132-
from getdist.yaml_format_tools import get_sampler_type, _post
133146
# Check consistency with info
134147
info_params = get_info_params(info)
135148
# ####################################################################################
@@ -139,8 +152,8 @@ def loadCobayaSamples(info, collections, name_tag=None,
139152
thin = info.get(_post, {}).get("thin", 1)
140153
# Maybe warn if trying to ignore rows twice?
141154
if ignore_rows != 0 and skip != 0:
142-
logging.warn("You are asking for rows to be ignored (%r), but some (%r) were "
143-
"already ignored in the original chain.", ignore_rows, skip)
155+
logging.warning("You are asking for rows to be ignored (%r), but some (%r) were "
156+
"already ignored in the original chain.", ignore_rows, skip)
144157
# Should we warn about thin too?
145158
# Most importantly: do we want to save somewhere the fact that we have *already*
146159
# thinned/skipped?
@@ -182,7 +195,8 @@ class MCSamples(Chains):
182195
"""
183196
The main high-level class for a collection of parameter samples.
184197
185-
Derives from :class:`.chains.Chains`, adding high-level functions including Kernel Density estimates, parameter ranges and custom settings.
198+
Derives from :class:`.chains.Chains`, adding high-level functions including
199+
Kernel Density estimates, parameter ranges and custom settings.
186200
"""
187201

188202
def __init__(self, root=None, jobItem=None, ini=None, settings=None, ranges=None,
@@ -2083,10 +2097,15 @@ def _setLikeStats(self):
20832097

20842098
def _readRanges(self):
20852099
if self.root:
2100+
from getdist.cobaya_interface import _separator_files
20862101
ranges_file_classic = self.root + '.ranges'
2087-
ranges_file_new = (
2088-
self.root + ('' if self.root.endswith('/') else '__') + 'full.yaml')
2089-
for ranges_file in [ranges_file_classic, ranges_file_new]:
2102+
ranges_file_cobaya_old = (
2103+
self.root + ('' if self.root.endswith((os.sep, "/")) else '__') + 'full.yaml')
2104+
ranges_file_cobaya = (
2105+
self.root + (
2106+
'' if self.root.endswith((os.sep, "/")) else _separator_files) + 'updated.yaml')
2107+
for ranges_file in [
2108+
ranges_file_classic, ranges_file_cobaya_old, ranges_file_cobaya]:
20902109
if os.path.isfile(ranges_file):
20912110
self.ranges = ParamBounds(ranges_file)
20922111
return
@@ -2553,9 +2572,9 @@ def GetChainRootFiles(rootdir):
25532572
"""
25542573
pattern = os.path.join(rootdir, '*.paramnames')
25552574
files = [os.path.splitext(f)[0] for f in glob.glob(pattern)]
2556-
ending = 'full.yaml'
2557-
pattern = os.path.join(rootdir, "*" + ending)
2558-
files += [f[:-len(ending)].rstrip("_") for f in glob.glob(pattern)]
2575+
for ending in ['full.yaml', 'updated.yaml']:
2576+
pattern = os.path.join(rootdir, "*" + ending)
2577+
files += [f[:-len(ending)].rstrip("_.") for f in glob.glob(pattern)]
25592578
files.sort()
25602579
return files
25612580

getdist/paramnames.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,10 +378,9 @@ def loadFromFile(self, fileName):
378378
with open(fileName) as f:
379379
self.names = [ParamInfo(line) for line in [s.strip() for s in f] if line != '']
380380
elif extension.lower() in ('.yaml', '.yml'):
381-
from getdist.yaml_format_tools import yaml_load_file, get_info_params
382-
from getdist.yaml_format_tools import is_sampled_param, is_derived_param
383-
from getdist.yaml_format_tools import _p_label, _p_renames
384-
381+
from getdist.yaml_tools import yaml_load_file
382+
from getdist.cobaya_interface import get_info_params, is_sampled_param
383+
from getdist.cobaya_interface import is_derived_param, _p_label, _p_renames
385384
info_params = get_info_params(yaml_load_file(fileName))
386385
# first sampled, then derived
387386
self.names = [ParamInfo(name=param, label=(info or {}).get(_p_label, param),

getdist/parampriors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def loadFromFile(self, fileName):
3131
if len(strings) == 3:
3232
self.setRange(strings[0], strings[1:])
3333
elif extension in ('.yaml', '.yml'):
34-
from getdist.yaml_format_tools import yaml_load_file, get_info_params
35-
from getdist.yaml_format_tools import get_range, is_fixed_param
34+
from getdist.cobaya_interface import get_range, is_fixed_param, get_info_params
35+
from getdist.yaml_tools import yaml_load_file
3636
info_params = get_info_params(yaml_load_file(fileName))
3737
for p, info in info_params.items():
3838
if not is_fixed_param(info):

getdist/plots.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from paramgrid import gridconfig, batchjob
1515
import getdist
1616
from getdist import MCSamples, loadMCSamples, ParamNames, ParamInfo, IniFile
17+
from getdist.chains import chainFiles
1718
from getdist.paramnames import escapeLatex, makeList, mergeRenames
1819
from getdist.parampriors import ParamBounds
1920
from getdist.densities import Density1D, Density2D
@@ -446,8 +447,8 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None):
446447
if isinstance(root, MCSamples): return root
447448
if os.path.isabs(root):
448449
# deal with just-folder prefix
449-
if root.endswith("/"):
450-
root = os.path.basename(root[:-1]) + "/"
450+
if root.endswith((os.sep, "/")):
451+
root = os.path.basename(root[:-1]) + os.sep
451452
else:
452453
root = os.path.basename(root)
453454
if root in self.mcsamples and cache: return self.mcsamples[root]
@@ -457,6 +458,7 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None):
457458
else:
458459
dist_settings = {}
459460
if not file_root:
461+
from getdist.cobaya_interface import _separator_files
460462
for chain_dir in self.chain_dirs:
461463
if hasattr(chain_dir, "resolveRoot"):
462464
jobItem = chain_dir.resolveRoot(root)
@@ -468,7 +470,8 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None):
468470
break
469471
else:
470472
name = os.path.join(chain_dir, root)
471-
if os.path.exists(name + '_1.txt') or os.path.exists(name + '.txt'):
473+
if any([chainFiles(name, separator=sep)
474+
for sep in ['_', _separator_files]]):
472475
file_root = name
473476
break
474477
if not file_root:
@@ -1944,7 +1947,7 @@ def triangle_plot(self, roots, params=None, legend_labels=None, plot_3d_with_par
19441947
:param title_limit:if not None, a maginalized limit (1,2..) to print as the title of the first root on the diagonal 1D plots
19451948
:param upper_kwargs: dict for same-named arguments for use when making upper-triangle 2D plots (contour_colors, etc). Set show_1d=False to not add to the diagonal.
19461949
:param diag1d_kwargs: list of dict for arguments when making 1D plots on grid diagonal
1947-
:param markers: optional dict giving marker values indexed by parameter, or a list of marker values for each parameter plotted
1950+
:param markers: optional dict giving marker values indexed by parameter, or a list of marker values for each parameter plotted
19481951
:param param_limits: a dictionary holding a mapping from parameter names to axis limits for that parameter
19491952
:param kwargs: optional keyword arguments for :func:`~GetDistPlotter.plot_2d` or :func:`~GetDistPlotter.plot_3d` (lower triangle only)
19501953

0 commit comments

Comments
 (0)