|
7 | 7 | # granted to it by virtue of its status as an intergovernmental organisation |
8 | 8 | # nor does it submit to any jurisdiction. |
9 | 9 |
|
10 | | - |
11 | | -import json |
12 | | -import logging |
13 | | -import os |
14 | | -import subprocess |
15 | | -from argparse import ArgumentParser |
16 | | -from argparse import Namespace |
17 | | -from tempfile import TemporaryDirectory |
18 | | -from typing import Any |
19 | | -from typing import Dict |
20 | | - |
21 | | -import yaml |
22 | | - |
23 | | -from . import Command |
24 | | - |
25 | | -LOG = logging.getLogger(__name__) |
26 | | - |
27 | | -EDITOR_OPTIONS = {"code": ["--wait"]} |
28 | | - |
29 | | - |
30 | | -class Metadata(Command): |
31 | | - """Edit, remove, dump or load metadata from a checkpoint file.""" |
32 | | - |
33 | | - def add_arguments(self, command_parser: ArgumentParser) -> None: |
34 | | - """Add command line arguments to the parser. |
35 | | -
|
36 | | - Parameters |
37 | | - ---------- |
38 | | - command_parser : ArgumentParser |
39 | | - The argument parser to which the arguments will be added. |
40 | | - """ |
41 | | - from anemoi.utils.checkpoints import DEFAULT_NAME |
42 | | - |
43 | | - command_parser.add_argument("path", help="Path to the checkpoint.") |
44 | | - |
45 | | - group = command_parser.add_mutually_exclusive_group(required=True) |
46 | | - |
47 | | - group.add_argument( |
48 | | - "--dump", |
49 | | - action="store_true", |
50 | | - help=( |
51 | | - "Extract the metadata from the checkpoint and print it to the standard output" |
52 | | - " or the file specified by ``--output``, in JSON or YAML format." |
53 | | - ), |
54 | | - ) |
55 | | - group.add_argument( |
56 | | - "--load", |
57 | | - action="store_true", |
58 | | - help=( |
59 | | - "Set the metadata in the checkpoint from the content" |
60 | | - " of a file specified by the ``--input`` argument." |
61 | | - ), |
62 | | - ) |
63 | | - |
64 | | - group.add_argument( |
65 | | - "--edit", |
66 | | - action="store_true", |
67 | | - help=( |
68 | | - "Edit the metadata in place, using the specified editor." |
69 | | - " See the ``--editor`` argument for more information." |
70 | | - ), |
71 | | - ) |
72 | | - |
73 | | - group.add_argument( |
74 | | - "--view", |
75 | | - action="store_true", |
76 | | - help=( |
77 | | - "View the metadata in place, using the specified pager." |
78 | | - " See the ``--pager`` argument for more information." |
79 | | - ), |
80 | | - ) |
81 | | - |
82 | | - group.add_argument( |
83 | | - "--remove", |
84 | | - action="store_true", |
85 | | - help="Remove the metadata from the checkpoint.", |
86 | | - ) |
87 | | - |
88 | | - group.add_argument( |
89 | | - "--supporting-arrays", |
90 | | - action="store_true", |
91 | | - help="Print the supporting arrays.", |
92 | | - ) |
93 | | - |
94 | | - group.add_argument( |
95 | | - "--get", |
96 | | - help="Navigate the metadata via dot-separated path.", |
97 | | - ) |
98 | | - |
99 | | - group.add_argument( |
100 | | - "--pytest", |
101 | | - action="store_true", |
102 | | - help=("Extract the metadata from the checkpoint so it can be added to the test suite."), |
103 | | - ) |
104 | | - |
105 | | - command_parser.add_argument( |
106 | | - "--name", |
107 | | - default=DEFAULT_NAME, |
108 | | - help="Name of metadata record to be used with the actions above.", |
109 | | - ) |
110 | | - |
111 | | - command_parser.add_argument( |
112 | | - "--input", |
113 | | - help="The output file name to be used by the ``--load`` option.", |
114 | | - ) |
115 | | - |
116 | | - command_parser.add_argument( |
117 | | - "--output", |
118 | | - help="The output file name to be used by the ``--dump`` option.", |
119 | | - ) |
120 | | - |
121 | | - command_parser.add_argument( |
122 | | - "--editor", |
123 | | - help="Editor to use for the ``--edit`` option. Default to ``$EDITOR`` if defined, else ``vi``.", |
124 | | - default=os.environ.get("EDITOR", "vi"), |
125 | | - ) |
126 | | - |
127 | | - command_parser.add_argument( |
128 | | - "--pager", |
129 | | - help="Editor to use for the ``--view`` option. Default to ``$PAGER`` if defined, else ``less``.", |
130 | | - default=os.environ.get("PAGER", "less"), |
131 | | - ) |
132 | | - |
133 | | - command_parser.add_argument( |
134 | | - "--json", |
135 | | - action="store_true", |
136 | | - help="Use the JSON format with ``--dump``, ``--view`` and ``--edit``.", |
137 | | - ) |
138 | | - |
139 | | - command_parser.add_argument( |
140 | | - "--yaml", |
141 | | - action="store_true", |
142 | | - help="Use the YAML format with ``--dump``, ``--view`` and ``--edit``.", |
143 | | - ) |
144 | | - |
145 | | - def run(self, args: Namespace) -> None: |
146 | | - """Execute the command based on the provided arguments. |
147 | | -
|
148 | | - Parameters |
149 | | - ---------- |
150 | | - args : Namespace |
151 | | - The arguments passed to the command. |
152 | | - """ |
153 | | - if args.edit: |
154 | | - return self.edit(args) |
155 | | - |
156 | | - if args.view: |
157 | | - return self.view(args) |
158 | | - |
159 | | - if args.get: |
160 | | - return self.get(args) |
161 | | - |
162 | | - if args.remove: |
163 | | - return self.remove(args) |
164 | | - |
165 | | - if args.dump or args.pytest: |
166 | | - return self.dump(args) |
167 | | - |
168 | | - if args.load: |
169 | | - return self.load(args) |
170 | | - |
171 | | - if args.supporting_arrays: |
172 | | - return self.supporting_arrays(args) |
173 | | - |
174 | | - def edit(self, args: Namespace) -> None: |
175 | | - """Edit the metadata in place using the specified editor. |
176 | | -
|
177 | | - Parameters |
178 | | - ---------- |
179 | | - args : Namespace |
180 | | - The arguments passed to the command. |
181 | | - """ |
182 | | - return self._edit(args, view=False, cmd=args.editor) |
183 | | - |
184 | | - def view(self, args: Namespace) -> None: |
185 | | - """View the metadata in place using the specified pager. |
186 | | -
|
187 | | - Parameters |
188 | | - ---------- |
189 | | - args : Namespace |
190 | | - The arguments passed to the command. |
191 | | - """ |
192 | | - return self._edit(args, view=True, cmd=args.pager) |
193 | | - |
194 | | - def _edit(self, args: Namespace, view: bool, cmd: str) -> None: |
195 | | - """Internal method to edit or view the metadata. |
196 | | -
|
197 | | - Parameters |
198 | | - ---------- |
199 | | - args : Namespace |
200 | | - The arguments passed to the command. |
201 | | - view : bool |
202 | | - If True, view the metadata; otherwise, edit it. |
203 | | - cmd : str |
204 | | - The command to use for editing or viewing. |
205 | | - """ |
206 | | - from anemoi.utils.checkpoints import load_metadata |
207 | | - from anemoi.utils.checkpoints import replace_metadata |
208 | | - |
209 | | - kwargs: Dict[str, Any] = {} |
210 | | - |
211 | | - if args.json: |
212 | | - ext = "json" |
213 | | - dump = json.dump |
214 | | - load = json.load |
215 | | - if args.test: |
216 | | - kwargs = {"sort_keys": True} |
217 | | - else: |
218 | | - kwargs = {"indent": 4, "sort_keys": True} |
219 | | - else: |
220 | | - ext = "yaml" |
221 | | - dump = yaml.dump |
222 | | - load = yaml.safe_load |
223 | | - kwargs = {"default_flow_style": False} |
224 | | - |
225 | | - with TemporaryDirectory() as temp_dir: |
226 | | - |
227 | | - path = os.path.join(temp_dir, f"checkpoint.{ext}") |
228 | | - metadata = load_metadata(args.path) |
229 | | - |
230 | | - with open(path, "w") as f: |
231 | | - dump(metadata, f, **kwargs) |
232 | | - |
233 | | - subprocess.check_call([cmd, *EDITOR_OPTIONS.get(cmd, []), path]) |
234 | | - |
235 | | - if not view: |
236 | | - with open(path) as f: |
237 | | - edited = load(f) |
238 | | - |
239 | | - if edited != metadata: |
240 | | - replace_metadata(args.path, edited) |
241 | | - else: |
242 | | - LOG.info("No changes made.") |
243 | | - |
244 | | - def remove(self, args: Namespace) -> None: |
245 | | - """Remove the metadata from the checkpoint. |
246 | | -
|
247 | | - Parameters |
248 | | - ---------- |
249 | | - args : Namespace |
250 | | - The arguments passed to the command. |
251 | | - """ |
252 | | - from anemoi.utils.checkpoints import remove_metadata |
253 | | - |
254 | | - remove_metadata(args.path, args.name) |
255 | | - |
256 | | - def dump(self, args: Namespace) -> None: |
257 | | - """Dump the metadata from the checkpoint to a file or standard output. |
258 | | -
|
259 | | - Parameters |
260 | | - ---------- |
261 | | - args : Namespace |
262 | | - The arguments passed to the command. |
263 | | - """ |
264 | | - from anemoi.utils.checkpoints import load_metadata |
265 | | - |
266 | | - if args.output: |
267 | | - file = open(args.output, "w") |
268 | | - else: |
269 | | - file = None |
270 | | - |
271 | | - metadata = load_metadata(args.path) |
272 | | - if args.pytest: |
273 | | - from anemoi.inference.testing.mock_checkpoint import minimum_mock_checkpoint |
274 | | - |
275 | | - # We remove all unessential metadata for testing purposes |
276 | | - metadata = minimum_mock_checkpoint(metadata) |
277 | | - |
278 | | - if args.yaml: |
279 | | - print(yaml.dump(metadata, indent=2, sort_keys=True), file=file) |
280 | | - return |
281 | | - |
282 | | - if args.json or True: |
283 | | - if args.pytest: |
284 | | - print(json.dumps(metadata, sort_keys=True), file=file) |
285 | | - else: |
286 | | - print(json.dumps(metadata, indent=4, sort_keys=True), file=file) |
287 | | - return |
288 | | - |
289 | | - def get(self, args: Namespace) -> None: |
290 | | - """Navigate and print the metadata via a dot-separated path. |
291 | | -
|
292 | | - Parameters |
293 | | - ---------- |
294 | | - args : Namespace |
295 | | - The arguments passed to the command. |
296 | | - """ |
297 | | - from pprint import pprint |
298 | | - |
299 | | - from anemoi.utils.checkpoints import load_metadata |
300 | | - |
301 | | - metadata = load_metadata(args.path, name=args.name) |
302 | | - |
303 | | - if args.get == ".": |
304 | | - print("Metadata from root: ", list(metadata.keys())) |
305 | | - return |
306 | | - |
307 | | - for key in args.get.split("."): |
308 | | - if key == "": |
309 | | - keys = list(metadata.keys()) |
310 | | - print(f"Metadata keys from {args.get[:-1]}: ", keys) |
311 | | - return |
312 | | - else: |
313 | | - metadata = metadata[key] |
314 | | - |
315 | | - print(f"Metadata values for {args.get}: ", end="\n" if isinstance(metadata, (dict, list)) else "") |
316 | | - if isinstance(metadata, dict): |
317 | | - pprint(metadata, indent=2, compact=True) |
318 | | - else: |
319 | | - print(metadata) |
320 | | - |
321 | | - def load(self, args: Namespace) -> None: |
322 | | - """Load metadata into the checkpoint from a specified file. |
323 | | -
|
324 | | - Parameters |
325 | | - ---------- |
326 | | - args : Namespace |
327 | | - The arguments passed to the command. |
328 | | - """ |
329 | | - from anemoi.utils.checkpoints import has_metadata |
330 | | - from anemoi.utils.checkpoints import replace_metadata |
331 | | - from anemoi.utils.checkpoints import save_metadata |
332 | | - |
333 | | - if args.input is None: |
334 | | - raise ValueError("Please specify a value for --input") |
335 | | - |
336 | | - _, ext = os.path.splitext(args.input) |
337 | | - if ext == ".json" or args.json: |
338 | | - with open(args.input) as f: |
339 | | - metadata = json.load(f) |
340 | | - |
341 | | - elif ext in (".yaml", ".yml") or args.yaml: |
342 | | - with open(args.input) as f: |
343 | | - metadata = yaml.safe_load(f) |
344 | | - |
345 | | - else: |
346 | | - raise ValueError(f"Unknown file extension {ext}. Please specify --json or --yaml") |
347 | | - |
348 | | - if has_metadata(args.path, name=args.name): |
349 | | - replace_metadata(args.path, metadata) |
350 | | - else: |
351 | | - save_metadata(args.path, metadata, name=args.name) |
352 | | - |
353 | | - def supporting_arrays(self, args: Namespace) -> None: |
354 | | - """Print the supporting arrays from the metadata. |
355 | | -
|
356 | | - Parameters |
357 | | - ---------- |
358 | | - args : Namespace |
359 | | - The arguments passed to the command. |
360 | | - """ |
361 | | - from anemoi.utils.checkpoints import load_metadata |
362 | | - |
363 | | - _, supporting_arrays = load_metadata(args.path, supporting_arrays=True) |
364 | | - |
365 | | - for name, array in supporting_arrays.items(): |
366 | | - print(f"{name}: shape={array.shape} dtype={array.dtype}") |
367 | | - |
| 10 | +from anemoi.utils.commands.metadata import Metadata |
368 | 11 |
|
369 | 12 | command = Metadata |
0 commit comments