From fdd36b8d30fe96f57787d03ee52ccbf37e3fb43b Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 12 May 2025 19:04:28 -0700 Subject: [PATCH 1/5] First commit --- metaflow/cli.py | 64 ++++- metaflow/cli_components/run_cmds.py | 143 +++++++++-- metaflow/cli_components/step_cmd.py | 141 ++++++++++- metaflow/client/core.py | 234 ++++++++++-------- metaflow/datastore/task_datastore.py | 1 + metaflow/metaflow_config.py | 7 + metaflow/plugins/__init__.py | 1 + metaflow/plugins/metadata_providers/spin.py | 87 +++++++ metaflow/runtime.py | 251 +++++++++++++++++++- metaflow/util.py | 56 +++++ 10 files changed, 849 insertions(+), 136 deletions(-) create mode 100644 metaflow/plugins/metadata_providers/spin.py diff --git a/metaflow/cli.py b/metaflow/cli.py index d24829e6db6..f480d9078f4 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -114,6 +114,8 @@ def logger(body="", system_msg=False, head="", bad=False, timestamp=True, nl=Tru "step": "metaflow.cli_components.step_cmd.step", "run": "metaflow.cli_components.run_cmds.run", "resume": "metaflow.cli_components.run_cmds.resume", + "spin": "metaflow.cli_components.run_cmds.spin", + "spin-step": "metaflow.cli_components.step_cmd.spin_step", }, ) def cli(ctx): @@ -440,14 +442,10 @@ def start( ctx.obj.event_logger = LOGGING_SIDECARS[event_logger]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.event_logger.start() - _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) ctx.obj.monitor = MONITOR_SIDECARS[monitor]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.monitor.start() - _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor @@ -462,6 +460,54 @@ def start( ) ctx.obj.config_options = config_options + ctx.obj.is_spin = False + + # Override values for spin + if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0]: + # To minimize side-effects for spin, we will only use the following: + # - spin metadata provider, + # - local datastore, + # - local environment, + # - null event logger, + # - null monitor + ctx.obj.is_spin = True + ctx.obj.spin_metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + if datastore_root is None: + datastore_root = ctx.obj.spin_datastore_impl.get_datastore_root_from_config( + ctx.obj.echo + ) + ctx.obj.spin_datastore_impl.datastore_root = datastore_root + ctx.obj.spin_flow_datastore = FlowDataStore( + ctx.obj.flow.name, + ctx.obj.environment, # Same environment as run/resume + ctx.obj.spin_metadata, # spin metadata provider (no-op) + ctx.obj.event_logger, # null event logger + ctx.obj.monitor, # null monitor + storage_impl=ctx.obj.spin_datastore_impl, + ) + + # Start event logger and monitor + ctx.obj.event_logger.start() + _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) + + ctx.obj.monitor.start() + _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) + + ctx.obj.effective_flow_datastore = ( + ctx.obj.spin_flow_datastore if ctx.obj.is_spin else ctx.obj.flow_datastore + ) + ctx.obj.effective_metadata = ( + ctx.obj.spin_metadata if ctx.obj.is_spin else ctx.obj.metadata + ) decorators._init(ctx.obj.flow) @@ -471,14 +517,14 @@ def start( ctx.obj.flow, ctx.obj.graph, ctx.obj.environment, - ctx.obj.flow_datastore, - ctx.obj.metadata, + ctx.obj.effective_flow_datastore, + ctx.obj.effective_metadata, ctx.obj.logger, echo, deco_options, ) - # In the case of run/resume, we will want to apply the TL decospecs + # In the case of run/resume/spin, we will want to apply the TL decospecs # *after* the run decospecs so that they don't take precedence. In other # words, for the same decorator, we want `myflow.py run --with foo` to # take precedence over any other `foo` decospec @@ -494,7 +540,7 @@ def start( parameters.set_parameter_context( ctx.obj.flow.name, ctx.obj.echo, - ctx.obj.flow_datastore, + ctx.obj.effective_flow_datastore, { k: ConfigValue(v) for k, v in ctx.obj.flow.__class__._flow_state.get( @@ -506,7 +552,7 @@ def start( if ( hasattr(ctx, "saved_args") and ctx.saved_args - and ctx.saved_args[0] not in ("run", "resume") + and ctx.saved_args[0] not in ("run", "resume", "spin") ): # run/resume are special cases because they can add more decorators with --with, # so they have to take care of themselves. diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 681fc8f94f9..dfa7914a6d1 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -10,17 +10,17 @@ from ..metaflow_current import current from ..metaflow_config import DEFAULT_DECOSPECS from ..package import MetaflowPackage -from ..runtime import NativeRuntime +from ..runtime import NativeRuntime, SpinRuntime from ..system import _system_logger from ..tagging_util import validate_tags -from ..util import get_latest_run_id, write_latest_run_id +from ..util import get_latest_run_id, write_latest_run_id, get_latest_task_pathspec def before_run(obj, tags, decospecs): validate_tags(tags) - # There's a --with option both at the top-level and for the run + # There's a --with option both at the top-level and for the run/resume/spin # subcommand. Why? # # "run --with shoes" looks so much better than "--with shoes run". @@ -40,7 +40,7 @@ def before_run(obj, tags, decospecs): + list(obj.environment.decospecs() or []) ) if all_decospecs: - # These decospecs are the ones from run/resume PLUS the ones from the + # These decospecs are the ones from run/resume/spin PLUS the ones from the # environment (for example the @conda) decorators._attach_decorators(obj.flow, all_decospecs) decorators._init(obj.flow) @@ -52,7 +52,7 @@ def before_run(obj, tags, decospecs): # obj.environment.init_environment(obj.logger) decorators._init_step_decorators( - obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger + obj.flow, obj.graph, obj.environment, obj.effective_flow_datastore, obj.logger ) obj.metadata.add_sticky_tags(tags=tags) @@ -65,6 +65,29 @@ def before_run(obj, tags, decospecs): ) +def common_runner_options(func): + @click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the ID of this run to the file specified.", + ) + @click.option( + "--runner-attribute-file", + default=None, + show_default=True, + type=str, + help="Write the metadata and pathspec of this run to the file specified. Used internally " + "for Metaflow's Runner API.", + ) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def write_file(file_path, content): if file_path is not None: with open(file_path, "w", encoding="utf-8") as f: @@ -129,20 +152,6 @@ def common_run_options(func): "in steps.", callback=config_callback, ) - @click.option( - "--run-id-file", - default=None, - show_default=True, - type=str, - help="Write the ID of this run to the file specified.", - ) - @click.option( - "--runner-attribute-file", - default=None, - show_default=True, - type=str, - help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", - ) @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -187,6 +196,7 @@ def wrapper(*args, **kwargs): @click.command(help="Resume execution of a previous run of this flow.") @tracing.cli("cli/resume") @common_run_options +@common_runner_options @click.pass_obj def resume( obj, @@ -305,6 +315,7 @@ def resume( @click.command(help="Run the workflow locally.") @tracing.cli("cli/run") @common_run_options +@common_runner_options @click.option( "--namespace", "user_namespace", @@ -380,3 +391,97 @@ def run( ) with runtime.run_heartbeat(): runtime.execute() + + +@click.command(help="Spins up a task for a given step from a previous run locally.") +@click.argument("spin-pathspec") +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. The artifacts should " + "be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the " + "artifact values. The artifact values will overwrite the default values of the artifacts used in " + "the spun step.", +) +@click.option( + "--max-log-size", + default=10, + show_default=True, + help="Maximum size of stdout and stderr captured in " + "megabytes. If a step outputs more than this to " + "stdout/stderr, its output will be truncated.", +) +@common_runner_options +@click.pass_obj +def spin( + obj, + spin_pathspec=None, + artifacts_module=None, + skip_decorators=False, + max_log_size=None, + run_id_file=None, + runner_attribute_file=None, + **kwargs +): + before_run(obj, [], []) + # Verify whether the user has provided step-name or spin-pathspec + if "/" in spin_pathspec: + # spin_pathspec is in the form of a task pathspec + if len(spin_pathspec.split("/")) != 4: + raise CommandException( + "Invalid spin-pathspec format. Expected format: {flow_name}/{run_id}/{step_name}/{task_id}" + ) + _, _, step_name, _ = spin_pathspec.split("/") + else: + # spin_pathspec is in the form of a step name + step_name = spin_pathspec + spin_pathspec = get_latest_task_pathspec(obj.flow.name, step_name) + + obj.echo( + f"Spinning up step *{step_name}* locally using previous task pathspec *{spin_pathspec}*" + ) + obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + step_func = getattr(obj.flow, step_name) + + spin_runtime = SpinRuntime( + obj.flow, + obj.graph, + obj.effective_flow_datastore, + obj.effective_metadata, + obj.environment, + obj.package, + obj.logger, + obj.entrypoint, + obj.event_logger, + obj.monitor, + step_func, + spin_pathspec, + skip_decorators, + artifacts_module, + max_log_size * 1024 * 1024, + ) + + write_latest_run_id(obj, spin_runtime.run_id) + write_file(run_id_file, spin_runtime.run_id) + spin_runtime.execute() + + if runner_attribute_file: + with open(runner_attribute_file, "w") as f: + json.dump( + { + "task_id": spin_runtime.task.task_id, + "step_name": step_name, + "run_id": spin_runtime.run_id, + "flow_name": obj.flow.name, + "metadata": f"{obj.spin_metadata.__class__.TYPE}@{obj.spin_metadata.__class__.INFO}", + }, + f, + ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index f4bef099e42..e662557fa84 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -6,7 +6,7 @@ from ..exception import CommandException from ..task import MetaflowTask from ..unbounded_foreach import UBF_CONTROL, UBF_TASK -from ..util import decompress_list +from ..util import decompress_list, read_artifacts_module import metaflow.tracing as tracing @@ -176,3 +176,142 @@ def step( ) echo("Success", fg="green", bold=True, indent=True) + + +@click.command(help="Internal command to spin a single task.", hidden=True) +@click.argument("step-name") +@click.option( + "--run-id", + default=None, + required=True, + help="Run ID for the step that's about to be spun", +) +@click.option( + "--task-id", + default=None, + required=True, + help="Task ID for the step that's about to be spun", +) +@click.option( + "--spin-pathspec", + default=None, + show_default=True, + help="Task Pathspec to be used in the spun step.", +) +@click.option( + "--input-paths", + help="A comma-separated list of pathspecs specifying inputs for this step.", +) +@click.option( + "--split-index", + type=int, + default=None, + show_default=True, + help="Index of this foreach split.", +) +@click.option( + "--retry-count", + default=0, + help="How many times we have attempted to run this task.", +) +@click.option( + "--max-user-code-retries", + default=0, + help="How many times we should attempt running the user code.", +) +@click.option( + "--namespace", + "namespace", + default=None, + help="Change namespace from the default (your username) to the specified tag.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. The artifacts should " + "be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the " + "artifact values. The artifact values will overwrite the default values of the artifacts used in " + "the spun step.", +) +@click.pass_context +def spin_step( + ctx, + step_name, + run_id=None, + task_id=None, + spin_pathspec=None, + input_paths=None, + split_index=None, + retry_count=None, + max_user_code_retries=None, + namespace=None, + skip_decorators=False, + artifacts_module=None, +): + import time + + start = time.time() + import sys + + if ctx.obj.is_quiet: + print("Echo dev null") + echo = echo_dev_null + else: + print("Echo always") + echo = echo_always + + input_paths = decompress_list(input_paths) if input_paths else [] + echo( + f"Spinning a task, *{step_name}* with previous task pathspec: {spin_pathspec}", + fg="magenta", + bold=False, + ) + # if namespace is not None: + # namespace(namespace or None) + + spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} + spin_artifacts = spin_artifacts.get("ARTIFACTS", {}) + + print(f"spin_artifacts: {spin_artifacts}") + print(f"spin_pathspec: {spin_pathspec}") + print(f"input_paths: {input_paths}") + + # task = MetaflowTask( + # ctx.obj.flow, + # ctx.obj.effective_flow_datastore, # local datastore + # ctx.obj.effective_metadata, # local metadata provider + # ctx.obj.environment, # local environment + # ctx.obj.echo, + # ctx.obj.event_logger, # null logger + # ctx.obj.monitor, # null monitor + # None, # no unbounded foreach context + # ) + # echo( + # "I am here", + # fg="magenta", + # bold=False, + # ) + + # task.run_step( + # step_name, + # run_id, + # task_id, + # None, + # input_paths, + # split_index, + # retry_count, + # max_user_code_retries, + # spin_pathspec, + # skip_decorators, + # spin_artifacts, + # ) + + echo(f"Time taken for the whole thing: {time.time() - start}") diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 4edbcdac00c..f1848e38799 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1191,142 +1191,186 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def _iter_matching_tasks(self, steps, metadata_key, metadata_pattern): + def _get_matching_pathspecs(self, steps, metadata_key, metadata_pattern): """ - Yield tasks from specified steps matching a foreach path pattern. + Yield pathspecs of tasks from specified steps that match a given metadata pattern. Parameters ---------- steps : List[str] - List of step names to search for tasks - pattern : str - Regex pattern to match foreach-indices metadata + List of Step objects to search for tasks. + metadata_key : str + Metadata key to filter tasks on (e.g., 'foreach-execution-path'). + metadata_pattern : str + Regular expression pattern to match against the metadata value. - Returns - ------- - Iterator[Task] - Tasks matching the foreach path pattern + Yields + ------ + str + Pathspec of each task whose metadata value for the specified key matches the pattern. """ flow_id, run_id, _, _ = self.path_components - for step in steps: task_pathspecs = self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step.id, metadata_key, metadata_pattern + flow_id, run_id, step, metadata_key, metadata_pattern ) for task_pathspec in task_pathspecs: - yield Task(pathspec=task_pathspec, _namespace_check=False) + yield task_pathspec + + @staticmethod + def _get_previous_steps(graph_info, step_name): + # Get the parent steps + steps = [] + for node_name, attributes in graph_info["steps"].items(): + if step_name in attributes["next"]: + steps.append(node_name) + return steps @property - def parent_tasks(self) -> Iterator["Task"]: + def parent_task_pathspecs(self) -> Iterator[str]: """ - Yields all parent tasks of the current task if one exists. + Yields pathspecs of all parent tasks of the current task. Yields ------ - Task - Parent task of the current task - + str + Pathspec of the parent task of the current task """ - flow_id, run_id, _, _ = self.path_components + flow_id, run_id, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data - steps = list(self.parent.parent_steps) - if not steps: - return [] - - current_path = self.metadata_dict.get("foreach-execution-path", "") + # Get the parent steps + steps = self._get_previous_steps(graph_info, step_name) + node_type = graph_info["steps"][step_name]["type"] + current_path = metadata_dict.get("foreach-execution-path") if len(steps) > 1: # Static join - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Parent task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach join - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No parent steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach split or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) + current_depth = len(current_path.split(",")) + if node_type == "join": + # Foreach join + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach split or linear step + # Pattern: "A:10,B:13" + parent_step_type = graph_info["steps"][steps[0]]["type"] + target_depth = current_depth + if parent_step_type == "split-foreach" and current_depth == 1: + # (Current task, "A:10") and (Parent task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") + if parent_step_type == "split-foreach": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + metadata_key = "foreach-execution-path" + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec @property - def child_tasks(self) -> Iterator["Task"]: + def child_task_pathspecs(self) -> Iterator[str]: """ - Yield all child tasks of the current task if one exists. + Yields pathspecs of all child tasks of the current task. Yields ------ - Task - Child task of the current task + str + Pathspec of the child task of the current task """ - flow_id, run_id, _, _ = self.path_components - steps = list(self.parent.child_steps) - if not steps: - return [] + flow_id, run_id, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data + + # Get the child steps + steps = graph_info["steps"][step_name]["next"] - current_path = self.metadata_dict.get("foreach-execution-path", "") + node_type = graph_info["steps"][step_name]["type"] + current_path = self.metadata_dict.get("foreach-execution-path") if len(steps) > 1: # Static split - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Child task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach split - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No child steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach join or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) - - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + current_depth = len(current_path.split(",")) + if node_type == "split-foreach": + # Foreach split + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach join or linear step + # Pattern: "A:10,B:13" + child_step_type = graph_info["steps"][steps[0]]["type"] + + # We need to know if the child step is a foreach join or a static join + child_step_prev_steps = self._get_previous_steps( + graph_info, steps[0] + ) + if len(child_step_prev_steps) > 1: + child_step_type = "static-join" + target_depth = current_depth + if child_step_type == "join" and current_depth == 1: + # (Current task, "A:10") and (Child task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") + if child_step_type == "join": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) + + metadata_key = "foreach-execution-path" + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec + + @property + def parent_tasks(self) -> Iterator["Task"]: + """ + Yields all parent tasks of the current task if one exists. + + Yields + ------ + Task + Parent task of the current task + """ + for pathspec in self.parent_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) + + @property + def child_tasks(self) -> Iterator["Task"]: + """ + Yields all child tasks of the current task if one exists. + + Yields + ------ + Task + Child task of the current task + """ + for pathspec in self.child_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) @property def metadata(self) -> List[Metadata]: diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index 325cc1ea1ae..7691862fc2a 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -165,6 +165,7 @@ def __init__( data_obj = self.load_metadata([self.METADATA_DATA_SUFFIX]) data_obj = data_obj[self.METADATA_DATA_SUFFIX] elif self._attempt is None or not allow_not_done: + print("HERE in self._mode == 'r'") raise DataException( "No completed attempts of the task was found for task '%s'" % self._path diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index eaf230bc383..fc38e5c727d 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -47,6 +47,13 @@ "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" ) +### +# Spin configuration +### +SPIN_ALLOWED_DECORATORS = from_conf( + "SPIN_ALLOWED_DECORATORS", ["conda", "pypi", "environment"] +) + ### # User configuration ### diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index f74a70726ba..0ed913763e9 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -82,6 +82,7 @@ METADATA_PROVIDERS_DESC = [ ("service", ".metadata_providers.service.ServiceMetadataProvider"), ("local", ".metadata_providers.local.LocalMetadataProvider"), + ("spin", ".metadata_providers.spin.SpinMetadataProvider"), ] # Add datastore here diff --git a/metaflow/plugins/metadata_providers/spin.py b/metaflow/plugins/metadata_providers/spin.py new file mode 100644 index 00000000000..828f8b9b31b --- /dev/null +++ b/metaflow/plugins/metadata_providers/spin.py @@ -0,0 +1,87 @@ +""" +An implementation of a null metadata provider for Spin steps. +This provider does not store any metadata and is used when +we want to ensure that there are no side effects from the +metadata provider. + +""" + +import time + +from metaflow.metadata_provider import MetadataProvider +from typing import List + + +class SpinMetadataProvider(MetadataProvider): + TYPE = "spin" + + def __init__(self, environment, flow, event_logger, monitor): + super(SpinMetadataProvider, self).__init__( + environment, flow, event_logger, monitor + ) + # No provider-specific initialization needed for a null provider + + def version(self): + """Return provider type as version.""" + return self.TYPE + + def new_run_id(self, tags=None, sys_tags=None): + # We currently just use the timestamp to create an ID. We can be reasonably certain + # that it is unique and this makes it possible to do without coordination or + # reliance on POSIX locks in the filesystem. + run_id = "%d" % (time.time() * 1e6) + return run_id + + def register_run_id(self, run_id, tags=None, sys_tags=None): + """No-op register_run_id. Indicates no action taken.""" + return False + + def new_task_id(self, run_id, step_name, tags=None, sys_tags=None): + self._task_id_seq += 1 + task_id = str(self._task_id_seq) + return task_id + + def register_task_id( + self, run_id, step_name, task_id, attempt=0, tags=None, sys_tags=None + ): + """No-op register_task_id. Indicates no action taken.""" + return False + + def register_data_artifacts( + self, run_id, step_name, task_id, attempt_id, artifacts + ): + """No-op register_data_artifacts.""" + pass + + def register_metadata(self, run_id, step_name, task_id, metadata): + """No-op register_metadata.""" + pass + + @classmethod + def _mutate_user_tags_for_run( + cls, flow_id, run_id, tags_to_add=None, tags_to_remove=None + ): + """No-op _mutate_user_tags_for_run. Returns an empty set of tags.""" + return frozenset() + + @classmethod + def filter_tasks_by_metadata( + cls, + flow_name: str, + run_id: str, + step_name: str, + field_name: str, + pattern: str, + ) -> List[str]: + """No-op filter_tasks_by_metadata. Returns an empty list.""" + return [] + + @classmethod + def _get_object_internal( + cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args + ): + """ + No-op _get_object_internal. Returns an empty list, + which MetadataProvider.get_object will interpret as 'not found'. + """ + return [] diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 7e9269841fb..e187d3753d3 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -24,7 +24,7 @@ from . import get_namespace from .metadata_provider import MetaDatum -from .metaflow_config import MAX_ATTEMPTS, UI_URL +from .metaflow_config import MAX_ATTEMPTS, UI_URL, SPIN_ALLOWED_DECORATORS from .exception import ( MetaflowException, MetaflowInternalError, @@ -73,6 +73,187 @@ # TODO option: output dot graph periodically about execution +class SpinRuntime(object): + def __init__( + self, + flow, + graph, + flow_datastore, + metadata, + environment, + package, + logger, + entrypoint, + event_logger, + monitor, + step_func, + spin_pathspec, + skip_decorators=False, + artifacts_module=None, + max_log_size=MAX_LOG_SIZE, + ): + from metaflow import Task + + self._flow = flow + self._graph = graph + self._flow_datastore = flow_datastore + self._metadata = metadata + self._environment = environment + self._package = package + self._logger = logger + self._entrypoint = entrypoint + self._event_logger = event_logger + self._monitor = monitor + + self._step_func = step_func + self._spin_pathspec = spin_pathspec + self._spin_task = Task(self._spin_pathspec, _namespace_check=False) + self._input_paths = None + self._split_index = None + self._whitelist_decorators = None + self._config_file_name = None + self._skip_decorators = skip_decorators + self._artifacts_module = artifacts_module + self._max_log_size = max_log_size + self._encoding = sys.stdout.encoding or "UTF-8" + + # Create a new run_id for the spin task + self.run_id = self._metadata.new_run_id() + for deco in self.whitelist_decorators: + # print("-" * 100) + deco.runtime_init(flow, graph, package, self.run_id) + + @property + def split_index(self): + if self._split_index: + return self._split_index + + if hasattr(self._spin_task, "index"): + self._split_index = self._spin_task.index + return self._split_index + + @property + def input_paths(self): + st = time.time() + + def _format_input_paths(task_pathspec): + _, run_id, step_name, task_id = task_pathspec.split("/") + return f"{run_id}/{step_name}/{task_id}" + + if self._input_paths: + return self._input_paths + + if self._step_func.name == "start": + from metaflow import Step + + flow_name, run_id, _, _ = self._spin_pathspec.split("/") + task = Step( + f"{flow_name}/{run_id}/_parameters", _namespace_check=False + ).task + self._input_paths = [_format_input_paths(task.pathspec)] + else: + self._input_paths = [ + _format_input_paths(task_pathspec) + for task_pathspec in self._spin_task.parent_task_pathspecs + ] + et = time.time() + print(f"Time taken to get input paths: {et - st}") + return self._input_paths + + @property + def whitelist_decorators(self): + if self._skip_decorators: + return [] + if self._whitelist_decorators: + return self._whitelist_decorators + self._whitelist_decorators = [ + deco + for deco in self._step_func.decorators + if any(deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS) + ] + return self._whitelist_decorators + + def _new_task(self, step, input_paths=None, **kwargs): + return Task( + flow_datastore=self._flow_datastore, + flow=self._flow, + step=step, + run_id=self.run_id, + metadata=self._metadata, + environment=self._environment, + entrypoint=self._entrypoint, + event_logger=self._event_logger, + monitor=self._monitor, + input_paths=input_paths, + decos=self.whitelist_decorators, + logger=self._logger, + split_index=self.split_index, + **kwargs, + ) + + def execute(self): + exception = None + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: + config_value = dump_config_values(self._flow) + if config_value: + json.dump(config_value, config_file) + config_file.flush() + self._config_file_name = config_file.name + else: + self._config_file_name = None + + # print("I am before _new_task") + self.task = self._new_task(self._step_func.name, self.input_paths) + # print(f"I am after _new_task") + try: + self._launch_and_monitor_task() + except Exception as ex: + self._logger("Task failed.", system_msg=True, bad=True) + exception = ex + raise + finally: + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) + + def _launch_and_monitor_task(self): + # print("I am in _launch_and_monitor_task") + worker = Worker( + self.task, + self._max_log_size, + self._config_file_name, + spin_pathspec=self._spin_pathspec, + skip_decorators=self._skip_decorators, + artifacts_module=self._artifacts_module, + ) + + # print("Worker created") + + poll = procpoll.make_poll() + fds = worker.fds() + for fd in fds: + poll.add(fd) + + active_fds = set(fds) + + while active_fds: + events = poll.poll(POLL_TIMEOUT) + for event in events: + if event.can_read: + worker.read_logline(event.fd) + if event.is_terminated: + poll.remove(event.fd) + active_fds.remove(event.fd) + + # print("I am after while loop") + returncode = worker.terminate() + # print(f"Return code: {returncode}") + + if returncode != 0: + raise TaskFailed(self.task, f"Task failed with return code {returncode}") + else: + self._logger("Task finished successfully.", system_msg=True) + + class NativeRuntime(object): def __init__( self, @@ -1508,8 +1689,13 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task): + def __init__( + self, task, spin_pathspec=None, skip_decorators=False, artifacts_module=None + ): self.task = task + self.spin_pathspec = spin_pathspec + self.skip_decorators = skip_decorators + self.artifacts_module = artifacts_module self.entrypoint = list(task.entrypoint) self.top_level_options = { "quiet": True, @@ -1542,18 +1728,42 @@ def __init__(self, task): (k, ConfigInput.make_key_name(k)) for k in configs ] + if spin_pathspec: + self.spin_args() + else: + self.default_args() + + def default_args(self): self.commands = ["step"] self.command_args = [self.task.step] self.command_options = { - "run-id": task.run_id, - "task-id": task.task_id, - "input-paths": compress_list(task.input_paths), - "split-index": task.split_index, - "retry-count": task.retries, - "max-user-code-retries": task.user_code_retries, - "tag": task.tags, + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "tag": self.task.tags, "namespace": get_namespace() or "", - "ubf-context": task.ubf_context, + "ubf-context": self.task.ubf_context, + } + self.env = {} + + def spin_args(self): + self.commands = ["spin-step"] + self.command_args = [self.task.step] + + self.command_options = { + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "namespace": get_namespace() or "", + "spin-pathspec": self.spin_pathspec, + "skip-decorators": self.skip_decorators, + "artifacts-module": self.artifacts_module, } self.env = {} @@ -1595,9 +1805,20 @@ def __str__(self): class Worker(object): - def __init__(self, task, max_logs_size, config_file_name): + def __init__( + self, + task, + max_logs_size, + config_file_name, + spin_pathspec=None, + skip_decorators=False, + artifacts_module=None, + ): self.task = task self._config_file_name = config_file_name + self.spin_pathspec = spin_pathspec + self.skip_decorators = skip_decorators + self.artifacts_module = artifacts_module self._proc = self._launch() if task.retries > task.user_code_retries: @@ -1629,7 +1850,12 @@ def __init__(self, task, max_logs_size, config_file_name): # not it is properly shut down) def _launch(self): - args = CLIArgs(self.task) + args = CLIArgs( + self.task, + spin_pathspec=self.spin_pathspec, + skip_decorators=self.skip_decorators, + artifacts_module=self.artifacts_module, + ) env = dict(os.environ) if self.task.clone_run_id: @@ -1765,6 +1991,7 @@ def terminate(self): # Return early if the task is cloned since we don't want to # perform any log collection. if not self.task.is_cloned: + print("I am in terminate where task is not cloned") self.task.save_metadata( "runtime", { diff --git a/metaflow/util.py b/metaflow/util.py index f9051aff589..9fc80ad9681 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -177,6 +177,45 @@ def resolve_identity(): return "%s:%s" % (identity_type, identity_value) +def get_latest_task_pathspec(flow_name: str, step_name: str) -> str: + """ + Returns a task pathspec from the latest run of the flow for the queried step. + If the queried step has several tasks, the task pathspec of the first task is returned. + + Parameters + ---------- + flow_name : str + The name of the flow. + step_name : str + The name of the step. + + Returns + ------- + str + The task pathspec of the first task of the queried step. + + Raises + ------ + MetaflowNotFound + If no task or run is found for the queried step. + """ + from metaflow import Flow, Step + from metaflow.exception import MetaflowNotFound + + run = Flow(flow_name, _namespace_check=False).latest_run + + if run is None: + raise MetaflowNotFound(f"No run found for the flow {flow_name}") + + try: + task = Step(f"{flow_name}/{run.id}/{step_name}", _namespace_check=False).task + return task.pathspec + except Exception: + raise MetaflowNotFound( + f"No step *{step_name}* found in run *{run.id}* for flow *{flow_name}*" + ) + + def get_latest_run_id(echo, flow_name): from metaflow.plugins.datastores.local_storage import LocalStorage @@ -467,6 +506,23 @@ def to_pod(value): return str(value) +def read_artifacts_module(file_path): + import importlib.util + + try: + spec = importlib.util.spec_from_file_location("artifacts_module", file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + variables = vars(module) + if "ARTIFACTS" not in variables: + raise MetaflowInternalError( + f"Module {file_path} does not contain ARTIFACTS variable" + ) + return variables + except Exception as e: + raise MetaflowInternalError(f"Error reading file {file_path}") from e + + if sys.version_info[:2] > (3, 5): from metaflow._vendor.packaging.version import parse as version_parse else: From dbcd63bbe1d3731b950488b831a1eea0886aca13 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Thu, 5 Jun 2025 10:42:50 -0700 Subject: [PATCH 2/5] Tmp commit --- metaflow/cli.py | 56 ++++--- metaflow/cli_components/run_cmds.py | 14 +- metaflow/cli_components/step_cmd.py | 71 +++++---- metaflow/datastore/__init__.py | 1 + metaflow/datastore/flow_datastore.py | 17 +++ metaflow/datastore/spin_datastore.py | 187 +++++++++++++++++++++++ metaflow/plugins/cards/card_decorator.py | 1 + metaflow/runtime.py | 11 ++ metaflow/task.py | 39 ++++- 9 files changed, 343 insertions(+), 54 deletions(-) create mode 100644 metaflow/datastore/spin_datastore.py diff --git a/metaflow/cli.py b/metaflow/cli.py index f480d9078f4..d91369ca229 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -471,29 +471,46 @@ def start( # - null event logger, # - null monitor ctx.obj.is_spin = True - ctx.obj.spin_metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( - ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor - ) ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( flow=ctx.obj.flow, env=ctx.obj.environment ) ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] if datastore_root is None: - datastore_root = ctx.obj.spin_datastore_impl.get_datastore_root_from_config( + datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config( ctx.obj.echo ) - ctx.obj.spin_datastore_impl.datastore_root = datastore_root - ctx.obj.spin_flow_datastore = FlowDataStore( + ctx.obj.datastore_impl.datastore_root = datastore_root + ctx.obj.flow_datastore = FlowDataStore( ctx.obj.flow.name, ctx.obj.environment, # Same environment as run/resume - ctx.obj.spin_metadata, # spin metadata provider (no-op) + ctx.obj.metadata, # spin metadata provider (no-op) ctx.obj.event_logger, # null event logger ctx.obj.monitor, # null monitor - storage_impl=ctx.obj.spin_datastore_impl, + storage_impl=ctx.obj.datastore_impl, ) + # ctx.obj.spin_metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( + # ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + # ) + # ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + # if datastore_root is None: + # datastore_root = ctx.obj.spin_datastore_impl.get_datastore_root_from_config( + # ctx.obj.echo + # ) + # ctx.obj.spin_datastore_impl.datastore_root = datastore_root + # ctx.obj.spin_flow_datastore = FlowDataStore( + # ctx.obj.flow.name, + # ctx.obj.environment, # Same environment as run/resume + # ctx.obj.spin_metadata, # spin metadata provider (no-op) + # ctx.obj.event_logger, # null event logger + # ctx.obj.monitor, # null monitor + # storage_impl=ctx.obj.spin_datastore_impl, + # ) # Start event logger and monitor ctx.obj.event_logger.start() @@ -502,12 +519,12 @@ def start( ctx.obj.monitor.start() _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) - ctx.obj.effective_flow_datastore = ( - ctx.obj.spin_flow_datastore if ctx.obj.is_spin else ctx.obj.flow_datastore - ) - ctx.obj.effective_metadata = ( - ctx.obj.spin_metadata if ctx.obj.is_spin else ctx.obj.metadata - ) + # ctx.obj.effective_flow_datastore = ( + # ctx.obj.spin_flow_datastore if ctx.obj.is_spin else ctx.obj.flow_datastore + # ) + # ctx.obj.effective_metadata = ( + # ctx.obj.spin_metadata if ctx.obj.is_spin else ctx.obj.metadata + # ) decorators._init(ctx.obj.flow) @@ -517,8 +534,10 @@ def start( ctx.obj.flow, ctx.obj.graph, ctx.obj.environment, - ctx.obj.effective_flow_datastore, - ctx.obj.effective_metadata, + # ctx.obj.effective_flow_datastore, + # ctx.obj.effective_metadata, + ctx.obj.flow_datastore, + ctx.obj.metadata, ctx.obj.logger, echo, deco_options, @@ -540,7 +559,8 @@ def start( parameters.set_parameter_context( ctx.obj.flow.name, ctx.obj.echo, - ctx.obj.effective_flow_datastore, + # ctx.obj.effective_flow_datastore, + ctx.obj.flow_datastore, { k: ConfigValue(v) for k, v in ctx.obj.flow.__class__._flow_state.get( diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index dfa7914a6d1..52569446313 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -51,8 +51,11 @@ def before_run(obj, tags, decospecs): obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) # obj.environment.init_environment(obj.logger) + # decorators._init_step_decorators( + # obj.flow, obj.graph, obj.environment, obj.effective_flow_datastore, obj.logger + # ) decorators._init_step_decorators( - obj.flow, obj.graph, obj.environment, obj.effective_flow_datastore, obj.logger + obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger ) obj.metadata.add_sticky_tags(tags=tags) @@ -448,14 +451,19 @@ def spin( obj.echo( f"Spinning up step *{step_name}* locally using previous task pathspec *{spin_pathspec}*" ) + # Set spin_pathspec of flow_datastore + obj.flow_datastore.is_spin = True + # obj.flow_datastore.spin_pathspec = spin_pathspec obj.flow._set_constants(obj.graph, kwargs, obj.config_options) step_func = getattr(obj.flow, step_name) spin_runtime = SpinRuntime( obj.flow, obj.graph, - obj.effective_flow_datastore, - obj.effective_metadata, + obj.flow_datastore, + obj.metadata, + # obj.effective_flow_datastore, + # obj.effective_metadata, obj.environment, obj.package, obj.logger, diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index e662557fa84..7631e5057b6 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -269,7 +269,7 @@ def spin_step( echo = echo_always input_paths = decompress_list(input_paths) if input_paths else [] - echo( + echo_always( f"Spinning a task, *{step_name}* with previous task pathspec: {spin_pathspec}", fg="magenta", bold=False, @@ -277,41 +277,50 @@ def spin_step( # if namespace is not None: # namespace(namespace or None) - spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} - spin_artifacts = spin_artifacts.get("ARTIFACTS", {}) + # spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} + # spin_artifacts = spin_artifacts.get("ARTIFACTS", {}) - print(f"spin_artifacts: {spin_artifacts}") + # Set spin_pathspec of flow_datastore + ctx.obj.flow_datastore.is_spin = True + + print(f"ctx.obj.flow_datastore: {ctx.obj.flow_datastore}") + print(f"ctx.obj.flow_datastore.is_spin: {ctx.obj.flow_datastore.is_spin}") + print(f"ctx.obj.metadata: {ctx.obj.metadata}") + print(f"type(ctx.obj.metadata): {type(ctx.obj.metadata)}") + # print(f"spin_artifacts: {spin_artifacts}") print(f"spin_pathspec: {spin_pathspec}") print(f"input_paths: {input_paths}") + print(f"ctx.obj.flow: {ctx.obj.flow}") - # task = MetaflowTask( - # ctx.obj.flow, - # ctx.obj.effective_flow_datastore, # local datastore - # ctx.obj.effective_metadata, # local metadata provider - # ctx.obj.environment, # local environment - # ctx.obj.echo, - # ctx.obj.event_logger, # null logger - # ctx.obj.monitor, # null monitor - # None, # no unbounded foreach context - # ) - # echo( - # "I am here", + task = MetaflowTask( + ctx.obj.flow, + ctx.obj.flow_datastore, + ctx.obj.metadata, + ctx.obj.environment, + ctx.obj.echo, + ctx.obj.event_logger, + ctx.obj.monitor, + None, # no unbounded foreach context + ) + # echo_always( + # "I am here Shashank", # fg="magenta", # bold=False, # ) + print(f"task: {task}") + # + task.run_step( + step_name, + run_id, + task_id, + None, + input_paths, + split_index, + retry_count, + max_user_code_retries, + # spin_pathspec, + # skip_decorators, + # spin_artifacts, + ) - # task.run_step( - # step_name, - # run_id, - # task_id, - # None, - # input_paths, - # split_index, - # retry_count, - # max_user_code_retries, - # spin_pathspec, - # skip_decorators, - # spin_artifacts, - # ) - - echo(f"Time taken for the whole thing: {time.time() - start}") + echo_always(f"Time taken for the whole thing: {time.time() - start}") diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 793251b0cff..9740b23dd00 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,3 +2,4 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore +from .spin_datastore import SpinDataStore diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index 16318ed7693..a07edee7678 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -5,6 +5,7 @@ from .content_addressed_store import ContentAddressedStore from .task_datastore import TaskDataStore +from .spin_datastore import SpinDataStore class FlowDataStore(object): @@ -58,6 +59,8 @@ def __init__( self.metadata = metadata self.logger = event_logger self.monitor = monitor + # Set to None unless its a spin step + self.is_spin = False self.ca_store = ContentAddressedStore( self._storage_impl.path_join(self.flow_name, "data"), self._storage_impl @@ -213,6 +216,20 @@ def get_task_datastore( mode="r", allow_not_done=False, ): + # print(f"Is spin: {self.is_spin}") + if self.is_spin: + print( + f"Using SpinDataStore for {self.flow_name} {run_id} {step_name} {task_id}" + ) + # This is a spin step, so we need to use the spin datastore + tp = SpinDataStore( + flow_name=self.flow_name, + run_id=run_id, + step_name=step_name, + task_id=task_id, + ) + print(f"SpinDataStore created: {tp}") + return tp return TaskDataStore( self, run_id, diff --git a/metaflow/datastore/spin_datastore.py b/metaflow/datastore/spin_datastore.py new file mode 100644 index 00000000000..477ff1419ae --- /dev/null +++ b/metaflow/datastore/spin_datastore.py @@ -0,0 +1,187 @@ +# A *read-only* datastore that fetches artifacts through Metaflow’s +# Client API. All mutating helpers are implemented as cheap no-ops so +# that existing runtime paths which expect them won’t break. +from types import MethodType, FunctionType +from ..parameters import Parameter +from .task_datastore import ( + require_mode, +) + + +class SpinDataStore(object): + """ + Minimal, read-only replacement for TaskDataStore. + + Artefacts are lazily materialised through the Metaflow Client + (`metaflow.Task(...).data`). All write/side-effecting methods are + stubbed out. + """ + + def __init__(self, flow_name, run_id, step_name, task_id, mode="r"): + assert mode in ("r",) # write modes unsupported + self._mode = mode + self._flow_name = flow_name + self._run_id = run_id + self._step_name = step_name + self._task_id = task_id + self._is_done_set = True # always read-only + self._task = None + + # Public API + @property + def pathspec(self): + return f"{self._run_id}/{self._step_name}/{self._task_id}" + + @property + def run_id(self): + return self._run_id + + @property + def step_name(self): + return self._step_name + + @property + def task_id(self): + return self._task_id + + @property + def task(self): + if self._task is None: + # Metaflow client task handle + # from metaflow.client.core import get_metadata + from metaflow import Task + + # tp = get_metadata() + # print(f"tp: {tp}") + # print("LALALALA") + self._task = Task( + f"{self._flow_name}/{self._run_id}/{self._step_name}/{self._task_id}", + _namespace_check=False, + _current_metadata="mli@https://mliservice.dynprod.netflix.net:7002/api/v0", + ) + # print(f"_metaflow: {self._task._metaflow}") + return self._task + + # artifact access and iteration helpers + @require_mode(None) + def __getitem__(self, name): + print(f"I am in SpinDataStore __getitem__ for {name}") + try: + # Attempt to access the artifact directly from the task + # Used for `_foreach_stack`, `_graph_info`, etc. + print(f"Task: {self.task}") + print(f"Task ID: {self.task.id}") + print(f"_graph_info: {self.task['_graph_info']}") + res = self.task.__getitem__(name) + except Exception as e: + print(f"Exception accessing {name} directly from task: {e}") + print( + f"Failed to access {name} directly from task, falling back to artifacts." + ) + # If the direct access fails, fall back to the artifacts + try: + res = getattr(self.task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + return res + + @require_mode("r") + def __contains__(self, name): + return hasattr(self.task.artifacts, name) + + @require_mode("r") + def __iter__(self): + for name in self.task.artifacts: + yield name, getattr(self.task.artifacts, name).data + + @require_mode("r") + def keys_for_artifacts(self, names): + return [None for _ in names] + + @require_mode(None) + def load_artifacts(self, names): + for n in names: + yield n, getattr(self.task.artifacts, n).data + + # metadata & logging helpers + def load_metadata(self, names, add_attempt=True): + return {n: None for n in names} + + def has_metadata(self, name, add_attempt=True): + return False + + def get_log_location(self, *a, **k): + return None + + def load_logs(self, *a, **k): + return [] + + def load_log_legacy(self, *a, **k): + return b"" + + def get_log_size(self, *a, **k): + return 0 + + def get_legacy_log_size(self, *a, **k): + return 0 + + # write-side no-ops + def init_task(self, *a, **k): + pass + + def save_artifacts(self, *a, **k): + pass + + def save_metadata(self, *a, **k): + pass + + def _dangerous_save_metadata_post_done(self, *a, **k): + pass + + def save_logs(self, *a, **k): + pass + + def scrub_logs(self, *a, **k): + pass + + def clone(self, *a, **k): + pass + + def passdown_partial(self, *a, **k): + pass + + def persist(self, flow, *a, **k): + # Should we just do __setitem__ or __setattr__ here? + + print(f"flow: {flow}") + valid_artifacts = [] + for var in dir(flow): + if var.startswith("__") or var in flow._EPHEMERAL: + continue + # Skip over properties of the class (Parameters or class variables) + if hasattr(flow.__class__, var) and isinstance( + getattr(flow.__class__, var), property + ): + continue + + val = getattr(flow, var) + if not ( + isinstance(val, MethodType) + or isinstance(val, FunctionType) + or isinstance(val, Parameter) + ): + valid_artifacts.append((var, val)) + + print(f"valid_artifacts: {valid_artifacts}") + # Use __setattr__ to set the attributes on the SpinDataStore instance + for name, value in valid_artifacts: + # print(f"Setting {name} to {value}") + setattr(self, name, value) + # print("Calling persist on SpinDataStore, which is a no-op.") + pass + + def done(self, *a, **k): + pass diff --git a/metaflow/plugins/cards/card_decorator.py b/metaflow/plugins/cards/card_decorator.py index 7006997e5ee..e01fc4feb19 100644 --- a/metaflow/plugins/cards/card_decorator.py +++ b/metaflow/plugins/cards/card_decorator.py @@ -137,6 +137,7 @@ def step_init( self._flow_datastore = flow_datastore self._environment = environment self._logger = logger + self.card_options = None # We check for configuration options. We do this here before they are diff --git a/metaflow/runtime.py b/metaflow/runtime.py index e187d3753d3..57e0de5b7f5 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -1599,6 +1599,7 @@ def results(self): self._results_ds = self._flow_datastore.get_task_datastore( self.run_id, self.step, self.task_id ) + print("I am in results property") return self._results_ds @property @@ -1889,6 +1890,7 @@ def _launch(self): # print('running', args) cmdline = args.get_args() debug.subcommand_exec(cmdline) + print(f"Command: {cmdline}") return subprocess.Popen( cmdline, env=env, @@ -1970,6 +1972,7 @@ def terminate(self): # this shouldn't block, since terminate() is called only # after the poller has decided that the worker is dead returncode = self._proc.wait() + print(f"I am in terminate where returncode is {returncode}") # consume all remaining loglines # we set the file descriptor to be non-blocking, since @@ -1992,6 +1995,7 @@ def terminate(self): # perform any log collection. if not self.task.is_cloned: print("I am in terminate where task is not cloned") + print("I am saving metadata and logs") self.task.save_metadata( "runtime", { @@ -2000,7 +2004,9 @@ def terminate(self): "success": returncode == 0, }, ) + print(f"Saving metadata for task is done") if returncode: + print("I am in if returncode") if not self.killed: if returncode == -11: self.emit_log( @@ -2011,7 +2017,11 @@ def terminate(self): else: self.emit_log(b"Task failed.", self._stderr, system_msg=True) else: + print("I am in else of terminate") + print(f"flow: {self.task.flow}") + # print(f"dadada: {self.task.flow._foreach_num_splits}") num = self.task.results["_foreach_num_splits"] + print(f"I am in {num} splits") if num: self.task.log( "Foreach yields %d child steps." % num, @@ -2021,6 +2031,7 @@ def terminate(self): self.task.log( "Task finished successfully.", system_msg=True, pid=self._proc.pid ) + print("I am just before task.save_logs") self.task.save_logs( { "stdout": self._stdout.get_buffer(), diff --git a/metaflow/task.py b/metaflow/task.py index 414b7e54710..f8d37dd0888 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -81,6 +81,15 @@ def set_as_parameter(name, value): property(fget=lambda _, val=value: val, fset=_set_cls_var), ) + print("In _init_parameters") + print(f"parameter_ds: {parameter_ds}") + print(f"type(parameter_ds): {type(parameter_ds)}") + print(f"parameter_ds.pathspec: {parameter_ds.pathspec}") + try: + print(f"parameter_ds['_graph_info']: {parameter_ds['_graph_info']}") + except Exception: + print("I am in exception of parameter_ds['_graph_info']") + # overwrite Parameters in the flow object all_vars = [] for var, param in self.flow._get_parameters(): @@ -115,14 +124,17 @@ def property_setter( all_vars.append(var) # We also passdown _graph_info through the entire graph + print("Calling set_as_parameter") set_as_parameter( "_graph_info", lambda _, parameter_ds=parameter_ds: parameter_ds["_graph_info"], ) all_vars.append("_graph_info") - if passdown: + print("Calling passdown_partial") self.flow._datastore.passdown_partial(parameter_ds, all_vars) + print("Done with passdown_partial") + print(f"param_only_vars: {param_only_vars}") return param_only_vars def _init_data(self, run_id, join_type, input_paths): @@ -130,6 +142,7 @@ def _init_data(self, run_id, join_type, input_paths): # (via TaskDataStoreSet) only with more than 4 datastores, because # the baseline overhead of using the set is ~1.5s and each datastore # init takes ~200-300ms when run sequentially. + print(f"I am in in _init_data") if len(input_paths) > 4: prefetch_data_artifacts = None if join_type and join_type == "foreach": @@ -159,9 +172,13 @@ def _init_data(self, run_id, join_type, input_paths): ds_list = [] for input_path in input_paths: run_id, step_name, task_id = input_path.split("/") + print( + f"In _init_data: run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" + ) ds_list.append( self.flow_datastore.get_task_datastore(run_id, step_name, task_id) ) + print(f"ds_list: {ds_list}") if not ds_list: # this guards against errors in input paths raise MetaflowDataMissing( @@ -398,7 +415,7 @@ def run_step( raise MetaflowInternalError( "Too many task attempts (%d)! MAX_ATTEMPTS exceeded." % retry_count ) - + print(f"Just before metadata_tags") metadata_tags = ["attempt_id:{0}".format(retry_count)] metadata = [ @@ -438,6 +455,7 @@ def run_step( ) ) + print("Done until trace_id") step_func = getattr(self.flow, step_name) decorators = step_func.decorators @@ -450,15 +468,21 @@ def run_step( output = self.flow_datastore.get_task_datastore( run_id, step_name, task_id, attempt=retry_count, mode="w" ) + print(f"Done until output datastore") output.init_task() + print("Done until output.init_task") + print(f"input_paths after output.init_task: {input_paths}") if input_paths: # 2. initialize input datastores + print(f"input_paths: {input_paths}") inputs = self._init_data(run_id, join_type, input_paths) + print(f"inputs: {inputs}") # 3. initialize foreach state self._init_foreach(step_name, join_type, inputs, split_index) + print(f"Done until _init_foreach") # Add foreach stack to metadata of the task @@ -467,6 +491,7 @@ def run_step( if hasattr(self.flow, "_foreach_stack") and self.flow._foreach_stack else [] ) + print(f"foreach_stack: {foreach_stack}") foreach_stack_formatted = [] current_foreach_path_length = 0 @@ -533,6 +558,7 @@ def run_step( is_running=True, tags=self.metadata.sticky_tags, ) + print(f"Done until current._set_env") # 5. run task output.save_metadata( @@ -550,6 +576,7 @@ def run_step( # We also pass this context as part of the task payload to support implementations that # can't access the context directly + print(f"Setting task_payload") task_payload = { "run_id": run_id, "step_name": step_name, @@ -584,6 +611,7 @@ def run_step( # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. + print(f"Just before if join_type") if join_type: # Join step: @@ -612,6 +640,7 @@ def run_step( } ) else: + print(f"In linear step") # Linear step: # We are running with a single input context. # The context is embedded in the flow. @@ -624,10 +653,12 @@ def run_step( "inputs." % step_name ) self.flow._set_datastore(inputs[0]) + print(f"Just before if input_paths with input_paths: {input_paths}") if input_paths: # initialize parameters (if they exist) # We take Parameter values from the first input, # which is always safe since parameters are read-only + # print(f"self.flow._graph_info: {self.flow._graph_info}") current._update_env( { "parameter_names": self._init_parameters( @@ -636,6 +667,9 @@ def run_step( "graph_info": self.flow._graph_info, } ) + print(f"Just done with current._update_env") + print("HEHEHEHEHEHEH") + # print(f"Set flow datastore to inputs[0]: {inputs[0]}") for deco in decorators: deco.task_pre_step( step_name, @@ -727,6 +761,7 @@ def run_step( try: # persisting might fail due to unpicklable artifacts. output.persist(self.flow) + print("Done with output.persist") except Exception as ex: self.flow._task_ok = False raise ex From e6dedc4f9cd7bb0e8661cae0776597c1af578c19 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 9 Jun 2025 16:02:46 -0700 Subject: [PATCH 3/5] Working commit --- metaflow/cli.py | 42 +-- metaflow/cli_components/run_cmds.py | 52 ++-- metaflow/cli_components/step_cmd.py | 78 +++--- metaflow/datastore/__init__.py | 2 +- metaflow/datastore/datastore_set.py | 11 +- metaflow/datastore/flow_datastore.py | 68 +++-- metaflow/datastore/spin_datastore.py | 252 ++++++------------ metaflow/datastore/task_datastore.py | 4 + metaflow/metaflow_config.py | 1 + metaflow/plugins/__init__.py | 1 - metaflow/plugins/metadata_providers/spin.py | 87 ------ metaflow/runner/metaflow_runner.py | 206 +++++++++++++- metaflow/runtime.py | 111 +++++--- metaflow/task.py | 64 ++--- metaflow/util.py | 31 ++- .../unit/spin/artifacts/complex_dag_step_a.py | 1 + .../unit/spin/artifacts/complex_dag_step_d.py | 11 + test/unit/spin/complex_dag_flow.py | 115 ++++++++ test/unit/spin/merge_artifacts_flow.py | 63 +++++ test/unit/spin/test_spin.py | 84 ++++++ 20 files changed, 810 insertions(+), 474 deletions(-) delete mode 100644 metaflow/plugins/metadata_providers/spin.py create mode 100644 test/unit/spin/artifacts/complex_dag_step_a.py create mode 100644 test/unit/spin/artifacts/complex_dag_step_d.py create mode 100644 test/unit/spin/complex_dag_flow.py create mode 100644 test/unit/spin/merge_artifacts_flow.py create mode 100644 test/unit/spin/test_spin.py diff --git a/metaflow/cli.py b/metaflow/cli.py index d91369ca229..2d6df48688e 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -1,3 +1,4 @@ +import os import functools import inspect import sys @@ -23,6 +24,8 @@ DEFAULT_METADATA, DEFAULT_MONITOR, DEFAULT_PACKAGE_SUFFIXES, + DATASTORE_SYSROOT_SPIN, + DATASTORE_LOCAL_DIR, ) from .metaflow_current import current from metaflow.system import _system_monitor, _system_logger @@ -465,7 +468,7 @@ def start( # Override values for spin if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0]: # To minimize side-effects for spin, we will only use the following: - # - spin metadata provider, + # - local metadata provider, # - local datastore, # - local environment, # - null event logger, @@ -477,40 +480,21 @@ def start( ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "local"][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor ) ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] - if datastore_root is None: - datastore_root = ctx.obj.datastore_impl.get_datastore_root_from_config( - ctx.obj.echo - ) + # Set datastore_root to be DATASTORE_SYSROOT_SPIN if not provided + datastore_root = os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) ctx.obj.datastore_impl.datastore_root = datastore_root ctx.obj.flow_datastore = FlowDataStore( ctx.obj.flow.name, ctx.obj.environment, # Same environment as run/resume - ctx.obj.metadata, # spin metadata provider (no-op) + ctx.obj.metadata, # local metadata ctx.obj.event_logger, # null event logger ctx.obj.monitor, # null monitor storage_impl=ctx.obj.datastore_impl, ) - # ctx.obj.spin_metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "spin"][0]( - # ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor - # ) - # ctx.obj.spin_datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] - # if datastore_root is None: - # datastore_root = ctx.obj.spin_datastore_impl.get_datastore_root_from_config( - # ctx.obj.echo - # ) - # ctx.obj.spin_datastore_impl.datastore_root = datastore_root - # ctx.obj.spin_flow_datastore = FlowDataStore( - # ctx.obj.flow.name, - # ctx.obj.environment, # Same environment as run/resume - # ctx.obj.spin_metadata, # spin metadata provider (no-op) - # ctx.obj.event_logger, # null event logger - # ctx.obj.monitor, # null monitor - # storage_impl=ctx.obj.spin_datastore_impl, - # ) # Start event logger and monitor ctx.obj.event_logger.start() @@ -519,13 +503,6 @@ def start( ctx.obj.monitor.start() _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) - # ctx.obj.effective_flow_datastore = ( - # ctx.obj.spin_flow_datastore if ctx.obj.is_spin else ctx.obj.flow_datastore - # ) - # ctx.obj.effective_metadata = ( - # ctx.obj.spin_metadata if ctx.obj.is_spin else ctx.obj.metadata - # ) - decorators._init(ctx.obj.flow) # It is important to initialize flow decorators early as some of the @@ -534,8 +511,6 @@ def start( ctx.obj.flow, ctx.obj.graph, ctx.obj.environment, - # ctx.obj.effective_flow_datastore, - # ctx.obj.effective_metadata, ctx.obj.flow_datastore, ctx.obj.metadata, ctx.obj.logger, @@ -559,7 +534,6 @@ def start( parameters.set_parameter_context( ctx.obj.flow.name, ctx.obj.echo, - # ctx.obj.effective_flow_datastore, ctx.obj.flow_datastore, { k: ConfigValue(v) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 52569446313..e2e351d5ab2 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -14,7 +14,7 @@ from ..system import _system_logger from ..tagging_util import validate_tags -from ..util import get_latest_run_id, write_latest_run_id, get_latest_task_pathspec +from ..util import get_latest_run_id, write_latest_run_id def before_run(obj, tags, decospecs): @@ -397,7 +397,13 @@ def run( @click.command(help="Spins up a task for a given step from a previous run locally.") -@click.argument("spin-pathspec") +@click.argument("step-name") +@click.option( + "--spin-pathspec", + default=None, + type=str, + help="Use specified task pathspec from a previous run to spin up the step.", +) @click.option( "--skip-decorators/--no-skip-decorators", is_flag=True, @@ -414,6 +420,14 @@ def run( "artifact values. The artifact values will overwrite the default values of the artifacts used in " "the spun step.", ) +@click.option( + "--persist/--no-persist", + "persist", + default=True, + show_default=True, + help="Whether to persist the artifacts in the spun step. If set to false, the artifacts will not" + " be persisted and will not be available in the spun step's datastore.", +) @click.option( "--max-log-size", default=10, @@ -426,7 +440,9 @@ def run( @click.pass_obj def spin( obj, + step_name, spin_pathspec=None, + persist=True, artifacts_module=None, skip_decorators=False, max_log_size=None, @@ -435,25 +451,7 @@ def spin( **kwargs ): before_run(obj, [], []) - # Verify whether the user has provided step-name or spin-pathspec - if "/" in spin_pathspec: - # spin_pathspec is in the form of a task pathspec - if len(spin_pathspec.split("/")) != 4: - raise CommandException( - "Invalid spin-pathspec format. Expected format: {flow_name}/{run_id}/{step_name}/{task_id}" - ) - _, _, step_name, _ = spin_pathspec.split("/") - else: - # spin_pathspec is in the form of a step name - step_name = spin_pathspec - spin_pathspec = get_latest_task_pathspec(obj.flow.name, step_name) - - obj.echo( - f"Spinning up step *{step_name}* locally using previous task pathspec *{spin_pathspec}*" - ) - # Set spin_pathspec of flow_datastore - obj.flow_datastore.is_spin = True - # obj.flow_datastore.spin_pathspec = spin_pathspec + obj.echo(f"Spinning up step *{step_name}* locally for flow *{obj.flow.name}*") obj.flow._set_constants(obj.graph, kwargs, obj.config_options) step_func = getattr(obj.flow, step_name) @@ -462,8 +460,6 @@ def spin( obj.graph, obj.flow_datastore, obj.metadata, - # obj.effective_flow_datastore, - # obj.effective_metadata, obj.environment, obj.package, obj.logger, @@ -471,14 +467,21 @@ def spin( obj.event_logger, obj.monitor, step_func, + step_name, spin_pathspec, skip_decorators, artifacts_module, + persist, max_log_size * 1024 * 1024, ) write_latest_run_id(obj, spin_runtime.run_id) write_file(run_id_file, spin_runtime.run_id) + + # datastore_root is os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) + # We only neeed the root for the metadata, i.e. the portion before DATASTORE_LOCAL_DIR + datastore_root = spin_runtime._flow_datastore._storage_impl.datastore_root + spin_metadata_root = datastore_root.rsplit("/", 1)[0] spin_runtime.execute() if runner_attribute_file: @@ -489,7 +492,8 @@ def spin( "step_name": step_name, "run_id": spin_runtime.run_id, "flow_name": obj.flow.name, - "metadata": f"{obj.spin_metadata.__class__.TYPE}@{obj.spin_metadata.__class__.INFO}", + # Store metadata in a format that can be used by the Runner API + "metadata": f"{obj.metadata.__class__.TYPE}@{spin_metadata_root}", }, f, ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index 7631e5057b6..a5261fb0c5b 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -192,6 +192,12 @@ def step( required=True, help="Task ID for the step that's about to be spun", ) +@click.option( + "--spin-metadata", + default=None, + show_default=True, + help="Spin metadata provider to be used for fetching artifacts/data for the input datastore", +) @click.option( "--spin-pathspec", default=None, @@ -221,16 +227,21 @@ def step( ) @click.option( "--namespace", - "namespace", + "opt_namespace", default=None, help="Change namespace from the default (your username) to the specified tag.", ) @click.option( - "--skip-decorators/--no-skip-decorators", - is_flag=True, - default=False, + "--whitelist-decorators", + help="A comma-separated list of whitelisted decorators to use for the spin step", +) +@click.option( + "--persist/--no-persist", + "persist", + default=True, show_default=True, - help="Skip decorators attached to the step.", + help="Whether to persist the artifacts in the spun step. If set to false, the artifacts will not" + " be persisted and will not be available in the spun step's datastore.", ) @click.option( "--artifacts-module", @@ -247,14 +258,16 @@ def spin_step( step_name, run_id=None, task_id=None, + spin_metadata=None, spin_pathspec=None, input_paths=None, split_index=None, retry_count=None, max_user_code_retries=None, - namespace=None, - skip_decorators=False, + opt_namespace=None, + whitelist_decorators=None, artifacts_module=None, + persist=True, ): import time @@ -262,53 +275,37 @@ def spin_step( import sys if ctx.obj.is_quiet: - print("Echo dev null") echo = echo_dev_null else: - print("Echo always") echo = echo_always - input_paths = decompress_list(input_paths) if input_paths else [] - echo_always( - f"Spinning a task, *{step_name}* with previous task pathspec: {spin_pathspec}", - fg="magenta", - bold=False, - ) - # if namespace is not None: - # namespace(namespace or None) - - # spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} - # spin_artifacts = spin_artifacts.get("ARTIFACTS", {}) + if opt_namespace is not None: + namespace(opt_namespace or None) - # Set spin_pathspec of flow_datastore - ctx.obj.flow_datastore.is_spin = True + input_paths = decompress_list(input_paths) if input_paths else [] + # echo_always( + # f"Spinning a task, *{step_name}* with previous task pathspec: {spin_pathspec}", + # fg="magenta", + # bold=True, + # ) - print(f"ctx.obj.flow_datastore: {ctx.obj.flow_datastore}") - print(f"ctx.obj.flow_datastore.is_spin: {ctx.obj.flow_datastore.is_spin}") - print(f"ctx.obj.metadata: {ctx.obj.metadata}") - print(f"type(ctx.obj.metadata): {type(ctx.obj.metadata)}") - # print(f"spin_artifacts: {spin_artifacts}") - print(f"spin_pathspec: {spin_pathspec}") - print(f"input_paths: {input_paths}") - print(f"ctx.obj.flow: {ctx.obj.flow}") + whitelist_decorators = ( + decompress_list(whitelist_decorators) if whitelist_decorators else [] + ) + spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} task = MetaflowTask( ctx.obj.flow, ctx.obj.flow_datastore, ctx.obj.metadata, ctx.obj.environment, - ctx.obj.echo, + echo, ctx.obj.event_logger, ctx.obj.monitor, None, # no unbounded foreach context + spin_metadata=spin_metadata, + spin_artifacts=spin_artifacts, ) - # echo_always( - # "I am here Shashank", - # fg="magenta", - # bold=False, - # ) - print(f"task: {task}") - # task.run_step( step_name, run_id, @@ -318,9 +315,8 @@ def spin_step( split_index, retry_count, max_user_code_retries, - # spin_pathspec, - # skip_decorators, - # spin_artifacts, + whitelist_decorators, + persist, ) echo_always(f"Time taken for the whole thing: {time.time() - start}") diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 9740b23dd00..65bb33b0eb9 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,4 +2,4 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore -from .spin_datastore import SpinDataStore +from .spin_datastore import SpinTaskDatastore diff --git a/metaflow/datastore/datastore_set.py b/metaflow/datastore/datastore_set.py index f60642de73f..403eadbfd4d 100644 --- a/metaflow/datastore/datastore_set.py +++ b/metaflow/datastore/datastore_set.py @@ -21,9 +21,18 @@ def __init__( pathspecs=None, prefetch_data_artifacts=None, allow_not_done=False, + join_type=None, + spin_metadata=None, + spin_artifacts=None, ): self.task_datastores = flow_datastore.get_task_datastores( - run_id, steps=steps, pathspecs=pathspecs, allow_not_done=allow_not_done + run_id, + steps=steps, + pathspecs=pathspecs, + allow_not_done=allow_not_done, + join_type=join_type, + spin_metadata=spin_metadata, + spin_artifacts=spin_artifacts, ) if prefetch_data_artifacts: diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index a07edee7678..3de7a90f087 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -5,7 +5,7 @@ from .content_addressed_store import ContentAddressedStore from .task_datastore import TaskDataStore -from .spin_datastore import SpinDataStore +from .spin_datastore import SpinTaskDatastore class FlowDataStore(object): @@ -59,8 +59,6 @@ def __init__( self.metadata = metadata self.logger = event_logger self.monitor = monitor - # Set to None unless its a spin step - self.is_spin = False self.ca_store = ContentAddressedStore( self._storage_impl.path_join(self.flow_name, "data"), self._storage_impl @@ -79,6 +77,9 @@ def get_task_datastores( attempt=None, include_prior=False, mode="r", + join_type=None, + spin_metadata=None, + spin_artifacts=None, ): """ Return a list of TaskDataStore for a subset of the tasks. @@ -109,6 +110,16 @@ def get_task_datastores( If True, returns all attempts up to and including attempt. mode : str, default "r" Mode to initialize the returned TaskDataStores in. + join_type : str, optional + If specified, the join type for the task. This is used to determine + the user specified artifacts for the task in case of a spin task. + spin_metadata : str, optional + The metadata provider in case of a spin task. If provided, the + returned TaskDataStore will be a SpinTaskDatastore instead of a + TaskDataStore. + spin_artifacts : Dict[str, Any], optional + Artifacts provided by user that can override the artifacts fetched via the + spin pathspec. Returns ------- @@ -201,7 +212,18 @@ def get_task_datastores( else (latest_started_attempts & done_attempts) ) latest_to_fetch = [ - (v[0], v[1], v[2], v[3], data_objs.get(v), mode, allow_not_done) + ( + v[0], + v[1], + v[2], + v[3], + data_objs.get(v), + mode, + allow_not_done, + join_type, + spin_metadata, + spin_artifacts, + ) for v in latest_to_fetch ] return list(itertools.starmap(self.get_task_datastore, latest_to_fetch)) @@ -215,21 +237,32 @@ def get_task_datastore( data_metadata=None, mode="r", allow_not_done=False, + join_type=None, + spin_metadata=None, + spin_artifacts=None, + persist=True, ): - # print(f"Is spin: {self.is_spin}") - if self.is_spin: - print( - f"Using SpinDataStore for {self.flow_name} {run_id} {step_name} {task_id}" - ) - # This is a spin step, so we need to use the spin datastore - tp = SpinDataStore( - flow_name=self.flow_name, - run_id=run_id, - step_name=step_name, - task_id=task_id, + if spin_metadata is not None: + # In spin step subprocess, use SpinTaskDatastore for accessing artifacts + if join_type is not None: + # If join_type is specified, we need to use the artifacts corresponding + # to that particular join index, specified by the parent task pathspec. + print(f"Spin Artifacts: {spin_artifacts}") + print(f"pathspec: {run_id}/{step_name}/{task_id}") + print( + f"Spin Artifacts tp: {spin_artifacts.get(f'{run_id}/{step_name}/{task_id}')}" + ) + spin_artifacts = spin_artifacts.get( + f"{run_id}/{step_name}/{task_id}", {} + ) + return SpinTaskDatastore( + self.flow_name, + run_id, + step_name, + task_id, + spin_metadata, + spin_artifacts, ) - print(f"SpinDataStore created: {tp}") - return tp return TaskDataStore( self, run_id, @@ -239,6 +272,7 @@ def get_task_datastore( data_metadata=data_metadata, mode=mode, allow_not_done=allow_not_done, + persist=persist, ) def save_data(self, data_iter, len_hint=0): diff --git a/metaflow/datastore/spin_datastore.py b/metaflow/datastore/spin_datastore.py index 477ff1419ae..1f7e215543f 100644 --- a/metaflow/datastore/spin_datastore.py +++ b/metaflow/datastore/spin_datastore.py @@ -1,187 +1,109 @@ -# A *read-only* datastore that fetches artifacts through Metaflow’s -# Client API. All mutating helpers are implemented as cheap no-ops so -# that existing runtime paths which expect them won’t break. -from types import MethodType, FunctionType -from ..parameters import Parameter -from .task_datastore import ( - require_mode, -) - - -class SpinDataStore(object): - """ - Minimal, read-only replacement for TaskDataStore. - - Artefacts are lazily materialised through the Metaflow Client - (`metaflow.Task(...).data`). All write/side-effecting methods are - stubbed out. - """ - - def __init__(self, flow_name, run_id, step_name, task_id, mode="r"): - assert mode in ("r",) # write modes unsupported - self._mode = mode - self._flow_name = flow_name - self._run_id = run_id - self._step_name = step_name - self._task_id = task_id - self._is_done_set = True # always read-only +from typing import Dict, Any +from .task_datastore import require_mode + + +class SpinTaskDatastore(object): + def __init__( + self, + flow_name: str, + run_id: str, + step_name: str, + task_id: str, + spin_metadata: str, + spin_artifacts: Dict[str, Any], + ): + """ + SpinTaskDatastore is a datastore for a task that is used to retrieve + artifacts and attributes for a spin step. It uses the task pathspec + from a previous execution of the step to access the artifacts and attributes. + + Parameters: + ----------- + flow_name : str + Name of the flow + run_id : str + Run ID of the flow + step_name : str + Name of the step + task_id : str + Task ID of the step + spin_metadata : str + Metadata for the spin task, typically a URI to the metadata service. + spin_artifacts : Dict[str, Any] + User provided artifacts that are to be used in the spin task. This is a dictionary + where keys are artifact names and values are the actual data or metadata. + """ + self.flow_name = flow_name + self.run_id = run_id + self.step_name = step_name + self.task_id = task_id + self.spin_metadata = spin_metadata + self.spin_artifacts = spin_artifacts self._task = None - # Public API - @property - def pathspec(self): - return f"{self._run_id}/{self._step_name}/{self._task_id}" - - @property - def run_id(self): - return self._run_id + # Update _objects and _info in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = {} + self._info = {} - @property - def step_name(self): - return self._step_name - - @property - def task_id(self): - return self._task_id + for artifact in self.task.artifacts: + self._objects[artifact.id] = artifact.sha + # Fulfills the contract for _info: name -> metadata + self._info[artifact.id] = { + "size": artifact.size, + "encoding": artifact._object["content_type"], + } @property def task(self): if self._task is None: - # Metaflow client task handle - # from metaflow.client.core import get_metadata + # Initialize the metaflow from metaflow import Task - # tp = get_metadata() - # print(f"tp: {tp}") - # print("LALALALA") + # print(f"Setting task with metadata: {self.spin_metadata} and pathspec: {self.run_id}/{self.step_name}/{self.task_id}") self._task = Task( - f"{self._flow_name}/{self._run_id}/{self._step_name}/{self._task_id}", + f"{self.flow_name}/{self.run_id}/{self.step_name}/{self.task_id}", _namespace_check=False, - _current_metadata="mli@https://mliservice.dynprod.netflix.net:7002/api/v0", + # We need to get this form the task pathspec somehow + _current_metadata=self.spin_metadata, ) - # print(f"_metaflow: {self._task._metaflow}") return self._task - # artifact access and iteration helpers @require_mode(None) def __getitem__(self, name): - print(f"I am in SpinDataStore __getitem__ for {name}") try: - # Attempt to access the artifact directly from the task - # Used for `_foreach_stack`, `_graph_info`, etc. - print(f"Task: {self.task}") - print(f"Task ID: {self.task.id}") - print(f"_graph_info: {self.task['_graph_info']}") - res = self.task.__getitem__(name) - except Exception as e: - print(f"Exception accessing {name} directly from task: {e}") - print( - f"Failed to access {name} directly from task, falling back to artifacts." - ) - # If the direct access fails, fall back to the artifacts + # Check if it's an artifact in the spin_artifacts + return self.spin_artifacts[name] + except Exception: try: - res = getattr(self.task.artifacts, name).data - except AttributeError: - raise AttributeError( - f"Attribute '{name}' not found in the previous execution of the task for " - f"`{self.step_name}`." - ) - return res - - @require_mode("r") - def __contains__(self, name): - return hasattr(self.task.artifacts, name) - - @require_mode("r") - def __iter__(self): - for name in self.task.artifacts: - yield name, getattr(self.task.artifacts, name).data - - @require_mode("r") - def keys_for_artifacts(self, names): - return [None for _ in names] + # Check if it's an attribute of the task + # _foreach_stack, _foreach_index, ... + return self.task.__getitem__(name).data + except Exception: + # If not an attribute, check if it's an artifact + try: + return getattr(self.task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) @require_mode(None) - def load_artifacts(self, names): - for n in names: - yield n, getattr(self.task.artifacts, n).data - - # metadata & logging helpers - def load_metadata(self, names, add_attempt=True): - return {n: None for n in names} - - def has_metadata(self, name, add_attempt=True): - return False - - def get_log_location(self, *a, **k): - return None - - def load_logs(self, *a, **k): - return [] + def is_none(self, name): + val = self.__getitem__(name) + return val is None - def load_log_legacy(self, *a, **k): - return b"" - - def get_log_size(self, *a, **k): - return 0 - - def get_legacy_log_size(self, *a, **k): - return 0 - - # write-side no-ops - def init_task(self, *a, **k): - pass - - def save_artifacts(self, *a, **k): - pass - - def save_metadata(self, *a, **k): - pass - - def _dangerous_save_metadata_post_done(self, *a, **k): - pass - - def save_logs(self, *a, **k): - pass - - def scrub_logs(self, *a, **k): - pass - - def clone(self, *a, **k): - pass - - def passdown_partial(self, *a, **k): - pass - - def persist(self, flow, *a, **k): - # Should we just do __setitem__ or __setattr__ here? - - print(f"flow: {flow}") - valid_artifacts = [] - for var in dir(flow): - if var.startswith("__") or var in flow._EPHEMERAL: - continue - # Skip over properties of the class (Parameters or class variables) - if hasattr(flow.__class__, var) and isinstance( - getattr(flow.__class__, var), property - ): - continue - - val = getattr(flow, var) - if not ( - isinstance(val, MethodType) - or isinstance(val, FunctionType) - or isinstance(val, Parameter) - ): - valid_artifacts.append((var, val)) - - print(f"valid_artifacts: {valid_artifacts}") - # Use __setattr__ to set the attributes on the SpinDataStore instance - for name, value in valid_artifacts: - # print(f"Setting {name} to {value}") - setattr(self, name, value) - # print("Calling persist on SpinDataStore, which is a no-op.") - pass + @require_mode(None) + def __contains__(self, name): + try: + _ = self.__getitem__(name) + return True + except AttributeError: + return False - def done(self, *a, **k): - pass + @require_mode(None) + def items(self): + if self._objects: + return self._objects.items() + return {} diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index 7691862fc2a..f5e63a27410 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -98,6 +98,7 @@ def __init__( data_metadata=None, mode="r", allow_not_done=False, + persist=True, ): self._storage_impl = flow_datastore._storage_impl @@ -114,6 +115,7 @@ def __init__( self._attempt = attempt self._metadata = flow_datastore.metadata self._parent = flow_datastore + self._persist = persist # The GZIP encodings are for backward compatibility self._encodings = {"pickle-v2", "gzip+pickle-v2"} @@ -682,6 +684,8 @@ def persist(self, flow): flow : FlowSpec Flow to persist """ + if not self._persist: + return if flow._datastore: self._objects.update(flow._datastore._objects) diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index fc38e5c727d..4c030d76945 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -64,6 +64,7 @@ # Datastore configuration ### DATASTORE_SYSROOT_LOCAL = from_conf("DATASTORE_SYSROOT_LOCAL") +DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN", "/tmp/metaflow") # S3 bucket and prefix to store artifacts for 's3' datastore. DATASTORE_SYSROOT_S3 = from_conf("DATASTORE_SYSROOT_S3") # Azure Blob Storage container and blob prefix diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index 0ed913763e9..f74a70726ba 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -82,7 +82,6 @@ METADATA_PROVIDERS_DESC = [ ("service", ".metadata_providers.service.ServiceMetadataProvider"), ("local", ".metadata_providers.local.LocalMetadataProvider"), - ("spin", ".metadata_providers.spin.SpinMetadataProvider"), ] # Add datastore here diff --git a/metaflow/plugins/metadata_providers/spin.py b/metaflow/plugins/metadata_providers/spin.py deleted file mode 100644 index 828f8b9b31b..00000000000 --- a/metaflow/plugins/metadata_providers/spin.py +++ /dev/null @@ -1,87 +0,0 @@ -""" -An implementation of a null metadata provider for Spin steps. -This provider does not store any metadata and is used when -we want to ensure that there are no side effects from the -metadata provider. - -""" - -import time - -from metaflow.metadata_provider import MetadataProvider -from typing import List - - -class SpinMetadataProvider(MetadataProvider): - TYPE = "spin" - - def __init__(self, environment, flow, event_logger, monitor): - super(SpinMetadataProvider, self).__init__( - environment, flow, event_logger, monitor - ) - # No provider-specific initialization needed for a null provider - - def version(self): - """Return provider type as version.""" - return self.TYPE - - def new_run_id(self, tags=None, sys_tags=None): - # We currently just use the timestamp to create an ID. We can be reasonably certain - # that it is unique and this makes it possible to do without coordination or - # reliance on POSIX locks in the filesystem. - run_id = "%d" % (time.time() * 1e6) - return run_id - - def register_run_id(self, run_id, tags=None, sys_tags=None): - """No-op register_run_id. Indicates no action taken.""" - return False - - def new_task_id(self, run_id, step_name, tags=None, sys_tags=None): - self._task_id_seq += 1 - task_id = str(self._task_id_seq) - return task_id - - def register_task_id( - self, run_id, step_name, task_id, attempt=0, tags=None, sys_tags=None - ): - """No-op register_task_id. Indicates no action taken.""" - return False - - def register_data_artifacts( - self, run_id, step_name, task_id, attempt_id, artifacts - ): - """No-op register_data_artifacts.""" - pass - - def register_metadata(self, run_id, step_name, task_id, metadata): - """No-op register_metadata.""" - pass - - @classmethod - def _mutate_user_tags_for_run( - cls, flow_id, run_id, tags_to_add=None, tags_to_remove=None - ): - """No-op _mutate_user_tags_for_run. Returns an empty set of tags.""" - return frozenset() - - @classmethod - def filter_tasks_by_metadata( - cls, - flow_name: str, - run_id: str, - step_name: str, - field_name: str, - pattern: str, - ) -> List[str]: - """No-op filter_tasks_by_metadata. Returns an empty list.""" - return [] - - @classmethod - def _get_object_internal( - cls, obj_type, obj_order, sub_type, sub_order, filters, attempt, *args - ): - """ - No-op _get_object_internal. Returns an empty list, - which MetadataProvider.get_object will interpret as 'not found'. - """ - return [] diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 2de20e0a1e0..9e69342d14a 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -17,27 +17,33 @@ from .subprocess_manager import CommandManager, SubprocessManager -class ExecutingRun(object): +class ExecutingProcess(object): """ - This class contains a reference to a `metaflow.Run` object representing - the currently executing or finished run, as well as metadata related - to the process. + This is a base class for `ExecutingRun` and `ExecutingTask` classes. + The `ExecutingRun` and `ExecutingTask` classes are returned by methods + in `Runner` and `NBRunner`, and they are subclasses of this class. - `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not - meant to be instantiated directly. + The `ExecutingRun` class for instance contains a reference to a `metaflow.Run` + object representing the currently executing or finished run, as well as the metadata + related to the process. + + Similarly, the `ExecutingTask` class contains a reference to a `metaflow.Task` + object representing the currently executing or finished task, as well as the metadata + related to the process. + + This class or its subclasses are not meant to be instantiated directly. The class + works as a context manager, allowing you to use a pattern like: - This class works as a context manager, allowing you to use a pattern like ```python with Runner(...).run() as running: ... ``` - Note that you should use either this object as the context manager or - `Runner`, not both in a nested manner. + + Note that you should use either this object as the context manager or `Runner`, not both + in a nested manner. """ - def __init__( - self, runner: "Runner", command_obj: CommandManager, run_obj: Run - ) -> None: + def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but instead user Runner.run() @@ -53,7 +59,6 @@ def __init__( """ self.runner = runner self.command_obj = command_obj - self.run = run_obj def __enter__(self) -> "ExecutingRun": return self @@ -189,6 +194,76 @@ async def stream_log( yield position, line +class ExecutingTask(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Task` object representing + the currently executing or finished task, as well as metadata related + to the process. + `ExecutingTask` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).spin() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, task_obj: "metaflow.Task" + ) -> None: + """ + Create a new ExecutingTask -- this should not be done by the user directly but + instead user Runner.spin() + Parameters + ---------- + runner : Runner + Parent runner for this task. + command_obj : CommandManager + CommandManager containing the subprocess executing this task. + task_obj : Task + Task object corresponding to this task. + """ + super().__init__(runner, command_obj) + self.task = task_obj + + +class ExecutingRun(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Run` object representing + the currently executing or finished run, as well as metadata related + to the process. + `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).run() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, run_obj: Run + ) -> None: + """ + Create a new ExecutingRun -- this should not be done by the user directly but + instead user Runner.run() + Parameters + ---------- + runner : Runner + Parent runner for this run. + command_obj : CommandManager + CommandManager containing the subprocess executing this run. + run_obj : Run + Run object corresponding to this run. + """ + super().__init__(runner, command_obj) + self.run = run_obj + + class RunnerMeta(type): def __new__(mcs, name, bases, dct): cls = super().__new__(mcs, name, bases, dct) @@ -257,7 +332,7 @@ def __init__( env: Optional[Dict[str, str]] = None, cwd: Optional[str] = None, file_read_timeout: int = 3600, - **kwargs + **kwargs, ): # these imports are required here and not at the top # since they interfere with the user defined Parameters @@ -373,6 +448,73 @@ def run(self, **kwargs) -> ExecutingRun: return self.__get_executing_run(attribute_file_fd, command_obj) + def __get_executing_task(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + + command_obj.sync_wait() + + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + async def __async_get_executing_task(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout + ) + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + def spin(self, step_name, **kwargs): + """ + Blocking spin execution of the run. + This method will wait until the spun run has completed execution. + Parameters + ---------- + step_name : str + The name of the step to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + Returns + ------- + ExecutingTask + ExecutingTask containing the results of the spun task. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_task(attribute_file_fd, command_obj) + def resume(self, **kwargs) -> ExecutingRun: """ Blocking resume execution of the run. @@ -468,6 +610,42 @@ async def async_resume(self, **kwargs) -> ExecutingRun: return await self.__async_get_executing_run(attribute_file_fd, command_obj) + async def async_spin(self, step_name, spin_pathspec, **kwargs) -> ExecutingTask: + """ + Non-blocking spin execution of the run. + This method will return as soon as the spun task has launched. + + Note that this method is asynchronous and needs to be `await`ed. + + Parameters + ---------- + step_name : str + The name of the step to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + + Returns + ------- + ExecutingTask + ExecutingTask representing the spun task that was started. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = await self.spm.async_run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + ) + command_obj = self.spm.get(pid) + + return await self.__async_get_executing_task(attribute_file_fd, command_obj) + def __exit__(self, exc_type, exc_value, traceback): self.spm.cleanup() diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 57e0de5b7f5..5a6c797e34a 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -36,7 +36,7 @@ from .decorators import flow_decorators from .flowspec import _FlowState from .mflog import mflog, RUNTIME_LOG_SOURCE -from .util import to_unicode, compress_list, unicode_type +from .util import to_unicode, compress_list, unicode_type, get_latest_task_pathspec from .clone_util import clone_task_helper from .unbounded_foreach import ( CONTROL_TASK_TAG, @@ -87,9 +87,11 @@ def __init__( event_logger, monitor, step_func, + step_name, spin_pathspec, skip_decorators=False, artifacts_module=None, + persist=True, max_log_size=MAX_LOG_SIZE, ): from metaflow import Task @@ -106,8 +108,25 @@ def __init__( self._monitor = monitor self._step_func = step_func + + # Verify whether the use has provided step-name or spin-pathspec + if not spin_pathspec: + task = get_latest_task_pathspec(flow.name, step_name) + else: + # The user already provided a spin-pathspec, verify if its valid + try: + task = Task(spin_pathspec, _namespace_check=False) + except Exception: + raise MetaflowException( + f"Invalid spin-pathspec: {spin_pathspec} for step: {step_name}" + ) + + spin_pathspec = task.pathspec + spin_metadata = task._metaflow.metadata.metadata_str() + self._spin_metadata = spin_metadata self._spin_pathspec = spin_pathspec - self._spin_task = Task(self._spin_pathspec, _namespace_check=False) + self._persist = persist + self._spin_task = task self._input_paths = None self._split_index = None self._whitelist_decorators = None @@ -120,17 +139,17 @@ def __init__( # Create a new run_id for the spin task self.run_id = self._metadata.new_run_id() for deco in self.whitelist_decorators: - # print("-" * 100) deco.runtime_init(flow, graph, package, self.run_id) @property def split_index(self): - if self._split_index: - return self._split_index + """ + Returns the split index, caching the result after the first access. + """ + if self._split_index is None: + self._split_index = getattr(self._spin_task, "index", None) - if hasattr(self._spin_task, "index"): - self._split_index = self._spin_task.index - return self._split_index + return self._split_index @property def input_paths(self): @@ -202,9 +221,7 @@ def execute(self): else: self._config_file_name = None - # print("I am before _new_task") self.task = self._new_task(self._step_func.name, self.input_paths) - # print(f"I am after _new_task") try: self._launch_and_monitor_task() except Exception as ex: @@ -216,18 +233,17 @@ def execute(self): deco.runtime_finished(exception) def _launch_and_monitor_task(self): - # print("I am in _launch_and_monitor_task") worker = Worker( self.task, self._max_log_size, self._config_file_name, + spin_metadata=self._spin_metadata, spin_pathspec=self._spin_pathspec, - skip_decorators=self._skip_decorators, + whitelist_decorators=self.whitelist_decorators, artifacts_module=self._artifacts_module, + persist=self._persist, ) - # print("Worker created") - poll = procpoll.make_poll() fds = worker.fds() for fd in fds: @@ -1599,7 +1615,6 @@ def results(self): self._results_ds = self._flow_datastore.get_task_datastore( self.run_id, self.step, self.task_id ) - print("I am in results property") return self._results_ds @property @@ -1691,12 +1706,20 @@ class CLIArgs(object): """ def __init__( - self, task, spin_pathspec=None, skip_decorators=False, artifacts_module=None + self, + task, + spin_metadata=None, + spin_pathspec=None, + whitelist_decorators=None, + artifacts_module=None, + persist=True, ): self.task = task + self.spin_metadata = spin_metadata self.spin_pathspec = spin_pathspec - self.skip_decorators = skip_decorators + self.whitelist_decorators = whitelist_decorators self.artifacts_module = artifacts_module + self.persist = persist self.entrypoint = list(task.entrypoint) self.top_level_options = { "quiet": True, @@ -1754,6 +1777,8 @@ def spin_args(self): self.commands = ["spin-step"] self.command_args = [self.task.step] + whitelist_decos = [deco.name for deco in self.whitelist_decorators] + self.command_options = { "run-id": self.task.run_id, "task-id": self.task.task_id, @@ -1762,10 +1787,13 @@ def spin_args(self): "retry-count": self.task.retries, "max-user-code-retries": self.task.user_code_retries, "namespace": get_namespace() or "", + "spin-metadata": self.spin_metadata, "spin-pathspec": self.spin_pathspec, - "skip-decorators": self.skip_decorators, + "whitelist-decorators": compress_list(whitelist_decos), "artifacts-module": self.artifacts_module, } + if self.persist: + self.command_options["persist"] = True self.env = {} def get_args(self): @@ -1811,15 +1839,19 @@ def __init__( task, max_logs_size, config_file_name, + spin_metadata=None, spin_pathspec=None, - skip_decorators=False, + whitelist_decorators=None, artifacts_module=None, + persist=True, ): self.task = task self._config_file_name = config_file_name - self.spin_pathspec = spin_pathspec - self.skip_decorators = skip_decorators - self.artifacts_module = artifacts_module + self._spin_metadata = spin_metadata + self._spin_pathspec = spin_pathspec + self._whitelist_decorators = whitelist_decorators + self._artifacts_module = artifacts_module + self._persist = persist self._proc = self._launch() if task.retries > task.user_code_retries: @@ -1853,9 +1885,11 @@ def __init__( def _launch(self): args = CLIArgs( self.task, - spin_pathspec=self.spin_pathspec, - skip_decorators=self.skip_decorators, - artifacts_module=self.artifacts_module, + spin_metadata=self._spin_metadata, + spin_pathspec=self._spin_pathspec, + whitelist_decorators=self._whitelist_decorators, + artifacts_module=self._artifacts_module, + persist=self._persist, ) env = dict(os.environ) @@ -1890,7 +1924,7 @@ def _launch(self): # print('running', args) cmdline = args.get_args() debug.subcommand_exec(cmdline) - print(f"Command: {cmdline}") + # print(f"Command: {cmdline}") return subprocess.Popen( cmdline, env=env, @@ -1972,7 +2006,6 @@ def terminate(self): # this shouldn't block, since terminate() is called only # after the poller has decided that the worker is dead returncode = self._proc.wait() - print(f"I am in terminate where returncode is {returncode}") # consume all remaining loglines # we set the file descriptor to be non-blocking, since @@ -1994,8 +2027,6 @@ def terminate(self): # Return early if the task is cloned since we don't want to # perform any log collection. if not self.task.is_cloned: - print("I am in terminate where task is not cloned") - print("I am saving metadata and logs") self.task.save_metadata( "runtime", { @@ -2004,9 +2035,7 @@ def terminate(self): "success": returncode == 0, }, ) - print(f"Saving metadata for task is done") if returncode: - print("I am in if returncode") if not self.killed: if returncode == -11: self.emit_log( @@ -2017,21 +2046,17 @@ def terminate(self): else: self.emit_log(b"Task failed.", self._stderr, system_msg=True) else: - print("I am in else of terminate") - print(f"flow: {self.task.flow}") - # print(f"dadada: {self.task.flow._foreach_num_splits}") - num = self.task.results["_foreach_num_splits"] - print(f"I am in {num} splits") - if num: - self.task.log( - "Foreach yields %d child steps." % num, - system_msg=True, - pid=self._proc.pid, - ) + if not self._spin_pathspec: + num = self.task.results["_foreach_num_splits"] + if num: + self.task.log( + "Foreach yields %d child steps." % num, + system_msg=True, + pid=self._proc.pid, + ) self.task.log( "Task finished successfully.", system_msg=True, pid=self._proc.pid ) - print("I am just before task.save_logs") self.task.save_logs( { "stdout": self._stdout.get_buffer(), diff --git a/metaflow/task.py b/metaflow/task.py index f8d37dd0888..53315d8f0fc 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -47,6 +47,8 @@ def __init__( event_logger, monitor, ubf_context, + spin_metadata=None, + spin_artifacts=None, ): self.flow = flow self.flow_datastore = flow_datastore @@ -56,6 +58,8 @@ def __init__( self.event_logger = event_logger self.monitor = monitor self.ubf_context = ubf_context + self.spin_metadata = spin_metadata + self.spin_artifacts = spin_artifacts def _exec_step_function(self, step_function, input_obj=None): if input_obj is None: @@ -81,15 +85,6 @@ def set_as_parameter(name, value): property(fget=lambda _, val=value: val, fset=_set_cls_var), ) - print("In _init_parameters") - print(f"parameter_ds: {parameter_ds}") - print(f"type(parameter_ds): {type(parameter_ds)}") - print(f"parameter_ds.pathspec: {parameter_ds.pathspec}") - try: - print(f"parameter_ds['_graph_info']: {parameter_ds['_graph_info']}") - except Exception: - print("I am in exception of parameter_ds['_graph_info']") - # overwrite Parameters in the flow object all_vars = [] for var, param in self.flow._get_parameters(): @@ -124,17 +119,13 @@ def property_setter( all_vars.append(var) # We also passdown _graph_info through the entire graph - print("Calling set_as_parameter") set_as_parameter( "_graph_info", lambda _, parameter_ds=parameter_ds: parameter_ds["_graph_info"], ) all_vars.append("_graph_info") if passdown: - print("Calling passdown_partial") self.flow._datastore.passdown_partial(parameter_ds, all_vars) - print("Done with passdown_partial") - print(f"param_only_vars: {param_only_vars}") return param_only_vars def _init_data(self, run_id, join_type, input_paths): @@ -142,7 +133,7 @@ def _init_data(self, run_id, join_type, input_paths): # (via TaskDataStoreSet) only with more than 4 datastores, because # the baseline overhead of using the set is ~1.5s and each datastore # init takes ~200-300ms when run sequentially. - print(f"I am in in _init_data") + # print(f"I am in in _init_data") if len(input_paths) > 4: prefetch_data_artifacts = None if join_type and join_type == "foreach": @@ -160,6 +151,9 @@ def _init_data(self, run_id, join_type, input_paths): run_id, pathspecs=input_paths, prefetch_data_artifacts=prefetch_data_artifacts, + join_type=join_type, + spin_metadata=self.spin_metadata, + spin_artifacts=self.spin_artifacts, ) ds_list = [ds for ds in datastore_set] if len(ds_list) != len(input_paths): @@ -172,13 +166,16 @@ def _init_data(self, run_id, join_type, input_paths): ds_list = [] for input_path in input_paths: run_id, step_name, task_id = input_path.split("/") - print( - f"In _init_data: run_id: {run_id}, step_name: {step_name}, task_id: {task_id}" - ) ds_list.append( - self.flow_datastore.get_task_datastore(run_id, step_name, task_id) + self.flow_datastore.get_task_datastore( + run_id, + step_name, + task_id, + join_type=join_type, + spin_metadata=self.spin_metadata, + spin_artifacts=self.spin_artifacts, + ) ) - print(f"ds_list: {ds_list}") if not ds_list: # this guards against errors in input paths raise MetaflowDataMissing( @@ -399,6 +396,8 @@ def run_step( split_index, retry_count, max_user_code_retries, + whitelist_decorators=None, + persist=True, ): if run_id and task_id: self.metadata.register_run_id(run_id) @@ -415,7 +414,6 @@ def run_step( raise MetaflowInternalError( "Too many task attempts (%d)! MAX_ATTEMPTS exceeded." % retry_count ) - print(f"Just before metadata_tags") metadata_tags = ["attempt_id:{0}".format(retry_count)] metadata = [ @@ -455,9 +453,13 @@ def run_step( ) ) - print("Done until trace_id") step_func = getattr(self.flow, step_name) decorators = step_func.decorators + if self.spin_metadata: + # We filter only the whitelisted decorators + decorators = [ + deco for deco in decorators if deco.name in whitelist_decorators + ] node = self.flow._graph[step_name] join_type = None @@ -466,32 +468,24 @@ def run_step( # 1. initialize output datastore output = self.flow_datastore.get_task_datastore( - run_id, step_name, task_id, attempt=retry_count, mode="w" + run_id, step_name, task_id, attempt=retry_count, mode="w", persist=persist ) - print(f"Done until output datastore") output.init_task() - print("Done until output.init_task") - print(f"input_paths after output.init_task: {input_paths}") if input_paths: # 2. initialize input datastores - print(f"input_paths: {input_paths}") inputs = self._init_data(run_id, join_type, input_paths) - print(f"inputs: {inputs}") # 3. initialize foreach state self._init_foreach(step_name, join_type, inputs, split_index) - print(f"Done until _init_foreach") # Add foreach stack to metadata of the task - foreach_stack = ( self.flow._foreach_stack if hasattr(self.flow, "_foreach_stack") and self.flow._foreach_stack else [] ) - print(f"foreach_stack: {foreach_stack}") foreach_stack_formatted = [] current_foreach_path_length = 0 @@ -558,7 +552,6 @@ def run_step( is_running=True, tags=self.metadata.sticky_tags, ) - print(f"Done until current._set_env") # 5. run task output.save_metadata( @@ -576,7 +569,6 @@ def run_step( # We also pass this context as part of the task payload to support implementations that # can't access the context directly - print(f"Setting task_payload") task_payload = { "run_id": run_id, "step_name": step_name, @@ -610,8 +602,6 @@ def run_step( # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. - - print(f"Just before if join_type") if join_type: # Join step: @@ -640,7 +630,6 @@ def run_step( } ) else: - print(f"In linear step") # Linear step: # We are running with a single input context. # The context is embedded in the flow. @@ -653,7 +642,6 @@ def run_step( "inputs." % step_name ) self.flow._set_datastore(inputs[0]) - print(f"Just before if input_paths with input_paths: {input_paths}") if input_paths: # initialize parameters (if they exist) # We take Parameter values from the first input, @@ -667,9 +655,6 @@ def run_step( "graph_info": self.flow._graph_info, } ) - print(f"Just done with current._update_env") - print("HEHEHEHEHEHEH") - # print(f"Set flow datastore to inputs[0]: {inputs[0]}") for deco in decorators: deco.task_pre_step( step_name, @@ -761,7 +746,6 @@ def run_step( try: # persisting might fail due to unpicklable artifacts. output.persist(self.flow) - print("Done with output.persist") except Exception as ex: self.flow._task_ok = False raise ex diff --git a/metaflow/util.py b/metaflow/util.py index 9fc80ad9681..4666678d5e9 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -7,6 +7,7 @@ from functools import wraps from io import BytesIO from itertools import takewhile +from typing import Dict, Any import re from metaflow.exception import MetaflowUnknownUser, MetaflowInternalError @@ -177,7 +178,7 @@ def resolve_identity(): return "%s:%s" % (identity_type, identity_value) -def get_latest_task_pathspec(flow_name: str, step_name: str) -> str: +def get_latest_task_pathspec(flow_name: str, step_name: str) -> (str, str): """ Returns a task pathspec from the latest run of the flow for the queried step. If the queried step has several tasks, the task pathspec of the first task is returned. @@ -191,8 +192,8 @@ def get_latest_task_pathspec(flow_name: str, step_name: str) -> str: Returns ------- - str - The task pathspec of the first task of the queried step. + Task + A Metaflow Task instance containing the latest task for the queried step. Raises ------ @@ -209,7 +210,7 @@ def get_latest_task_pathspec(flow_name: str, step_name: str) -> str: try: task = Step(f"{flow_name}/{run.id}/{step_name}", _namespace_check=False).task - return task.pathspec + return task except Exception: raise MetaflowNotFound( f"No step *{step_name}* found in run *{run.id}* for flow *{flow_name}*" @@ -506,7 +507,25 @@ def to_pod(value): return str(value) -def read_artifacts_module(file_path): +def read_artifacts_module(file_path: str) -> Dict[str, Any]: + """ + Read a Python module from the given file path and return its ARTIFACTS variable. + + Parameters + ---------- + file_path : str + The path to the Python file containing the ARTIFACTS variable. + + Returns + ------- + Dict[str, Any] + A dictionary containing the ARTIFACTS variable from the module. + + Raises + ------- + MetaflowInternalError + If the file cannot be read or does not contain the ARTIFACTS variable. + """ import importlib.util try: @@ -518,7 +537,7 @@ def read_artifacts_module(file_path): raise MetaflowInternalError( f"Module {file_path} does not contain ARTIFACTS variable" ) - return variables + return variables.get("ARTIFACTS") except Exception as e: raise MetaflowInternalError(f"Error reading file {file_path}") from e diff --git a/test/unit/spin/artifacts/complex_dag_step_a.py b/test/unit/spin/artifacts/complex_dag_step_a.py new file mode 100644 index 00000000000..b7e81bf1b6f --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_a.py @@ -0,0 +1 @@ +ARTIFACTS = {"my_output": [10, 11, 12]} diff --git a/test/unit/spin/artifacts/complex_dag_step_d.py b/test/unit/spin/artifacts/complex_dag_step_d.py new file mode 100644 index 00000000000..5aa40d64766 --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_d.py @@ -0,0 +1,11 @@ +from metaflow import Run + + +def _get_artifact(): + task = Run("ComplexDAGFlow/2")["step_d"].task + task_pathspec = next(task.parent_task_pathspecs) + _, inp_path = task_pathspec.split("/", 1) + return {inp_path: {"my_output": [-1]}} + + +ARTIFACTS = _get_artifact() diff --git a/test/unit/spin/complex_dag_flow.py b/test/unit/spin/complex_dag_flow.py new file mode 100644 index 00000000000..733df2d60d7 --- /dev/null +++ b/test/unit/spin/complex_dag_flow.py @@ -0,0 +1,115 @@ +from metaflow import FlowSpec, step, project, conda, Task, pypi + + +class ComplexDAGFlow(FlowSpec): + @step + def start(self): + self.split_start = [1, 2, 3] + self.my_output = [] + print("My output is: ", self.my_output) + self.next(self.step_a, foreach="split_start") + + @step + def step_a(self): + self.split_a = [4, 5] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_b, foreach="split_a") + + @step + def step_b(self): + self.split_b = [6, 7, 8] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_c, foreach="split_b") + + @conda(libraries={"numpy": "2.1.1"}) + @step + def step_c(self): + import numpy as np + + self.np_version = np.__version__ + print(f"numpy version: {self.np_version}") + self.my_output = self.my_output + [self.input] + [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_d) + + @step + def step_d(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_e) + + @step + def step_e(self): + print(f"I am step E. Input is: {self.input}") + self.split_e = [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_f, foreach="split_e") + + @step + def step_f(self): + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_g) + + @step + def step_g(self): + print("My output is: ", self.my_output) + self.next(self.step_h) + + @step + def step_h(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_i) + + @step + def step_i(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_j) + + @step + def step_j(self): + print("My output is: ", self.my_output) + self.next(self.step_k, self.step_l) + + @step + def step_k(self): + self.my_output = self.my_output + [11] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @step + def step_l(self): + print(f"I am step L. Input is: {self.input}") + self.my_output = self.my_output + [12] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @conda(libraries={"scikit-learn": ""}) + @step + def step_m(self, inputs): + import sklearn + + self.sklearn_version = sklearn.__version__ + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_n) + + @step + def step_n(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.end) + + @step + def end(self): + self.my_output = self.my_output + [13] + print("My output is: ", self.my_output) + print("Flow is complete!") + + +if __name__ == "__main__": + ComplexDAGFlow() diff --git a/test/unit/spin/merge_artifacts_flow.py b/test/unit/spin/merge_artifacts_flow.py new file mode 100644 index 00000000000..59f1390e052 --- /dev/null +++ b/test/unit/spin/merge_artifacts_flow.py @@ -0,0 +1,63 @@ +from metaflow import FlowSpec, step + + +class MergeArtifactsFlow(FlowSpec): + + @step + def start(self): + self.pass_down = "a" + self.next(self.a, self.b) + + @step + def a(self): + self.common = 5 + self.x = 1 + self.y = 3 + self.from_a = 6 + self.next(self.join) + + @step + def b(self): + self.common = 5 + self.x = 2 + self.y = 4 + self.next(self.join) + + @step + def join(self, inputs): + print(f"In join step, self._datastore: {(type(self._datastore))}") + self.x = inputs.a.x + self.merge_artifacts(inputs, exclude=["y"]) + print("x is %s" % self.x) + print("pass_down is %s" % self.pass_down) + print("common is %d" % self.common) + print("from_a is %d" % self.from_a) + self.next(self.c) + + @step + def c(self): + self.next(self.d, self.e) + + @step + def d(self): + self.conflicting = 7 + self.next(self.join2) + + @step + def e(self): + self.conflicting = 8 + self.next(self.join2) + + @step + def join2(self, inputs): + self.merge_artifacts(inputs, include=["pass_down", "common"]) + print("Only pass_down and common exist here") + self.next(self.end) + + @step + def end(self): + pass + + +if __name__ == "__main__": + MergeArtifactsFlow() diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py new file mode 100644 index 00000000000..2940b453053 --- /dev/null +++ b/test/unit/spin/test_spin.py @@ -0,0 +1,84 @@ +import pytest +from metaflow import Runner, Run + + +@pytest.fixture +def complex_dag_run(): + # with Runner('complex_dag_flow.py').run() as running: + # yield running.run + return Run("ComplexDAGFlow/2", _namespace_check=False) + + +@pytest.fixture +def merge_artifacts_run(): + # with Runner('merge_artifacts_flow.py').run() as running: + # yield running.run + return Run("MergeArtifactsFlow/55", _namespace_check=False) + + +def _assert_artifacts(task, spin_task): + spin_task_artifacts = { + artifact.id: artifact.data for artifact in spin_task.artifacts + } + print(f"Spin task artifacts: {spin_task_artifacts}") + for artifact in task.artifacts: + assert ( + artifact.id in spin_task_artifacts + ), f"Artifact {artifact.id} not found in spin task" + assert ( + artifact.data == spin_task_artifacts[artifact.id] + ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" + + +def _run_step(file_name, run, step_name): + task = run[step_name].task + with Runner(file_name).spin(step_name, spin_pathspec=task.pathspec) as spin: + print("-" * 50) + print(f"Running test for step: {step_name} with task pathspec: {task.pathspec}") + _assert_artifacts(task, spin.task) + + +def test_runtime_flow(complex_dag_run): + print(f"Running test for runtime flow: {complex_dag_run}") + for step in complex_dag_run.steps(): + print("-" * 100) + _run_step("runtime_dag_flow.py", complex_dag_run, step.id) + + +def test_merge_artifacts_flow(merge_artifacts_run): + print(f"Running test for merge artifacts flow: {merge_artifacts_run}") + for step in merge_artifacts_run.steps(): + print("-" * 100) + _run_step("merge_artifacts_flow.py", merge_artifacts_run, step.id) + + +def test_artifacts_module(complex_dag_run): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_a" + task = complex_dag_run[step_name].task + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + artifacts_module="./artifacts/complex_dag_step_a.py", + ) as spin: + print("-" * 50) + print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + spin_task = spin.task + print(f"my_output: {spin_task['my_output']}") + assert spin_task["my_output"].data == [10, 11, 12, 3] + + +def test_artifacts_module_join_step(complex_dag_run): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_d" + task = complex_dag_run[step_name].task + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + artifacts_module="./artifacts/complex_dag_step_d.py", + ) as spin: + print("-" * 50) + print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + spin_task = spin.task + print(f"my_output: {spin_task['my_output']}") + assert spin_task["my_output"].data == [-1] From f6dae383170a3f695f1b581ddf0b23b6498590b1 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Mon, 9 Jun 2025 23:57:11 -0700 Subject: [PATCH 4/5] Fix bug in persisting artifacts --- metaflow/cli_components/step_cmd.py | 5 -- metaflow/datastore/flow_datastore.py | 5 -- metaflow/datastore/spin_datastore.py | 2 + metaflow/datastore/task_datastore.py | 1 - metaflow/task.py | 33 ++++++++++- test/unit/spin/complex_dag_flow.py | 3 +- test/unit/spin/test_spin.py | 88 ++++++++++++++++++++++++---- 7 files changed, 110 insertions(+), 27 deletions(-) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index a5261fb0c5b..e8b91f639e2 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -283,11 +283,6 @@ def spin_step( namespace(opt_namespace or None) input_paths = decompress_list(input_paths) if input_paths else [] - # echo_always( - # f"Spinning a task, *{step_name}* with previous task pathspec: {spin_pathspec}", - # fg="magenta", - # bold=True, - # ) whitelist_decorators = ( decompress_list(whitelist_decorators) if whitelist_decorators else [] diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index 3de7a90f087..1e7d1c102d1 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -247,11 +247,6 @@ def get_task_datastore( if join_type is not None: # If join_type is specified, we need to use the artifacts corresponding # to that particular join index, specified by the parent task pathspec. - print(f"Spin Artifacts: {spin_artifacts}") - print(f"pathspec: {run_id}/{step_name}/{task_id}") - print( - f"Spin Artifacts tp: {spin_artifacts.get(f'{run_id}/{step_name}/{task_id}')}" - ) spin_artifacts = spin_artifacts.get( f"{run_id}/{step_name}/{task_id}", {} ) diff --git a/metaflow/datastore/spin_datastore.py b/metaflow/datastore/spin_datastore.py index 1f7e215543f..3b7421a77ab 100644 --- a/metaflow/datastore/spin_datastore.py +++ b/metaflow/datastore/spin_datastore.py @@ -50,6 +50,8 @@ def __init__( self._objects[artifact.id] = artifact.sha # Fulfills the contract for _info: name -> metadata self._info[artifact.id] = { + # Do not save the type of the data + # "type": str(type(artifact.data)), "size": artifact.size, "encoding": artifact._object["content_type"], } diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index f5e63a27410..a63c89480ca 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -167,7 +167,6 @@ def __init__( data_obj = self.load_metadata([self.METADATA_DATA_SUFFIX]) data_obj = data_obj[self.METADATA_DATA_SUFFIX] elif self._attempt is None or not allow_not_done: - print("HERE in self._mode == 'r'") raise DataException( "No completed attempts of the task was found for task '%s'" % self._path diff --git a/metaflow/task.py b/metaflow/task.py index 53315d8f0fc..961b7c29a0a 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -133,7 +133,6 @@ def _init_data(self, run_id, join_type, input_paths): # (via TaskDataStoreSet) only with more than 4 datastores, because # the baseline overhead of using the set is ~1.5s and each datastore # init takes ~200-300ms when run sequentially. - # print(f"I am in in _init_data") if len(input_paths) > 4: prefetch_data_artifacts = None if join_type and join_type == "foreach": @@ -456,7 +455,7 @@ def run_step( step_func = getattr(self.flow, step_name) decorators = step_func.decorators if self.spin_metadata: - # We filter only the whitelisted decorators + # We filter only the whitelisted decorators in case of spin step. decorators = [ deco for deco in decorators if deco.name in whitelist_decorators ] @@ -642,11 +641,31 @@ def run_step( "inputs." % step_name ) self.flow._set_datastore(inputs[0]) + # Iterate over all artifacts in the parent pathspec and add them + # to the current flow's datastore. We need to do this explictly + # since we want to persist even those attributes that are not + # used / redefined in the spin step. + if self.spin_metadata and persist: + st_time = time.time() + for artifact_name in self.flow._datastore._objects.keys(): + # This is highly inefficient since we are loading data + # that we don't need, but there is no better way to + # support this now + artifact_data = self.spin_artifacts.get( + artifact_name, self.flow._datastore[artifact_name] + ) + setattr( + self.flow, + artifact_name, + artifact_data, + ) + print( + f"Time taken to load all artifacts: {time.time() - st_time:.2f} seconds" + ) if input_paths: # initialize parameters (if they exist) # We take Parameter values from the first input, # which is always safe since parameters are read-only - # print(f"self.flow._graph_info: {self.flow._graph_info}") current._update_env( { "parameter_names": self._init_parameters( @@ -685,10 +704,14 @@ def run_step( self.ubf_context, ) + st_time = time.time() if join_type: self._exec_step_function(step_func, input_obj) else: self._exec_step_function(step_func) + print( + f"Time taken to run the step function: {time.time() - st_time:.2f} seconds" + ) for deco in decorators: deco.task_post_step( @@ -745,7 +768,11 @@ def run_step( ) try: # persisting might fail due to unpicklable artifacts. + st_time = time.time() output.persist(self.flow) + print( + f"Time taken to persist the output: {time.time() - st_time:.2f} seconds" + ) except Exception as ex: self.flow._task_ok = False raise ex diff --git a/test/unit/spin/complex_dag_flow.py b/test/unit/spin/complex_dag_flow.py index 733df2d60d7..04b185fe40f 100644 --- a/test/unit/spin/complex_dag_flow.py +++ b/test/unit/spin/complex_dag_flow.py @@ -88,13 +88,14 @@ def step_l(self): print("My output is: ", self.my_output) self.next(self.step_m) - @conda(libraries={"scikit-learn": ""}) + @conda(libraries={"scikit-learn": "1.3.0"}) @step def step_m(self, inputs): import sklearn self.sklearn_version = sklearn.__version__ self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("Sklearn version: ", self.sklearn_version) print("My output is: ", self.my_output) self.next(self.step_n) diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 2940b453053..215575e1a0e 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -6,7 +6,7 @@ def complex_dag_run(): # with Runner('complex_dag_flow.py').run() as running: # yield running.run - return Run("ComplexDAGFlow/2", _namespace_check=False) + return Run("ComplexDAGFlow/3", _namespace_check=False) @pytest.fixture @@ -30,19 +30,33 @@ def _assert_artifacts(task, spin_task): ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" -def _run_step(file_name, run, step_name): +def _run_step(file_name, run, step_name, is_conda=False): task = run[step_name].task - with Runner(file_name).spin(step_name, spin_pathspec=task.pathspec) as spin: - print("-" * 50) - print(f"Running test for step: {step_name} with task pathspec: {task.pathspec}") - _assert_artifacts(task, spin.task) - - -def test_runtime_flow(complex_dag_run): - print(f"Running test for runtime flow: {complex_dag_run}") + if not is_conda: + with Runner(file_name).spin(step_name, spin_pathspec=task.pathspec) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + _assert_artifacts(task, spin.task) + else: + with Runner(file_name, environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + ) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + print(f"Spin task artifacts: {spin.task.artifacts}") + _assert_artifacts(task, spin.task) + + +def test_complex_dag_flow(complex_dag_run): + print(f"Running test for ComplexDAGFlow flow: {complex_dag_run}") for step in complex_dag_run.steps(): print("-" * 100) - _run_step("runtime_dag_flow.py", complex_dag_run, step.id) + _run_step("complex_dag_flow.py", complex_dag_run, step.id, is_conda=True) def test_merge_artifacts_flow(merge_artifacts_run): @@ -80,5 +94,55 @@ def test_artifacts_module_join_step(complex_dag_run): print("-" * 50) print(f"Running test for step: step_a with task pathspec: {task.pathspec}") spin_task = spin.task - print(f"my_output: {spin_task['my_output']}") assert spin_task["my_output"].data == [-1] + + +def test_skip_decorators(complex_dag_run): + print(f"Running test for skip decorator in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_m" + task = complex_dag_run[step_name].task + # Check if sklearn is available in the outer environment + # If not, this test will fail as it requires sklearn to be installed and skip_decorator + # is set to True + is_sklearn = True + try: + import sklearn + except ImportError: + is_sklearn = False + if is_sklearn: + # We verify that the sklearn version is the same as the one in the outside environment + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + skip_decorators=True, + ) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + spin_task = spin.task + import sklearn + + expected_version = sklearn.__version__ + assert ( + spin_task["sklearn_version"].data == expected_version + ), f"Expected sklearn version {expected_version} but got {spin_task['sklearn_version']}" + else: + # We assert that an exception is raised when trying to run the step with skip_decorators=True + with pytest.raises(Exception) as exc_info: + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + skip_decorators=True, + ): + pass + + +# with Runner('complex_dag_flow.py', environment="conda").spin("step_g") as spin: +# print("-" * 50) +# print(f"Running test for step: step_g with task pathspec: {spin.task.pathspec}") +# spin_task = spin.task +# print(spin_task) +# print(spin_task.artifacts) +# for artifact in spin_task.artifacts: +# print(f"Artifact {artifact.id}: {artifact.data}") From bbc2ce5f30b92cc0572df23de2f063fe535b2873 Mon Sep 17 00:00:00 2001 From: Shashank Srikanth Date: Tue, 10 Jun 2025 00:08:15 -0700 Subject: [PATCH 5/5] Remove superfluous print statement --- metaflow/cli_components/run_cmds.py | 3 --- metaflow/runtime.py | 2 -- test/unit/spin/test_spin.py | 10 ---------- 3 files changed, 15 deletions(-) diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index e2e351d5ab2..9455be47caa 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -51,9 +51,6 @@ def before_run(obj, tags, decospecs): obj.check(obj.graph, obj.flow, obj.environment, pylint=obj.pylint) # obj.environment.init_environment(obj.logger) - # decorators._init_step_decorators( - # obj.flow, obj.graph, obj.environment, obj.effective_flow_datastore, obj.logger - # ) decorators._init_step_decorators( obj.flow, obj.graph, obj.environment, obj.flow_datastore, obj.logger ) diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 5a6c797e34a..6d1ed1b5fa6 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -260,9 +260,7 @@ def _launch_and_monitor_task(self): poll.remove(event.fd) active_fds.remove(event.fd) - # print("I am after while loop") returncode = worker.terminate() - # print(f"Return code: {returncode}") if returncode != 0: raise TaskFailed(self.task, f"Task failed with return code {returncode}") diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py index 215575e1a0e..75577f9c624 100644 --- a/test/unit/spin/test_spin.py +++ b/test/unit/spin/test_spin.py @@ -136,13 +136,3 @@ def test_skip_decorators(complex_dag_run): skip_decorators=True, ): pass - - -# with Runner('complex_dag_flow.py', environment="conda").spin("step_g") as spin: -# print("-" * 50) -# print(f"Running test for step: step_g with task pathspec: {spin.task.pathspec}") -# spin_task = spin.task -# print(spin_task) -# print(spin_task.artifacts) -# for artifact in spin_task.artifacts: -# print(f"Artifact {artifact.id}: {artifact.data}")