Skip to content

Commit 6b9d842

Browse files
authored
experiments: handle exp run from within subdirs/subrepos (#5093)
* experiments: add test cases for stage in subdir & subrepo * git: fix relpath handling in dulwich add() * repro: move exp file git staging from stage.run into repro * experiments: support running experiments from outside DVC root * experiments: setup logger inside multiprocessing (executor) context * git: fix dulwich add in submodules
1 parent 84a3a59 commit 6b9d842

File tree

6 files changed

+145
-14
lines changed

6 files changed

+145
-14
lines changed

dvc/repo/experiments/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _update_params(self, params: dict):
269269
logger.debug("Using experiment params '%s'", params)
270270

271271
for params_fname in params:
272-
path = PathInfo(self.repo.root_dir) / params_fname
272+
path = PathInfo(params_fname)
273273
suffix = path.suffix.lower()
274274
modify_data = MODIFIERS[suffix]
275275
with modify_data(path, tree=self.repo.tree) as data:
@@ -496,6 +496,8 @@ def _reproduce(
496496

497497
manager = Manager()
498498
pid_q = manager.Queue()
499+
500+
rel_cwd = relpath(os.getcwd(), self.repo.root_dir)
499501
with ProcessPoolExecutor(max_workers=jobs) as workers:
500502
futures = {}
501503
for rev, executor in executors.items():
@@ -505,6 +507,8 @@ def _reproduce(
505507
pid_q,
506508
rev,
507509
name=executor.name,
510+
rel_cwd=rel_cwd,
511+
log_level=logger.getEffectiveLevel(),
508512
)
509513
futures[future] = (rev, executor)
510514

dvc/repo/experiments/executor.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,9 @@ def reproduce(
196196
dvc_dir: str,
197197
queue: "Queue",
198198
rev: str,
199-
cwd: Optional[str] = None,
199+
rel_cwd: Optional[str] = None,
200200
name: Optional[str] = None,
201+
log_level: Optional[int] = None,
201202
) -> Tuple[Optional[str], bool]:
202203
"""Run dvc repro and return the result.
203204
@@ -211,6 +212,7 @@ def reproduce(
211212
unchanged = []
212213

213214
queue.put((rev, os.getpid()))
215+
cls._set_log_level(log_level)
214216

215217
def filter_pipeline(stages):
216218
unchanged.extend(
@@ -223,9 +225,11 @@ def filter_pipeline(stages):
223225
try:
224226
dvc = Repo(dvc_dir)
225227
old_cwd = os.getcwd()
226-
new_cwd = cwd if cwd else dvc.root_dir
227-
os.chdir(new_cwd)
228-
logger.debug("Running repro in '%s'", cwd)
228+
if rel_cwd:
229+
os.chdir(os.path.join(dvc.root_dir, rel_cwd))
230+
else:
231+
os.chdir(dvc.root_dir)
232+
logger.debug("Running repro in '%s'", os.getcwd())
229233

230234
args_path = os.path.join(
231235
dvc.tmp_dir, BaseExecutor.PACKED_ARGS_FILE
@@ -321,6 +325,18 @@ def commit(cls, scm: "Git", exp_hash: str, exp_name: Optional[str] = None):
321325
scm.set_ref(EXEC_BRANCH, branch, symbolic=True)
322326
return new_rev
323327

328+
@staticmethod
329+
def _set_log_level(level):
330+
from dvc.logger import disable_other_loggers
331+
332+
# When executor.reproduce is run in a multiprocessing child process,
333+
# dvc.main will not be called for that child process so we need to
334+
# setup logging ourselves
335+
dvc_logger = logging.getLogger("dvc")
336+
disable_other_loggers()
337+
if level is not None:
338+
dvc_logger.setLevel(level)
339+
324340

325341
class LocalExecutor(BaseExecutor):
326342
"""Local machine experiment executor."""

dvc/repo/reproduce.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
import typing
34
from functools import partial
45

@@ -18,6 +19,7 @@
1819
def _reproduce_stage(stage, **kwargs):
1920
def _run_callback(repro_callback):
2021
_dump_stage(stage)
22+
_track_stage(stage)
2123
repro_callback([stage])
2224

2325
checkpoint_func = kwargs.pop("checkpoint_func", None)
@@ -42,7 +44,10 @@ def _run_callback(repro_callback):
4244
return []
4345

4446
if not kwargs.get("dry", False):
47+
track = checkpoint_func is not None
4548
_dump_stage(stage)
49+
if track:
50+
_track_stage(stage)
4651

4752
return [stage]
4853

@@ -54,6 +59,13 @@ def _dump_stage(stage):
5459
dvcfile.dump(stage, update_pipeline=False)
5560

5661

62+
def _track_stage(stage):
63+
for out in stage.outs:
64+
if not out.use_scm_ignore and out.is_in_repo:
65+
stage.repo.scm.track_file(os.fspath(out.path_info))
66+
stage.repo.scm.track_changed_files()
67+
68+
5769
def _get_active_graph(G):
5870
import networkx as nx
5971

dvc/scm/git/backend/dulwich.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import os
55
import stat
66
from io import BytesIO, StringIO
7-
from typing import Callable, Iterable, Optional, Tuple
7+
from typing import Callable, Dict, Iterable, Optional, Tuple
88

9+
from dvc.path_info import PathInfo
910
from dvc.scm.base import SCMError
1011
from dvc.utils import relpath
1112

@@ -67,8 +68,24 @@ def __init__( # pylint:disable=W0231
6768
except NotGitRepository as exc:
6869
raise SCMError(f"{root_dir} is not a git repository") from exc
6970

71+
self._submodules: Dict[str, "PathInfo"] = self._find_submodules()
7072
self._stashes: dict = {}
7173

74+
def _find_submodules(self) -> Dict[str, "PathInfo"]:
75+
"""Return dict mapping submodule names to submodule paths.
76+
77+
Submodule paths will be relative to Git repo root.
78+
"""
79+
from dulwich.config import ConfigFile, parse_submodules
80+
81+
submodules: Dict[str, "PathInfo"] = {}
82+
config_path = os.path.join(self.root_dir, ".gitmodules")
83+
if os.path.isfile(config_path):
84+
config = ConfigFile.from_path(config_path)
85+
for path, _url, section in parse_submodules(config):
86+
submodules[os.fsdecode(section)] = PathInfo(os.fsdecode(path))
87+
return submodules
88+
7289
def close(self):
7390
self.repo.close()
7491

@@ -101,8 +118,19 @@ def add(self, paths: Iterable[str]):
101118

102119
files = []
103120
for path in paths:
104-
if not os.path.isabs(path):
105-
path = os.path.join(self.root_dir, path)
121+
if not os.path.isabs(path) and self._submodules:
122+
# NOTE: If path is inside a submodule, Dulwich expects the
123+
# staged paths to be relative to the submodule root (not the
124+
# parent git repo root). We append path to root_dir here so
125+
# that the result of relpath(path, root_dir) is actually the
126+
# path relative to the submodule root.
127+
path_info = PathInfo(path).relative_to(self.root_dir)
128+
for sm_path in self._submodules.values():
129+
if path_info.isin(sm_path):
130+
path = os.path.join(
131+
self.root_dir, path_info.relative_to(sm_path)
132+
)
133+
break
106134
if os.path.isdir(path):
107135
files.extend(walk_files(path))
108136
else:
@@ -138,8 +166,6 @@ def untracked_files(self) -> Iterable[str]:
138166
raise NotImplementedError
139167

140168
def is_tracked(self, path: str) -> bool:
141-
from dvc.path_info import PathInfo
142-
143169
rel = PathInfo(path).relative_to(self.root_dir).as_posix().encode()
144170
rel_dir = rel + b"/"
145171
for path in self.repo.open_index():

dvc/stage/run.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,5 @@ def _kill_nt(proc):
211211
def _run_callback(stage, callback_func):
212212
stage.save(allow_missing=True)
213213
stage.commit(allow_missing=True)
214-
for out in stage.outs:
215-
if not out.use_scm_ignore and out.is_in_repo:
216-
stage.repo.scm.track_file(os.fspath(out.path_info))
217-
stage.repo.scm.track_changed_files()
218214
logger.debug("Running checkpoint callback for stage '%s'", stage)
219215
callback_func()

tests/func/experiments/test_experiments.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
from funcy import first
77

8+
from dvc.dvcfile import PIPELINE_FILE
89
from dvc.repo.experiments.utils import exp_refs_by_rev
910
from dvc.utils.serialize import PythonFileCorruptedError
1011
from tests.func.test_repro_multistage import COPY_SCRIPT
@@ -436,3 +437,79 @@ def test_list(tmp_dir, scm, dvc, exp_stage):
436437
baseline_a: {ref_info_a.name, ref_info_b.name},
437438
baseline_c: {ref_info_c.name},
438439
}
440+
441+
442+
def test_subdir(tmp_dir, scm, dvc):
443+
subdir = tmp_dir / "dir"
444+
subdir.gen("copy.py", COPY_SCRIPT)
445+
subdir.gen("params.yaml", "foo: 1")
446+
447+
with subdir.chdir():
448+
dvc.run(
449+
cmd="python copy.py params.yaml metrics.yaml",
450+
metrics_no_cache=["metrics.yaml"],
451+
params=["foo"],
452+
name="copy-file",
453+
no_exec=True,
454+
)
455+
scm.add(
456+
[subdir / "dvc.yaml", subdir / "copy.py", subdir / "params.yaml"]
457+
)
458+
scm.commit("init")
459+
460+
results = dvc.experiments.run(PIPELINE_FILE, params=["foo=2"])
461+
assert results
462+
463+
exp = first(results)
464+
ref_info = first(exp_refs_by_rev(scm, exp))
465+
466+
tree = scm.get_tree(exp)
467+
for fname in ["metrics.yaml", "dvc.lock"]:
468+
assert tree.exists(subdir / fname)
469+
with tree.open(subdir / "metrics.yaml") as fobj:
470+
assert fobj.read().strip() == "foo: 2"
471+
472+
assert dvc.experiments.get_exact_name(exp) == ref_info.name
473+
assert scm.resolve_rev(ref_info.name) == exp
474+
475+
476+
def test_subrepo(tmp_dir, scm):
477+
from tests.unit.tree.test_repo import make_subrepo
478+
479+
subrepo = tmp_dir / "dir" / "repo"
480+
make_subrepo(subrepo, scm)
481+
482+
subrepo.gen("copy.py", COPY_SCRIPT)
483+
subrepo.gen("params.yaml", "foo: 1")
484+
485+
with subrepo.chdir():
486+
subrepo.dvc.run(
487+
cmd="python copy.py params.yaml metrics.yaml",
488+
metrics_no_cache=["metrics.yaml"],
489+
params=["foo"],
490+
name="copy-file",
491+
no_exec=True,
492+
)
493+
scm.add(
494+
[
495+
subrepo / "dvc.yaml",
496+
subrepo / "copy.py",
497+
subrepo / "params.yaml",
498+
]
499+
)
500+
scm.commit("init")
501+
502+
results = subrepo.dvc.experiments.run(PIPELINE_FILE, params=["foo=2"])
503+
assert results
504+
505+
exp = first(results)
506+
ref_info = first(exp_refs_by_rev(scm, exp))
507+
508+
tree = scm.get_tree(exp)
509+
for fname in ["metrics.yaml", "dvc.lock"]:
510+
assert tree.exists(subrepo / fname)
511+
with tree.open(subrepo / "metrics.yaml") as fobj:
512+
assert fobj.read().strip() == "foo: 2"
513+
514+
assert subrepo.dvc.experiments.get_exact_name(exp) == ref_info.name
515+
assert scm.resolve_rev(ref_info.name) == exp

0 commit comments

Comments
 (0)