Skip to content

Commit aec8d65

Browse files
authored
repro: accept multiple targets (#5111)
Currently, we iterate over targets manually in CLI, which results in a complicated code and a DAG collection/check overhead. With this PR we collect the DAG once, figure out all the needed targets within it and just reproduce them in one pass.
1 parent 8a38611 commit aec8d65

File tree

8 files changed

+77
-112
lines changed

8 files changed

+77
-112
lines changed

dvc/command/base.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,6 @@ def __init__(self, args):
4545
updater = Updater(self.repo.tmp_dir, hardlink_lock=hardlink_lock)
4646
updater.check()
4747

48-
@property
49-
def default_targets(self):
50-
"""Default targets for `dvc repro`."""
51-
from dvc.dvcfile import PIPELINE_FILE
52-
53-
logger.trace(f"assuming default target '{PIPELINE_FILE}'.")
54-
return [PIPELINE_FILE]
55-
5648
@abstractmethod
5749
def run(self):
5850
pass

dvc/command/experiments.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -441,31 +441,17 @@ def run(self):
441441

442442
class CmdExperimentsRun(CmdRepro):
443443
def run(self):
444-
# Dirty hack so the for loop below can at least enter once
445-
if self.args.all_pipelines:
446-
self.args.targets = [None]
447-
elif not self.args.targets:
448-
self.args.targets = self.default_targets
449-
450-
ret = 0
451-
for target in self.args.targets:
452-
try:
453-
self.repo.experiments.run(
454-
target,
455-
name=self.args.name,
456-
queue=self.args.queue,
457-
run_all=self.args.run_all,
458-
jobs=self.args.jobs,
459-
params=self.args.params,
460-
checkpoint_resume=self.args.checkpoint_resume,
461-
**self._repro_kwargs,
462-
)
463-
except DvcException:
464-
logger.exception("")
465-
ret = 1
466-
break
444+
self.repo.experiments.run(
445+
name=self.args.name,
446+
queue=self.args.queue,
447+
run_all=self.args.run_all,
448+
jobs=self.args.jobs,
449+
params=self.args.params,
450+
checkpoint_resume=self.args.checkpoint_resume,
451+
**self._repro_kwargs,
452+
)
467453

468-
return ret
454+
return 0
469455

470456

471457
class CmdExperimentsGC(CmdRepro):

dvc/command/repro.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,46 +6,30 @@
66
from dvc.command.metrics import _show_metrics
77
from dvc.command.status import CmdDataStatus
88
from dvc.dvcfile import PIPELINE_FILE
9-
from dvc.exceptions import DvcException
109

1110
logger = logging.getLogger(__name__)
1211

1312

1413
class CmdRepro(CmdBase):
1514
def run(self):
16-
# Dirty hack so the for loop below can at least enter once
17-
if self.args.all_pipelines:
18-
self.args.targets = [None]
19-
elif not self.args.targets:
20-
self.args.targets = self.default_targets
15+
stages = self.repo.reproduce(**self._repro_kwargs)
16+
if len(stages) == 0:
17+
logger.info(CmdDataStatus.UP_TO_DATE_MSG)
18+
else:
19+
logger.info(
20+
"Use `dvc push` to send your updates to " "remote storage."
21+
)
2122

22-
ret = 0
23-
for target in self.args.targets:
24-
try:
25-
stages = self.repo.reproduce(target, **self._repro_kwargs)
23+
if self.args.metrics:
24+
metrics = self.repo.metrics.show()
25+
logger.info(_show_metrics(metrics))
2626

27-
if len(stages) == 0:
28-
logger.info(CmdDataStatus.UP_TO_DATE_MSG)
29-
else:
30-
logger.info(
31-
"Use `dvc push` to send your updates to "
32-
"remote storage."
33-
)
34-
35-
if self.args.metrics:
36-
metrics = self.repo.metrics.show()
37-
logger.info(_show_metrics(metrics))
38-
39-
except DvcException:
40-
logger.exception("")
41-
ret = 1
42-
break
43-
44-
return ret
27+
return 0
4528

4629
@property
4730
def _repro_kwargs(self):
4831
return {
32+
"targets": self.args.targets,
4933
"single_item": self.args.single_item,
5034
"force": self.args.force,
5135
"dry": self.args.dry,

dvc/repo/experiments/run.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def _parse_params(path_params: Iterable):
3737
@locked
3838
def run(
3939
repo,
40-
target: Optional[str] = None,
40+
targets: Optional[Iterable] = None,
4141
params: Optional[Iterable] = None,
4242
run_all: Optional[bool] = False,
4343
jobs: Optional[int] = 1,
4444
**kwargs,
4545
) -> dict:
46-
"""Reproduce the specified target as an experiment.
46+
"""Reproduce the specified targets as an experiment.
4747
4848
Accepts the same additional kwargs as Repo.reproduce.
4949
@@ -59,10 +59,10 @@ def run(
5959
params = []
6060
try:
6161
return repo.experiments.reproduce_one(
62-
target=target, params=params, **kwargs
62+
targets=targets, params=params, **kwargs
6363
)
6464
except UnchangedExperimentError:
6565
# If experiment contains no changes, just run regular repro
6666
kwargs.pop("queue", None)
6767
kwargs.pop("checkpoint_resume", None)
68-
return {None: repo.reproduce(target=target, **kwargs)}
68+
return {None: repo.reproduce(targets=targets, **kwargs)}

dvc/repo/reproduce.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import typing
44
from functools import partial
55

6-
from dvc.exceptions import InvalidArgumentError, ReproductionError
6+
from dvc.exceptions import ReproductionError
77
from dvc.repo.scm_context import scm_context
88
from dvc.stage.run import CheckpointKilledError
99

@@ -92,19 +92,22 @@ def _get_active_graph(G):
9292
@scm_context
9393
def reproduce(
9494
self: "Repo",
95-
target=None,
95+
targets=None,
9696
recursive=False,
9797
pipeline=False,
9898
all_pipelines=False,
9999
**kwargs,
100100
):
101101
glob = kwargs.pop("glob", False)
102102
accept_group = not glob
103-
assert target is None or isinstance(target, str)
104-
if not target and not all_pipelines:
105-
raise InvalidArgumentError(
106-
"Neither `target` nor `--all-pipelines` are specified."
107-
)
103+
104+
if isinstance(targets, str):
105+
targets = [targets]
106+
107+
if not all_pipelines and targets is None:
108+
from dvc.dvcfile import PIPELINE_FILE
109+
110+
targets = [PIPELINE_FILE]
108111

109112
interactive = kwargs.get("interactive", False)
110113
if not interactive:
@@ -113,28 +116,33 @@ def reproduce(
113116
active_graph = _get_active_graph(self.graph)
114117
active_pipelines = get_pipelines(active_graph)
115118

119+
stages = set()
116120
if pipeline or all_pipelines:
117121
if all_pipelines:
118122
pipelines = active_pipelines
119123
else:
120-
stage = self.stage.get_target(target)
121-
pipelines = [get_pipeline(active_pipelines, stage)]
124+
pipelines = []
125+
for target in targets:
126+
stage = self.stage.get_target(target)
127+
pipelines.append(get_pipeline(active_pipelines, stage))
122128

123-
targets = []
124129
for pipeline in pipelines:
125130
for stage in pipeline:
126131
if pipeline.in_degree(stage) == 0:
127-
targets.append(stage)
132+
stages.add(stage)
128133
else:
129-
targets = self.stage.collect(
130-
target,
131-
recursive=recursive,
132-
graph=active_graph,
133-
accept_group=accept_group,
134-
glob=glob,
135-
)
134+
for target in targets:
135+
stages.update(
136+
self.stage.collect(
137+
target,
138+
recursive=recursive,
139+
graph=active_graph,
140+
accept_group=accept_group,
141+
glob=glob,
142+
)
143+
)
136144

137-
return _reproduce_stages(active_graph, targets, **kwargs)
145+
return _reproduce_stages(active_graph, list(stages), **kwargs)
138146

139147

140148
def _reproduce_stages(
@@ -220,30 +228,28 @@ def _reproduce_stages(
220228
def _get_pipeline(G, stages, downstream, single_item):
221229
import networkx as nx
222230

223-
if single_item:
224-
all_pipelines = stages
225-
else:
226-
all_pipelines = []
227-
for stage in stages:
228-
if downstream:
229-
# NOTE (py3 only):
230-
# Python's `deepcopy` defaults to pickle/unpickle the object.
231-
# Stages are complex objects (with references to `repo`,
232-
# `outs`, and `deps`) that cause struggles when you try
233-
# to serialize them. We need to create a copy of the graph
234-
# itself, and then reverse it, instead of using
235-
# graph.reverse() directly because it calls `deepcopy`
236-
# underneath -- unless copy=False is specified.
237-
nodes = nx.dfs_postorder_nodes(
238-
G.copy().reverse(copy=False), stage
239-
)
240-
all_pipelines += reversed(list(nodes))
241-
else:
242-
all_pipelines += nx.dfs_postorder_nodes(G, stage)
231+
all_pipelines = []
232+
for stage in stages:
233+
if downstream:
234+
# NOTE (py3 only):
235+
# Python's `deepcopy` defaults to pickle/unpickle the object.
236+
# Stages are complex objects (with references to `repo`,
237+
# `outs`, and `deps`) that cause struggles when you try
238+
# to serialize them. We need to create a copy of the graph
239+
# itself, and then reverse it, instead of using
240+
# graph.reverse() directly because it calls `deepcopy`
241+
# underneath -- unless copy=False is specified.
242+
nodes = nx.dfs_postorder_nodes(G.copy().reverse(copy=False), stage)
243+
all_pipelines += reversed(list(nodes))
244+
else:
245+
all_pipelines += nx.dfs_postorder_nodes(G, stage)
243246

244247
pipeline = []
245248
for stage in all_pipelines:
246249
if stage not in pipeline:
250+
if single_item and stage not in stages:
251+
continue
252+
247253
pipeline.append(stage)
248254

249255
return pipeline

tests/func/test_repro_multistage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def test_repro_list_of_commands_in_order(tmp_dir, dvc):
528528
"""
529529
)
530530
)
531-
dvc.reproduce(target="multi")
531+
dvc.reproduce(targets=["multi"])
532532
assert (tmp_dir / "foo").read_text() == "foo\n"
533533
assert (tmp_dir / "bar").read_text() == "bar\n"
534534

@@ -547,6 +547,6 @@ def test_repro_list_of_commands_raise_and_stops_after_failure(tmp_dir, dvc):
547547
)
548548
)
549549
with pytest.raises(ReproductionError):
550-
dvc.reproduce(target="multi")
550+
dvc.reproduce(targets=["multi"])
551551
assert (tmp_dir / "foo").read_text() == "foo\n"
552552
assert not (tmp_dir / "bar").exists()

tests/unit/command/test_experiments.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
CmdExperimentsRun,
1313
CmdExperimentsShow,
1414
)
15-
from dvc.dvcfile import PIPELINE_FILE
1615
from dvc.exceptions import InvalidArgumentError
1716

1817
from .test_repro import default_arguments as repro_arguments
@@ -103,9 +102,7 @@ def test_experiments_run(dvc, scm, mocker, args, resume):
103102
mocker.patch.object(cmd.repo, "reproduce")
104103
mocker.patch.object(cmd.repo.experiments, "run")
105104
cmd.run()
106-
cmd.repo.experiments.run.assert_called_with(
107-
PIPELINE_FILE, **default_arguments
108-
)
105+
cmd.repo.experiments.run.assert_called_with(**default_arguments)
109106

110107

111108
def test_experiments_gc(dvc, scm, mocker):

tests/unit/command/test_repro.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dvc.cli import parse_args
22
from dvc.command.repro import CmdRepro
3-
from dvc.dvcfile import PIPELINE_FILE
43

54
default_arguments = {
65
"all_pipelines": False,
@@ -16,14 +15,15 @@
1615
"force_downstream": False,
1716
"pull": False,
1817
"glob": False,
18+
"targets": [],
1919
}
2020

2121

2222
def test_default_arguments(dvc, mocker):
2323
cmd = CmdRepro(parse_args(["repro"]))
2424
mocker.patch.object(cmd.repo, "reproduce")
2525
cmd.run()
26-
cmd.repo.reproduce.assert_called_with(PIPELINE_FILE, **default_arguments)
26+
cmd.repo.reproduce.assert_called_with(**default_arguments)
2727

2828

2929
def test_downstream(dvc, mocker):
@@ -32,4 +32,4 @@ def test_downstream(dvc, mocker):
3232
cmd.run()
3333
arguments = default_arguments.copy()
3434
arguments.update({"downstream": True})
35-
cmd.repo.reproduce.assert_called_with(PIPELINE_FILE, **arguments)
35+
cmd.repo.reproduce.assert_called_with(**arguments)

0 commit comments

Comments
 (0)