Skip to content

Commit 64dfaff

Browse files
authored
[WIP] repro: simplify the logic for frozen stages (#5114)
Fixes #5082
1 parent aec8d65 commit 64dfaff

File tree

3 files changed

+40
-75
lines changed

3 files changed

+40
-75
lines changed

dvc/repo/reproduce.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,6 @@ def _track_stage(stage):
6666
stage.repo.scm.track_changed_files()
6767

6868

69-
def _get_active_graph(G):
70-
import networkx as nx
71-
72-
active = G.copy()
73-
for stage in G:
74-
if not stage.frozen:
75-
continue
76-
active.remove_edges_from(G.out_edges(stage))
77-
for edge in G.out_edges(stage):
78-
_, to_stage = edge
79-
for node in nx.dfs_preorder_nodes(G, to_stage):
80-
# NOTE: `in_degree` will return InDegreeView({}) if stage
81-
# no longer exists in the `active` DAG.
82-
if not active.in_degree(node):
83-
# NOTE: if some edge no longer exists `remove_edges_from`
84-
# will ignore it without error.
85-
active.remove_edges_from(G.out_edges(node))
86-
active.remove_node(node)
87-
88-
return active
89-
90-
9169
@locked
9270
@scm_context
9371
def reproduce(
@@ -104,7 +82,7 @@ def reproduce(
10482
if isinstance(targets, str):
10583
targets = [targets]
10684

107-
if not all_pipelines and targets is None:
85+
if not all_pipelines and not targets:
10886
from dvc.dvcfile import PIPELINE_FILE
10987

11088
targets = [PIPELINE_FILE]
@@ -113,36 +91,33 @@ def reproduce(
11391
if not interactive:
11492
kwargs["interactive"] = self.config["core"].get("interactive", False)
11593

116-
active_graph = _get_active_graph(self.graph)
117-
active_pipelines = get_pipelines(active_graph)
118-
11994
stages = set()
12095
if pipeline or all_pipelines:
96+
pipelines = get_pipelines(self.graph)
12197
if all_pipelines:
122-
pipelines = active_pipelines
98+
used_pipelines = pipelines
12399
else:
124-
pipelines = []
100+
used_pipelines = []
125101
for target in targets:
126102
stage = self.stage.get_target(target)
127-
pipelines.append(get_pipeline(active_pipelines, stage))
103+
used_pipelines.append(get_pipeline(pipelines, stage))
128104

129-
for pipeline in pipelines:
130-
for stage in pipeline:
131-
if pipeline.in_degree(stage) == 0:
105+
for pline in used_pipelines:
106+
for stage in pline:
107+
if pline.in_degree(stage) == 0:
132108
stages.add(stage)
133109
else:
134110
for target in targets:
135111
stages.update(
136112
self.stage.collect(
137113
target,
138114
recursive=recursive,
139-
graph=active_graph,
140115
accept_group=accept_group,
141116
glob=glob,
142117
)
143118
)
144119

145-
return _reproduce_stages(active_graph, list(stages), **kwargs)
120+
return _reproduce_stages(self.graph, list(stages), **kwargs)
146121

147122

148123
def _reproduce_stages(
@@ -183,15 +158,15 @@ def _reproduce_stages(
183158
184159
The derived evaluation of _downstream_ B would be: [B, D, E]
185160
"""
186-
pipeline = _get_pipeline(G, stages, downstream, single_item)
161+
steps = _get_steps(G, stages, downstream, single_item)
187162

188163
force_downstream = kwargs.pop("force_downstream", False)
189164
result = []
190165
unchanged = []
191166
# `ret` is used to add a cosmetic newline.
192167
ret = []
193168
checkpoint_func = kwargs.pop("checkpoint_func", None)
194-
for stage in pipeline:
169+
for stage in steps:
195170
if ret:
196171
logger.info("")
197172

@@ -225,9 +200,17 @@ def _reproduce_stages(
225200
return result
226201

227202

228-
def _get_pipeline(G, stages, downstream, single_item):
203+
def _get_steps(G, stages, downstream, single_item):
229204
import networkx as nx
230205

206+
active = G.copy()
207+
if not single_item:
208+
# NOTE: frozen stages don't matter for single_item
209+
for stage in G:
210+
if stage.frozen:
211+
# NOTE: disconnect frozen stage from its dependencies
212+
active.remove_edges_from(G.out_edges(stage))
213+
231214
all_pipelines = []
232215
for stage in stages:
233216
if downstream:
@@ -239,20 +222,21 @@ def _get_pipeline(G, stages, downstream, single_item):
239222
# itself, and then reverse it, instead of using
240223
# graph.reverse() directly because it calls `deepcopy`
241224
# underneath -- unless copy=False is specified.
242-
nodes = nx.dfs_postorder_nodes(G.copy().reverse(copy=False), stage)
225+
nodes = nx.dfs_postorder_nodes(active.reverse(copy=False), stage)
243226
all_pipelines += reversed(list(nodes))
244227
else:
245-
all_pipelines += nx.dfs_postorder_nodes(G, stage)
228+
all_pipelines += nx.dfs_postorder_nodes(active, stage)
246229

247-
pipeline = []
230+
steps = []
248231
for stage in all_pipelines:
249-
if stage not in pipeline:
232+
if stage not in steps:
233+
# NOTE: order of steps still matters for single_item
250234
if single_item and stage not in stages:
251235
continue
252236

253-
pipeline.append(stage)
237+
steps.append(stage)
254238

255-
return pipeline
239+
return steps
256240

257241

258242
def _repro_callback(experiments_callback, unchanged, stages):

tests/func/test_repro_multistage.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,19 @@ def test_non_existing_stage_name(tmp_dir, dvc, run_copy):
170170
assert main(["freeze", ":copy-file1-file3"]) != 0
171171

172172

173+
def test_repro_frozen(tmp_dir, dvc, run_copy):
174+
(data_stage,) = tmp_dir.dvc_gen("data", "foo")
175+
stage0 = run_copy("data", "stage0", name="copy-data-stage0")
176+
run_copy("stage0", "stage1", name="copy-data-stage1")
177+
run_copy("stage1", "stage2", name="copy-data-stage2")
178+
179+
dvc.freeze("copy-data-stage1")
180+
181+
tmp_dir.gen("data", "bar")
182+
stages = dvc.reproduce()
183+
assert stages == [data_stage, stage0]
184+
185+
173186
def test_downstream(tmp_dir, dvc):
174187
# The dependency graph should look like this:
175188
#

tests/unit/repo/test_reproduce.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,5 @@
11
import mock
22

3-
from dvc.repo.reproduce import _get_active_graph
4-
5-
6-
def test_get_active_graph(tmp_dir, dvc):
7-
(pre_foo_stage,) = tmp_dir.dvc_gen({"pre-foo": "pre-foo"})
8-
foo_stage = dvc.run(
9-
single_stage=True, deps=["pre-foo"], outs=["foo"], cmd="echo foo > foo"
10-
)
11-
bar_stage = dvc.run(
12-
single_stage=True, deps=["foo"], outs=["bar"], cmd="echo bar > bar"
13-
)
14-
baz_stage = dvc.run(
15-
single_stage=True, deps=["foo"], outs=["baz"], cmd="echo baz > baz"
16-
)
17-
18-
dvc.freeze("bar.dvc")
19-
20-
graph = dvc.graph
21-
active_graph = _get_active_graph(graph)
22-
assert active_graph.nodes == graph.nodes
23-
assert set(active_graph.edges) == {
24-
(foo_stage, pre_foo_stage),
25-
(baz_stage, foo_stage),
26-
}
27-
28-
dvc.freeze("baz.dvc")
29-
30-
graph = dvc.graph
31-
active_graph = _get_active_graph(graph)
32-
assert set(active_graph.nodes) == {bar_stage, baz_stage}
33-
assert not active_graph.edges
34-
353

364
@mock.patch("dvc.repo.reproduce._reproduce_stage", returns=[])
375
def test_number_reproduces(reproduce_stage_mock, tmp_dir, dvc):

0 commit comments

Comments
 (0)