diff --git a/metaflow/cli.py b/metaflow/cli.py index d24829e6db6..2d6df48688e 100644 --- a/metaflow/cli.py +++ b/metaflow/cli.py @@ -1,3 +1,4 @@ +import os import functools import inspect import sys @@ -23,6 +24,8 @@ DEFAULT_METADATA, DEFAULT_MONITOR, DEFAULT_PACKAGE_SUFFIXES, + DATASTORE_SYSROOT_SPIN, + DATASTORE_LOCAL_DIR, ) from .metaflow_current import current from metaflow.system import _system_monitor, _system_logger @@ -114,6 +117,8 @@ def logger(body="", system_msg=False, head="", bad=False, timestamp=True, nl=Tru "step": "metaflow.cli_components.step_cmd.step", "run": "metaflow.cli_components.run_cmds.run", "resume": "metaflow.cli_components.run_cmds.resume", + "spin": "metaflow.cli_components.run_cmds.spin", + "spin-step": "metaflow.cli_components.step_cmd.spin_step", }, ) def cli(ctx): @@ -440,14 +445,10 @@ def start( ctx.obj.event_logger = LOGGING_SIDECARS[event_logger]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.event_logger.start() - _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) ctx.obj.monitor = MONITOR_SIDECARS[monitor]( flow=ctx.obj.flow, env=ctx.obj.environment ) - ctx.obj.monitor.start() - _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == metadata][0]( ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor @@ -462,6 +463,45 @@ def start( ) ctx.obj.config_options = config_options + ctx.obj.is_spin = False + + # Override values for spin + if hasattr(ctx, "saved_args") and ctx.saved_args and "spin" in ctx.saved_args[0]: + # To minimize side-effects for spin, we will only use the following: + # - local metadata provider, + # - local datastore, + # - local environment, + # - null event logger, + # - null monitor + ctx.obj.is_spin = True + ctx.obj.event_logger = LOGGING_SIDECARS["nullSidecarLogger"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.monitor = MONITOR_SIDECARS["nullSidecarMonitor"]( + flow=ctx.obj.flow, env=ctx.obj.environment + ) + ctx.obj.metadata = [m for m in METADATA_PROVIDERS if m.TYPE == "local"][0]( + ctx.obj.environment, ctx.obj.flow, ctx.obj.event_logger, ctx.obj.monitor + ) + ctx.obj.datastore_impl = [d for d in DATASTORES if d.TYPE == "local"][0] + # Set datastore_root to be DATASTORE_SYSROOT_SPIN if not provided + datastore_root = os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) + ctx.obj.datastore_impl.datastore_root = datastore_root + ctx.obj.flow_datastore = FlowDataStore( + ctx.obj.flow.name, + ctx.obj.environment, # Same environment as run/resume + ctx.obj.metadata, # local metadata + ctx.obj.event_logger, # null event logger + ctx.obj.monitor, # null monitor + storage_impl=ctx.obj.datastore_impl, + ) + + # Start event logger and monitor + ctx.obj.event_logger.start() + _system_logger.init_system_logger(ctx.obj.flow.name, ctx.obj.event_logger) + + ctx.obj.monitor.start() + _system_monitor.init_system_monitor(ctx.obj.flow.name, ctx.obj.monitor) decorators._init(ctx.obj.flow) @@ -478,7 +518,7 @@ def start( deco_options, ) - # In the case of run/resume, we will want to apply the TL decospecs + # In the case of run/resume/spin, we will want to apply the TL decospecs # *after* the run decospecs so that they don't take precedence. In other # words, for the same decorator, we want `myflow.py run --with foo` to # take precedence over any other `foo` decospec @@ -506,7 +546,7 @@ def start( if ( hasattr(ctx, "saved_args") and ctx.saved_args - and ctx.saved_args[0] not in ("run", "resume") + and ctx.saved_args[0] not in ("run", "resume", "spin") ): # run/resume are special cases because they can add more decorators with --with, # so they have to take care of themselves. diff --git a/metaflow/cli_components/run_cmds.py b/metaflow/cli_components/run_cmds.py index 681fc8f94f9..9455be47caa 100644 --- a/metaflow/cli_components/run_cmds.py +++ b/metaflow/cli_components/run_cmds.py @@ -10,7 +10,7 @@ from ..metaflow_current import current from ..metaflow_config import DEFAULT_DECOSPECS from ..package import MetaflowPackage -from ..runtime import NativeRuntime +from ..runtime import NativeRuntime, SpinRuntime from ..system import _system_logger from ..tagging_util import validate_tags @@ -20,7 +20,7 @@ def before_run(obj, tags, decospecs): validate_tags(tags) - # There's a --with option both at the top-level and for the run + # There's a --with option both at the top-level and for the run/resume/spin # subcommand. Why? # # "run --with shoes" looks so much better than "--with shoes run". @@ -40,7 +40,7 @@ def before_run(obj, tags, decospecs): + list(obj.environment.decospecs() or []) ) if all_decospecs: - # These decospecs are the ones from run/resume PLUS the ones from the + # These decospecs are the ones from run/resume/spin PLUS the ones from the # environment (for example the @conda) decorators._attach_decorators(obj.flow, all_decospecs) decorators._init(obj.flow) @@ -65,6 +65,29 @@ def before_run(obj, tags, decospecs): ) +def common_runner_options(func): + @click.option( + "--run-id-file", + default=None, + show_default=True, + type=str, + help="Write the ID of this run to the file specified.", + ) + @click.option( + "--runner-attribute-file", + default=None, + show_default=True, + type=str, + help="Write the metadata and pathspec of this run to the file specified. Used internally " + "for Metaflow's Runner API.", + ) + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + def write_file(file_path, content): if file_path is not None: with open(file_path, "w", encoding="utf-8") as f: @@ -129,20 +152,6 @@ def common_run_options(func): "in steps.", callback=config_callback, ) - @click.option( - "--run-id-file", - default=None, - show_default=True, - type=str, - help="Write the ID of this run to the file specified.", - ) - @click.option( - "--runner-attribute-file", - default=None, - show_default=True, - type=str, - help="Write the metadata and pathspec of this run to the file specified. Used internally for Metaflow's Runner API.", - ) @wraps(func) def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -187,6 +196,7 @@ def wrapper(*args, **kwargs): @click.command(help="Resume execution of a previous run of this flow.") @tracing.cli("cli/resume") @common_run_options +@common_runner_options @click.pass_obj def resume( obj, @@ -305,6 +315,7 @@ def resume( @click.command(help="Run the workflow locally.") @tracing.cli("cli/run") @common_run_options +@common_runner_options @click.option( "--namespace", "user_namespace", @@ -380,3 +391,106 @@ def run( ) with runtime.run_heartbeat(): runtime.execute() + + +@click.command(help="Spins up a task for a given step from a previous run locally.") +@click.argument("step-name") +@click.option( + "--spin-pathspec", + default=None, + type=str, + help="Use specified task pathspec from a previous run to spin up the step.", +) +@click.option( + "--skip-decorators/--no-skip-decorators", + is_flag=True, + default=False, + show_default=True, + help="Skip decorators attached to the step.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. The artifacts should " + "be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the " + "artifact values. The artifact values will overwrite the default values of the artifacts used in " + "the spun step.", +) +@click.option( + "--persist/--no-persist", + "persist", + default=True, + show_default=True, + help="Whether to persist the artifacts in the spun step. If set to false, the artifacts will not" + " be persisted and will not be available in the spun step's datastore.", +) +@click.option( + "--max-log-size", + default=10, + show_default=True, + help="Maximum size of stdout and stderr captured in " + "megabytes. If a step outputs more than this to " + "stdout/stderr, its output will be truncated.", +) +@common_runner_options +@click.pass_obj +def spin( + obj, + step_name, + spin_pathspec=None, + persist=True, + artifacts_module=None, + skip_decorators=False, + max_log_size=None, + run_id_file=None, + runner_attribute_file=None, + **kwargs +): + before_run(obj, [], []) + obj.echo(f"Spinning up step *{step_name}* locally for flow *{obj.flow.name}*") + obj.flow._set_constants(obj.graph, kwargs, obj.config_options) + step_func = getattr(obj.flow, step_name) + + spin_runtime = SpinRuntime( + obj.flow, + obj.graph, + obj.flow_datastore, + obj.metadata, + obj.environment, + obj.package, + obj.logger, + obj.entrypoint, + obj.event_logger, + obj.monitor, + step_func, + step_name, + spin_pathspec, + skip_decorators, + artifacts_module, + persist, + max_log_size * 1024 * 1024, + ) + + write_latest_run_id(obj, spin_runtime.run_id) + write_file(run_id_file, spin_runtime.run_id) + + # datastore_root is os.path.join(DATASTORE_SYSROOT_SPIN, DATASTORE_LOCAL_DIR) + # We only neeed the root for the metadata, i.e. the portion before DATASTORE_LOCAL_DIR + datastore_root = spin_runtime._flow_datastore._storage_impl.datastore_root + spin_metadata_root = datastore_root.rsplit("/", 1)[0] + spin_runtime.execute() + + if runner_attribute_file: + with open(runner_attribute_file, "w") as f: + json.dump( + { + "task_id": spin_runtime.task.task_id, + "step_name": step_name, + "run_id": spin_runtime.run_id, + "flow_name": obj.flow.name, + # Store metadata in a format that can be used by the Runner API + "metadata": f"{obj.metadata.__class__.TYPE}@{spin_metadata_root}", + }, + f, + ) diff --git a/metaflow/cli_components/step_cmd.py b/metaflow/cli_components/step_cmd.py index f4bef099e42..e8b91f639e2 100644 --- a/metaflow/cli_components/step_cmd.py +++ b/metaflow/cli_components/step_cmd.py @@ -6,7 +6,7 @@ from ..exception import CommandException from ..task import MetaflowTask from ..unbounded_foreach import UBF_CONTROL, UBF_TASK -from ..util import decompress_list +from ..util import decompress_list, read_artifacts_module import metaflow.tracing as tracing @@ -176,3 +176,142 @@ def step( ) echo("Success", fg="green", bold=True, indent=True) + + +@click.command(help="Internal command to spin a single task.", hidden=True) +@click.argument("step-name") +@click.option( + "--run-id", + default=None, + required=True, + help="Run ID for the step that's about to be spun", +) +@click.option( + "--task-id", + default=None, + required=True, + help="Task ID for the step that's about to be spun", +) +@click.option( + "--spin-metadata", + default=None, + show_default=True, + help="Spin metadata provider to be used for fetching artifacts/data for the input datastore", +) +@click.option( + "--spin-pathspec", + default=None, + show_default=True, + help="Task Pathspec to be used in the spun step.", +) +@click.option( + "--input-paths", + help="A comma-separated list of pathspecs specifying inputs for this step.", +) +@click.option( + "--split-index", + type=int, + default=None, + show_default=True, + help="Index of this foreach split.", +) +@click.option( + "--retry-count", + default=0, + help="How many times we have attempted to run this task.", +) +@click.option( + "--max-user-code-retries", + default=0, + help="How many times we should attempt running the user code.", +) +@click.option( + "--namespace", + "opt_namespace", + default=None, + help="Change namespace from the default (your username) to the specified tag.", +) +@click.option( + "--whitelist-decorators", + help="A comma-separated list of whitelisted decorators to use for the spin step", +) +@click.option( + "--persist/--no-persist", + "persist", + default=True, + show_default=True, + help="Whether to persist the artifacts in the spun step. If set to false, the artifacts will not" + " be persisted and will not be available in the spun step's datastore.", +) +@click.option( + "--artifacts-module", + default=None, + show_default=True, + help="Path to a module that contains artifacts to be used in the spun step. The artifacts should " + "be defined as a dictionary called ARTIFACTS with keys as the artifact names and values as the " + "artifact values. The artifact values will overwrite the default values of the artifacts used in " + "the spun step.", +) +@click.pass_context +def spin_step( + ctx, + step_name, + run_id=None, + task_id=None, + spin_metadata=None, + spin_pathspec=None, + input_paths=None, + split_index=None, + retry_count=None, + max_user_code_retries=None, + opt_namespace=None, + whitelist_decorators=None, + artifacts_module=None, + persist=True, +): + import time + + start = time.time() + import sys + + if ctx.obj.is_quiet: + echo = echo_dev_null + else: + echo = echo_always + + if opt_namespace is not None: + namespace(opt_namespace or None) + + input_paths = decompress_list(input_paths) if input_paths else [] + + whitelist_decorators = ( + decompress_list(whitelist_decorators) if whitelist_decorators else [] + ) + spin_artifacts = read_artifacts_module(artifacts_module) if artifacts_module else {} + + task = MetaflowTask( + ctx.obj.flow, + ctx.obj.flow_datastore, + ctx.obj.metadata, + ctx.obj.environment, + echo, + ctx.obj.event_logger, + ctx.obj.monitor, + None, # no unbounded foreach context + spin_metadata=spin_metadata, + spin_artifacts=spin_artifacts, + ) + task.run_step( + step_name, + run_id, + task_id, + None, + input_paths, + split_index, + retry_count, + max_user_code_retries, + whitelist_decorators, + persist, + ) + + echo_always(f"Time taken for the whole thing: {time.time() - start}") diff --git a/metaflow/client/core.py b/metaflow/client/core.py index 4edbcdac00c..f1848e38799 100644 --- a/metaflow/client/core.py +++ b/metaflow/client/core.py @@ -1191,142 +1191,186 @@ def _iter_filter(self, x): # exclude private data artifacts return x.id[0] != "_" - def _iter_matching_tasks(self, steps, metadata_key, metadata_pattern): + def _get_matching_pathspecs(self, steps, metadata_key, metadata_pattern): """ - Yield tasks from specified steps matching a foreach path pattern. + Yield pathspecs of tasks from specified steps that match a given metadata pattern. Parameters ---------- steps : List[str] - List of step names to search for tasks - pattern : str - Regex pattern to match foreach-indices metadata + List of Step objects to search for tasks. + metadata_key : str + Metadata key to filter tasks on (e.g., 'foreach-execution-path'). + metadata_pattern : str + Regular expression pattern to match against the metadata value. - Returns - ------- - Iterator[Task] - Tasks matching the foreach path pattern + Yields + ------ + str + Pathspec of each task whose metadata value for the specified key matches the pattern. """ flow_id, run_id, _, _ = self.path_components - for step in steps: task_pathspecs = self._metaflow.metadata.filter_tasks_by_metadata( - flow_id, run_id, step.id, metadata_key, metadata_pattern + flow_id, run_id, step, metadata_key, metadata_pattern ) for task_pathspec in task_pathspecs: - yield Task(pathspec=task_pathspec, _namespace_check=False) + yield task_pathspec + + @staticmethod + def _get_previous_steps(graph_info, step_name): + # Get the parent steps + steps = [] + for node_name, attributes in graph_info["steps"].items(): + if step_name in attributes["next"]: + steps.append(node_name) + return steps @property - def parent_tasks(self) -> Iterator["Task"]: + def parent_task_pathspecs(self) -> Iterator[str]: """ - Yields all parent tasks of the current task if one exists. + Yields pathspecs of all parent tasks of the current task. Yields ------ - Task - Parent task of the current task - + str + Pathspec of the parent task of the current task """ - flow_id, run_id, _, _ = self.path_components + flow_id, run_id, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data - steps = list(self.parent.parent_steps) - if not steps: - return [] - - current_path = self.metadata_dict.get("foreach-execution-path", "") + # Get the parent steps + steps = self._get_previous_steps(graph_info, step_name) + node_type = graph_info["steps"][step_name]["type"] + current_path = metadata_dict.get("foreach-execution-path") if len(steps) > 1: # Static join - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Parent task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach join - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No parent steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach split or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) + current_depth = len(current_path.split(",")) + if node_type == "join": + # Foreach join + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach split or linear step + # Pattern: "A:10,B:13" + parent_step_type = graph_info["steps"][steps[0]]["type"] + target_depth = current_depth + if parent_step_type == "split-foreach" and current_depth == 1: + # (Current task, "A:10") and (Parent task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Parent task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Parent task, "A:10,B:13") + if parent_step_type == "split-foreach": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + metadata_key = "foreach-execution-path" + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec @property - def child_tasks(self) -> Iterator["Task"]: + def child_task_pathspecs(self) -> Iterator[str]: """ - Yield all child tasks of the current task if one exists. + Yields pathspecs of all child tasks of the current task. Yields ------ - Task - Child task of the current task + str + Pathspec of the child task of the current task """ - flow_id, run_id, _, _ = self.path_components - steps = list(self.parent.child_steps) - if not steps: - return [] + flow_id, run_id, step_name, _ = self.path_components + metadata_dict = self.metadata_dict + graph_info = self["_graph_info"].data + + # Get the child steps + steps = graph_info["steps"][step_name]["next"] - current_path = self.metadata_dict.get("foreach-execution-path", "") + node_type = graph_info["steps"][step_name]["type"] + current_path = self.metadata_dict.get("foreach-execution-path") if len(steps) > 1: # Static split - use exact path matching pattern = current_path or ".*" - yield from self._iter_matching_tasks( - steps, "foreach-execution-path", pattern - ) - return - - # Handle single step case - target_task = Step( - f"{flow_id}/{run_id}/{steps[0].id}", _namespace_check=False - ).task - target_path = target_task.metadata_dict.get("foreach-execution-path") - - if not target_path or not current_path: - # (Current task, "A:10") and (Child task, "") - # Pattern: ".*" - pattern = ".*" else: - current_depth = len(current_path.split(",")) - target_depth = len(target_path.split(",")) - - if current_depth < target_depth: - # Foreach split - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") - # Pattern: "A:10,B:13,.*" - pattern = f"{current_path},.*" + if not steps: + return # No child steps, yield nothing + + if not current_path: + # Current task is not part of a foreach + # Pattern: ".*" + pattern = ".*" else: - # Foreach join or linear step - # Option 1: - # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") - # Option 2: - # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") - # Pattern: "A:10,B:13" - pattern = ",".join(current_path.split(",")[:target_depth]) - - yield from self._iter_matching_tasks(steps, "foreach-execution-path", pattern) + current_depth = len(current_path.split(",")) + if node_type == "split-foreach": + # Foreach split + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13,C:21") + # Pattern: "A:10,B:13,.*" + pattern = f"{current_path},.*" + else: + # Foreach join or linear step + # Pattern: "A:10,B:13" + child_step_type = graph_info["steps"][steps[0]]["type"] + + # We need to know if the child step is a foreach join or a static join + child_step_prev_steps = self._get_previous_steps( + graph_info, steps[0] + ) + if len(child_step_prev_steps) > 1: + child_step_type = "static-join" + target_depth = current_depth + if child_step_type == "join" and current_depth == 1: + # (Current task, "A:10") and (Child task, "") + pattern = ".*" + else: + # (Current task, "A:10,B:13,C:21") and (Child task, "A:10,B:13") + # (Current task, "A:10,B:13") and (Child task, "A:10,B:13") + if child_step_type == "join": + target_depth = current_depth - 1 + pattern = ",".join(current_path.split(",")[:target_depth]) + + metadata_key = "foreach-execution-path" + for pathspec in self._get_matching_pathspecs(steps, metadata_key, pattern): + yield pathspec + + @property + def parent_tasks(self) -> Iterator["Task"]: + """ + Yields all parent tasks of the current task if one exists. + + Yields + ------ + Task + Parent task of the current task + """ + for pathspec in self.parent_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) + + @property + def child_tasks(self) -> Iterator["Task"]: + """ + Yields all child tasks of the current task if one exists. + + Yields + ------ + Task + Child task of the current task + """ + for pathspec in self.child_task_pathspecs: + yield Task(pathspec=pathspec, _namespace_check=False) @property def metadata(self) -> List[Metadata]: diff --git a/metaflow/datastore/__init__.py b/metaflow/datastore/__init__.py index 793251b0cff..65bb33b0eb9 100644 --- a/metaflow/datastore/__init__.py +++ b/metaflow/datastore/__init__.py @@ -2,3 +2,4 @@ from .flow_datastore import FlowDataStore from .datastore_set import TaskDataStoreSet from .task_datastore import TaskDataStore +from .spin_datastore import SpinTaskDatastore diff --git a/metaflow/datastore/datastore_set.py b/metaflow/datastore/datastore_set.py index f60642de73f..403eadbfd4d 100644 --- a/metaflow/datastore/datastore_set.py +++ b/metaflow/datastore/datastore_set.py @@ -21,9 +21,18 @@ def __init__( pathspecs=None, prefetch_data_artifacts=None, allow_not_done=False, + join_type=None, + spin_metadata=None, + spin_artifacts=None, ): self.task_datastores = flow_datastore.get_task_datastores( - run_id, steps=steps, pathspecs=pathspecs, allow_not_done=allow_not_done + run_id, + steps=steps, + pathspecs=pathspecs, + allow_not_done=allow_not_done, + join_type=join_type, + spin_metadata=spin_metadata, + spin_artifacts=spin_artifacts, ) if prefetch_data_artifacts: diff --git a/metaflow/datastore/flow_datastore.py b/metaflow/datastore/flow_datastore.py index 16318ed7693..1e7d1c102d1 100644 --- a/metaflow/datastore/flow_datastore.py +++ b/metaflow/datastore/flow_datastore.py @@ -5,6 +5,7 @@ from .content_addressed_store import ContentAddressedStore from .task_datastore import TaskDataStore +from .spin_datastore import SpinTaskDatastore class FlowDataStore(object): @@ -76,6 +77,9 @@ def get_task_datastores( attempt=None, include_prior=False, mode="r", + join_type=None, + spin_metadata=None, + spin_artifacts=None, ): """ Return a list of TaskDataStore for a subset of the tasks. @@ -106,6 +110,16 @@ def get_task_datastores( If True, returns all attempts up to and including attempt. mode : str, default "r" Mode to initialize the returned TaskDataStores in. + join_type : str, optional + If specified, the join type for the task. This is used to determine + the user specified artifacts for the task in case of a spin task. + spin_metadata : str, optional + The metadata provider in case of a spin task. If provided, the + returned TaskDataStore will be a SpinTaskDatastore instead of a + TaskDataStore. + spin_artifacts : Dict[str, Any], optional + Artifacts provided by user that can override the artifacts fetched via the + spin pathspec. Returns ------- @@ -198,7 +212,18 @@ def get_task_datastores( else (latest_started_attempts & done_attempts) ) latest_to_fetch = [ - (v[0], v[1], v[2], v[3], data_objs.get(v), mode, allow_not_done) + ( + v[0], + v[1], + v[2], + v[3], + data_objs.get(v), + mode, + allow_not_done, + join_type, + spin_metadata, + spin_artifacts, + ) for v in latest_to_fetch ] return list(itertools.starmap(self.get_task_datastore, latest_to_fetch)) @@ -212,7 +237,27 @@ def get_task_datastore( data_metadata=None, mode="r", allow_not_done=False, + join_type=None, + spin_metadata=None, + spin_artifacts=None, + persist=True, ): + if spin_metadata is not None: + # In spin step subprocess, use SpinTaskDatastore for accessing artifacts + if join_type is not None: + # If join_type is specified, we need to use the artifacts corresponding + # to that particular join index, specified by the parent task pathspec. + spin_artifacts = spin_artifacts.get( + f"{run_id}/{step_name}/{task_id}", {} + ) + return SpinTaskDatastore( + self.flow_name, + run_id, + step_name, + task_id, + spin_metadata, + spin_artifacts, + ) return TaskDataStore( self, run_id, @@ -222,6 +267,7 @@ def get_task_datastore( data_metadata=data_metadata, mode=mode, allow_not_done=allow_not_done, + persist=persist, ) def save_data(self, data_iter, len_hint=0): diff --git a/metaflow/datastore/spin_datastore.py b/metaflow/datastore/spin_datastore.py new file mode 100644 index 00000000000..3b7421a77ab --- /dev/null +++ b/metaflow/datastore/spin_datastore.py @@ -0,0 +1,111 @@ +from typing import Dict, Any +from .task_datastore import require_mode + + +class SpinTaskDatastore(object): + def __init__( + self, + flow_name: str, + run_id: str, + step_name: str, + task_id: str, + spin_metadata: str, + spin_artifacts: Dict[str, Any], + ): + """ + SpinTaskDatastore is a datastore for a task that is used to retrieve + artifacts and attributes for a spin step. It uses the task pathspec + from a previous execution of the step to access the artifacts and attributes. + + Parameters: + ----------- + flow_name : str + Name of the flow + run_id : str + Run ID of the flow + step_name : str + Name of the step + task_id : str + Task ID of the step + spin_metadata : str + Metadata for the spin task, typically a URI to the metadata service. + spin_artifacts : Dict[str, Any] + User provided artifacts that are to be used in the spin task. This is a dictionary + where keys are artifact names and values are the actual data or metadata. + """ + self.flow_name = flow_name + self.run_id = run_id + self.step_name = step_name + self.task_id = task_id + self.spin_metadata = spin_metadata + self.spin_artifacts = spin_artifacts + self._task = None + + # Update _objects and _info in order to persist artifacts + # See `persist` method in `TaskDatastore` for more details + self._objects = {} + self._info = {} + + for artifact in self.task.artifacts: + self._objects[artifact.id] = artifact.sha + # Fulfills the contract for _info: name -> metadata + self._info[artifact.id] = { + # Do not save the type of the data + # "type": str(type(artifact.data)), + "size": artifact.size, + "encoding": artifact._object["content_type"], + } + + @property + def task(self): + if self._task is None: + # Initialize the metaflow + from metaflow import Task + + # print(f"Setting task with metadata: {self.spin_metadata} and pathspec: {self.run_id}/{self.step_name}/{self.task_id}") + self._task = Task( + f"{self.flow_name}/{self.run_id}/{self.step_name}/{self.task_id}", + _namespace_check=False, + # We need to get this form the task pathspec somehow + _current_metadata=self.spin_metadata, + ) + return self._task + + @require_mode(None) + def __getitem__(self, name): + try: + # Check if it's an artifact in the spin_artifacts + return self.spin_artifacts[name] + except Exception: + try: + # Check if it's an attribute of the task + # _foreach_stack, _foreach_index, ... + return self.task.__getitem__(name).data + except Exception: + # If not an attribute, check if it's an artifact + try: + return getattr(self.task.artifacts, name).data + except AttributeError: + raise AttributeError( + f"Attribute '{name}' not found in the previous execution of the task for " + f"`{self.step_name}`." + ) + + @require_mode(None) + def is_none(self, name): + val = self.__getitem__(name) + return val is None + + @require_mode(None) + def __contains__(self, name): + try: + _ = self.__getitem__(name) + return True + except AttributeError: + return False + + @require_mode(None) + def items(self): + if self._objects: + return self._objects.items() + return {} diff --git a/metaflow/datastore/task_datastore.py b/metaflow/datastore/task_datastore.py index 325cc1ea1ae..a63c89480ca 100644 --- a/metaflow/datastore/task_datastore.py +++ b/metaflow/datastore/task_datastore.py @@ -98,6 +98,7 @@ def __init__( data_metadata=None, mode="r", allow_not_done=False, + persist=True, ): self._storage_impl = flow_datastore._storage_impl @@ -114,6 +115,7 @@ def __init__( self._attempt = attempt self._metadata = flow_datastore.metadata self._parent = flow_datastore + self._persist = persist # The GZIP encodings are for backward compatibility self._encodings = {"pickle-v2", "gzip+pickle-v2"} @@ -681,6 +683,8 @@ def persist(self, flow): flow : FlowSpec Flow to persist """ + if not self._persist: + return if flow._datastore: self._objects.update(flow._datastore._objects) diff --git a/metaflow/metaflow_config.py b/metaflow/metaflow_config.py index eaf230bc383..4c030d76945 100644 --- a/metaflow/metaflow_config.py +++ b/metaflow/metaflow_config.py @@ -47,6 +47,13 @@ "DEFAULT_FROM_DEPLOYMENT_IMPL", "argo-workflows" ) +### +# Spin configuration +### +SPIN_ALLOWED_DECORATORS = from_conf( + "SPIN_ALLOWED_DECORATORS", ["conda", "pypi", "environment"] +) + ### # User configuration ### @@ -57,6 +64,7 @@ # Datastore configuration ### DATASTORE_SYSROOT_LOCAL = from_conf("DATASTORE_SYSROOT_LOCAL") +DATASTORE_SYSROOT_SPIN = from_conf("DATASTORE_SYSROOT_SPIN", "/tmp/metaflow") # S3 bucket and prefix to store artifacts for 's3' datastore. DATASTORE_SYSROOT_S3 = from_conf("DATASTORE_SYSROOT_S3") # Azure Blob Storage container and blob prefix diff --git a/metaflow/plugins/cards/card_decorator.py b/metaflow/plugins/cards/card_decorator.py index 7006997e5ee..e01fc4feb19 100644 --- a/metaflow/plugins/cards/card_decorator.py +++ b/metaflow/plugins/cards/card_decorator.py @@ -137,6 +137,7 @@ def step_init( self._flow_datastore = flow_datastore self._environment = environment self._logger = logger + self.card_options = None # We check for configuration options. We do this here before they are diff --git a/metaflow/runner/metaflow_runner.py b/metaflow/runner/metaflow_runner.py index 2de20e0a1e0..9e69342d14a 100644 --- a/metaflow/runner/metaflow_runner.py +++ b/metaflow/runner/metaflow_runner.py @@ -17,27 +17,33 @@ from .subprocess_manager import CommandManager, SubprocessManager -class ExecutingRun(object): +class ExecutingProcess(object): """ - This class contains a reference to a `metaflow.Run` object representing - the currently executing or finished run, as well as metadata related - to the process. + This is a base class for `ExecutingRun` and `ExecutingTask` classes. + The `ExecutingRun` and `ExecutingTask` classes are returned by methods + in `Runner` and `NBRunner`, and they are subclasses of this class. - `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not - meant to be instantiated directly. + The `ExecutingRun` class for instance contains a reference to a `metaflow.Run` + object representing the currently executing or finished run, as well as the metadata + related to the process. + + Similarly, the `ExecutingTask` class contains a reference to a `metaflow.Task` + object representing the currently executing or finished task, as well as the metadata + related to the process. + + This class or its subclasses are not meant to be instantiated directly. The class + works as a context manager, allowing you to use a pattern like: - This class works as a context manager, allowing you to use a pattern like ```python with Runner(...).run() as running: ... ``` - Note that you should use either this object as the context manager or - `Runner`, not both in a nested manner. + + Note that you should use either this object as the context manager or `Runner`, not both + in a nested manner. """ - def __init__( - self, runner: "Runner", command_obj: CommandManager, run_obj: Run - ) -> None: + def __init__(self, runner: "Runner", command_obj: CommandManager) -> None: """ Create a new ExecutingRun -- this should not be done by the user directly but instead user Runner.run() @@ -53,7 +59,6 @@ def __init__( """ self.runner = runner self.command_obj = command_obj - self.run = run_obj def __enter__(self) -> "ExecutingRun": return self @@ -189,6 +194,76 @@ async def stream_log( yield position, line +class ExecutingTask(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Task` object representing + the currently executing or finished task, as well as metadata related + to the process. + `ExecutingTask` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).spin() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, task_obj: "metaflow.Task" + ) -> None: + """ + Create a new ExecutingTask -- this should not be done by the user directly but + instead user Runner.spin() + Parameters + ---------- + runner : Runner + Parent runner for this task. + command_obj : CommandManager + CommandManager containing the subprocess executing this task. + task_obj : Task + Task object corresponding to this task. + """ + super().__init__(runner, command_obj) + self.task = task_obj + + +class ExecutingRun(ExecutingProcess): + """ + This class contains a reference to a `metaflow.Run` object representing + the currently executing or finished run, as well as metadata related + to the process. + `ExecutingRun` is returned by methods in `Runner` and `NBRunner`. It is not + meant to be instantiated directly. + This class works as a context manager, allowing you to use a pattern like + ```python + with Runner(...).run() as running: + ... + ``` + Note that you should use either this object as the context manager or + `Runner`, not both in a nested manner. + """ + + def __init__( + self, runner: "Runner", command_obj: CommandManager, run_obj: Run + ) -> None: + """ + Create a new ExecutingRun -- this should not be done by the user directly but + instead user Runner.run() + Parameters + ---------- + runner : Runner + Parent runner for this run. + command_obj : CommandManager + CommandManager containing the subprocess executing this run. + run_obj : Run + Run object corresponding to this run. + """ + super().__init__(runner, command_obj) + self.run = run_obj + + class RunnerMeta(type): def __new__(mcs, name, bases, dct): cls = super().__new__(mcs, name, bases, dct) @@ -257,7 +332,7 @@ def __init__( env: Optional[Dict[str, str]] = None, cwd: Optional[str] = None, file_read_timeout: int = 3600, - **kwargs + **kwargs, ): # these imports are required here and not at the top # since they interfere with the user defined Parameters @@ -373,6 +448,73 @@ def run(self, **kwargs) -> ExecutingRun: return self.__get_executing_run(attribute_file_fd, command_obj) + def __get_executing_task(self, attribute_file_fd, command_obj): + content = handle_timeout(attribute_file_fd, command_obj, self.file_read_timeout) + + command_obj.sync_wait() + + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + async def __async_get_executing_task(self, attribute_file_fd, command_obj): + content = await async_handle_timeout( + attribute_file_fd, command_obj, self.file_read_timeout + ) + content = json.loads(content) + pathspec = f"{content.get('flow_name')}/{content.get('run_id')}/{content.get('step_name')}/{content.get('task_id')}" + + # Set the correct metadata from the runner_attribute file corresponding to this run. + metadata_for_flow = content.get("metadata") + + from metaflow import Task + + task_object = Task( + pathspec, _namespace_check=False, _current_metadata=metadata_for_flow + ) + return ExecutingTask(self, command_obj, task_object) + + def spin(self, step_name, **kwargs): + """ + Blocking spin execution of the run. + This method will wait until the spun run has completed execution. + Parameters + ---------- + step_name : str + The name of the step to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + Returns + ------- + ExecutingTask + ExecutingTask containing the results of the spun task. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = self.spm.run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + show_output=self.show_output, + ) + command_obj = self.spm.get(pid) + + return self.__get_executing_task(attribute_file_fd, command_obj) + def resume(self, **kwargs) -> ExecutingRun: """ Blocking resume execution of the run. @@ -468,6 +610,42 @@ async def async_resume(self, **kwargs) -> ExecutingRun: return await self.__async_get_executing_run(attribute_file_fd, command_obj) + async def async_spin(self, step_name, spin_pathspec, **kwargs) -> ExecutingTask: + """ + Non-blocking spin execution of the run. + This method will return as soon as the spun task has launched. + + Note that this method is asynchronous and needs to be `await`ed. + + Parameters + ---------- + step_name : str + The name of the step to spin. + **kwargs : Any + Additional arguments that you would pass to `python ./myflow.py` after + the `spin` command. + + Returns + ------- + ExecutingTask + ExecutingTask representing the spun task that was started. + """ + with temporary_fifo() as (attribute_file_path, attribute_file_fd): + command = self.api(**self.top_level_kwargs).spin( + step_name=step_name, + runner_attribute_file=attribute_file_path, + **kwargs, + ) + + pid = await self.spm.async_run_command( + [sys.executable, *command], + env=self.env_vars, + cwd=self.cwd, + ) + command_obj = self.spm.get(pid) + + return await self.__async_get_executing_task(attribute_file_fd, command_obj) + def __exit__(self, exc_type, exc_value, traceback): self.spm.cleanup() diff --git a/metaflow/runtime.py b/metaflow/runtime.py index 7e9269841fb..6d1ed1b5fa6 100644 --- a/metaflow/runtime.py +++ b/metaflow/runtime.py @@ -24,7 +24,7 @@ from . import get_namespace from .metadata_provider import MetaDatum -from .metaflow_config import MAX_ATTEMPTS, UI_URL +from .metaflow_config import MAX_ATTEMPTS, UI_URL, SPIN_ALLOWED_DECORATORS from .exception import ( MetaflowException, MetaflowInternalError, @@ -36,7 +36,7 @@ from .decorators import flow_decorators from .flowspec import _FlowState from .mflog import mflog, RUNTIME_LOG_SOURCE -from .util import to_unicode, compress_list, unicode_type +from .util import to_unicode, compress_list, unicode_type, get_latest_task_pathspec from .clone_util import clone_task_helper from .unbounded_foreach import ( CONTROL_TASK_TAG, @@ -73,6 +73,201 @@ # TODO option: output dot graph periodically about execution +class SpinRuntime(object): + def __init__( + self, + flow, + graph, + flow_datastore, + metadata, + environment, + package, + logger, + entrypoint, + event_logger, + monitor, + step_func, + step_name, + spin_pathspec, + skip_decorators=False, + artifacts_module=None, + persist=True, + max_log_size=MAX_LOG_SIZE, + ): + from metaflow import Task + + self._flow = flow + self._graph = graph + self._flow_datastore = flow_datastore + self._metadata = metadata + self._environment = environment + self._package = package + self._logger = logger + self._entrypoint = entrypoint + self._event_logger = event_logger + self._monitor = monitor + + self._step_func = step_func + + # Verify whether the use has provided step-name or spin-pathspec + if not spin_pathspec: + task = get_latest_task_pathspec(flow.name, step_name) + else: + # The user already provided a spin-pathspec, verify if its valid + try: + task = Task(spin_pathspec, _namespace_check=False) + except Exception: + raise MetaflowException( + f"Invalid spin-pathspec: {spin_pathspec} for step: {step_name}" + ) + + spin_pathspec = task.pathspec + spin_metadata = task._metaflow.metadata.metadata_str() + self._spin_metadata = spin_metadata + self._spin_pathspec = spin_pathspec + self._persist = persist + self._spin_task = task + self._input_paths = None + self._split_index = None + self._whitelist_decorators = None + self._config_file_name = None + self._skip_decorators = skip_decorators + self._artifacts_module = artifacts_module + self._max_log_size = max_log_size + self._encoding = sys.stdout.encoding or "UTF-8" + + # Create a new run_id for the spin task + self.run_id = self._metadata.new_run_id() + for deco in self.whitelist_decorators: + deco.runtime_init(flow, graph, package, self.run_id) + + @property + def split_index(self): + """ + Returns the split index, caching the result after the first access. + """ + if self._split_index is None: + self._split_index = getattr(self._spin_task, "index", None) + + return self._split_index + + @property + def input_paths(self): + st = time.time() + + def _format_input_paths(task_pathspec): + _, run_id, step_name, task_id = task_pathspec.split("/") + return f"{run_id}/{step_name}/{task_id}" + + if self._input_paths: + return self._input_paths + + if self._step_func.name == "start": + from metaflow import Step + + flow_name, run_id, _, _ = self._spin_pathspec.split("/") + task = Step( + f"{flow_name}/{run_id}/_parameters", _namespace_check=False + ).task + self._input_paths = [_format_input_paths(task.pathspec)] + else: + self._input_paths = [ + _format_input_paths(task_pathspec) + for task_pathspec in self._spin_task.parent_task_pathspecs + ] + et = time.time() + print(f"Time taken to get input paths: {et - st}") + return self._input_paths + + @property + def whitelist_decorators(self): + if self._skip_decorators: + return [] + if self._whitelist_decorators: + return self._whitelist_decorators + self._whitelist_decorators = [ + deco + for deco in self._step_func.decorators + if any(deco.name.startswith(prefix) for prefix in SPIN_ALLOWED_DECORATORS) + ] + return self._whitelist_decorators + + def _new_task(self, step, input_paths=None, **kwargs): + return Task( + flow_datastore=self._flow_datastore, + flow=self._flow, + step=step, + run_id=self.run_id, + metadata=self._metadata, + environment=self._environment, + entrypoint=self._entrypoint, + event_logger=self._event_logger, + monitor=self._monitor, + input_paths=input_paths, + decos=self.whitelist_decorators, + logger=self._logger, + split_index=self.split_index, + **kwargs, + ) + + def execute(self): + exception = None + with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") as config_file: + config_value = dump_config_values(self._flow) + if config_value: + json.dump(config_value, config_file) + config_file.flush() + self._config_file_name = config_file.name + else: + self._config_file_name = None + + self.task = self._new_task(self._step_func.name, self.input_paths) + try: + self._launch_and_monitor_task() + except Exception as ex: + self._logger("Task failed.", system_msg=True, bad=True) + exception = ex + raise + finally: + for deco in self.whitelist_decorators: + deco.runtime_finished(exception) + + def _launch_and_monitor_task(self): + worker = Worker( + self.task, + self._max_log_size, + self._config_file_name, + spin_metadata=self._spin_metadata, + spin_pathspec=self._spin_pathspec, + whitelist_decorators=self.whitelist_decorators, + artifacts_module=self._artifacts_module, + persist=self._persist, + ) + + poll = procpoll.make_poll() + fds = worker.fds() + for fd in fds: + poll.add(fd) + + active_fds = set(fds) + + while active_fds: + events = poll.poll(POLL_TIMEOUT) + for event in events: + if event.can_read: + worker.read_logline(event.fd) + if event.is_terminated: + poll.remove(event.fd) + active_fds.remove(event.fd) + + returncode = worker.terminate() + + if returncode != 0: + raise TaskFailed(self.task, f"Task failed with return code {returncode}") + else: + self._logger("Task finished successfully.", system_msg=True) + + class NativeRuntime(object): def __init__( self, @@ -1508,8 +1703,21 @@ class CLIArgs(object): for step execution in StepDecorator.runtime_step_cli(). """ - def __init__(self, task): + def __init__( + self, + task, + spin_metadata=None, + spin_pathspec=None, + whitelist_decorators=None, + artifacts_module=None, + persist=True, + ): self.task = task + self.spin_metadata = spin_metadata + self.spin_pathspec = spin_pathspec + self.whitelist_decorators = whitelist_decorators + self.artifacts_module = artifacts_module + self.persist = persist self.entrypoint = list(task.entrypoint) self.top_level_options = { "quiet": True, @@ -1542,21 +1750,50 @@ def __init__(self, task): (k, ConfigInput.make_key_name(k)) for k in configs ] + if spin_pathspec: + self.spin_args() + else: + self.default_args() + + def default_args(self): self.commands = ["step"] self.command_args = [self.task.step] self.command_options = { - "run-id": task.run_id, - "task-id": task.task_id, - "input-paths": compress_list(task.input_paths), - "split-index": task.split_index, - "retry-count": task.retries, - "max-user-code-retries": task.user_code_retries, - "tag": task.tags, + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "tag": self.task.tags, "namespace": get_namespace() or "", - "ubf-context": task.ubf_context, + "ubf-context": self.task.ubf_context, } self.env = {} + def spin_args(self): + self.commands = ["spin-step"] + self.command_args = [self.task.step] + + whitelist_decos = [deco.name for deco in self.whitelist_decorators] + + self.command_options = { + "run-id": self.task.run_id, + "task-id": self.task.task_id, + "input-paths": compress_list(self.task.input_paths), + "split-index": self.task.split_index, + "retry-count": self.task.retries, + "max-user-code-retries": self.task.user_code_retries, + "namespace": get_namespace() or "", + "spin-metadata": self.spin_metadata, + "spin-pathspec": self.spin_pathspec, + "whitelist-decorators": compress_list(whitelist_decos), + "artifacts-module": self.artifacts_module, + } + if self.persist: + self.command_options["persist"] = True + self.env = {} + def get_args(self): # TODO: Make one with dict_to_cli_options; see cli_args.py for more detail def _options(mapping): @@ -1595,9 +1832,24 @@ def __str__(self): class Worker(object): - def __init__(self, task, max_logs_size, config_file_name): + def __init__( + self, + task, + max_logs_size, + config_file_name, + spin_metadata=None, + spin_pathspec=None, + whitelist_decorators=None, + artifacts_module=None, + persist=True, + ): self.task = task self._config_file_name = config_file_name + self._spin_metadata = spin_metadata + self._spin_pathspec = spin_pathspec + self._whitelist_decorators = whitelist_decorators + self._artifacts_module = artifacts_module + self._persist = persist self._proc = self._launch() if task.retries > task.user_code_retries: @@ -1629,7 +1881,14 @@ def __init__(self, task, max_logs_size, config_file_name): # not it is properly shut down) def _launch(self): - args = CLIArgs(self.task) + args = CLIArgs( + self.task, + spin_metadata=self._spin_metadata, + spin_pathspec=self._spin_pathspec, + whitelist_decorators=self._whitelist_decorators, + artifacts_module=self._artifacts_module, + persist=self._persist, + ) env = dict(os.environ) if self.task.clone_run_id: @@ -1663,6 +1922,7 @@ def _launch(self): # print('running', args) cmdline = args.get_args() debug.subcommand_exec(cmdline) + # print(f"Command: {cmdline}") return subprocess.Popen( cmdline, env=env, @@ -1784,13 +2044,14 @@ def terminate(self): else: self.emit_log(b"Task failed.", self._stderr, system_msg=True) else: - num = self.task.results["_foreach_num_splits"] - if num: - self.task.log( - "Foreach yields %d child steps." % num, - system_msg=True, - pid=self._proc.pid, - ) + if not self._spin_pathspec: + num = self.task.results["_foreach_num_splits"] + if num: + self.task.log( + "Foreach yields %d child steps." % num, + system_msg=True, + pid=self._proc.pid, + ) self.task.log( "Task finished successfully.", system_msg=True, pid=self._proc.pid ) diff --git a/metaflow/task.py b/metaflow/task.py index 414b7e54710..961b7c29a0a 100644 --- a/metaflow/task.py +++ b/metaflow/task.py @@ -47,6 +47,8 @@ def __init__( event_logger, monitor, ubf_context, + spin_metadata=None, + spin_artifacts=None, ): self.flow = flow self.flow_datastore = flow_datastore @@ -56,6 +58,8 @@ def __init__( self.event_logger = event_logger self.monitor = monitor self.ubf_context = ubf_context + self.spin_metadata = spin_metadata + self.spin_artifacts = spin_artifacts def _exec_step_function(self, step_function, input_obj=None): if input_obj is None: @@ -120,7 +124,6 @@ def property_setter( lambda _, parameter_ds=parameter_ds: parameter_ds["_graph_info"], ) all_vars.append("_graph_info") - if passdown: self.flow._datastore.passdown_partial(parameter_ds, all_vars) return param_only_vars @@ -147,6 +150,9 @@ def _init_data(self, run_id, join_type, input_paths): run_id, pathspecs=input_paths, prefetch_data_artifacts=prefetch_data_artifacts, + join_type=join_type, + spin_metadata=self.spin_metadata, + spin_artifacts=self.spin_artifacts, ) ds_list = [ds for ds in datastore_set] if len(ds_list) != len(input_paths): @@ -160,7 +166,14 @@ def _init_data(self, run_id, join_type, input_paths): for input_path in input_paths: run_id, step_name, task_id = input_path.split("/") ds_list.append( - self.flow_datastore.get_task_datastore(run_id, step_name, task_id) + self.flow_datastore.get_task_datastore( + run_id, + step_name, + task_id, + join_type=join_type, + spin_metadata=self.spin_metadata, + spin_artifacts=self.spin_artifacts, + ) ) if not ds_list: # this guards against errors in input paths @@ -382,6 +395,8 @@ def run_step( split_index, retry_count, max_user_code_retries, + whitelist_decorators=None, + persist=True, ): if run_id and task_id: self.metadata.register_run_id(run_id) @@ -398,7 +413,6 @@ def run_step( raise MetaflowInternalError( "Too many task attempts (%d)! MAX_ATTEMPTS exceeded." % retry_count ) - metadata_tags = ["attempt_id:{0}".format(retry_count)] metadata = [ @@ -440,6 +454,11 @@ def run_step( step_func = getattr(self.flow, step_name) decorators = step_func.decorators + if self.spin_metadata: + # We filter only the whitelisted decorators in case of spin step. + decorators = [ + deco for deco in decorators if deco.name in whitelist_decorators + ] node = self.flow._graph[step_name] join_type = None @@ -448,7 +467,7 @@ def run_step( # 1. initialize output datastore output = self.flow_datastore.get_task_datastore( - run_id, step_name, task_id, attempt=retry_count, mode="w" + run_id, step_name, task_id, attempt=retry_count, mode="w", persist=persist ) output.init_task() @@ -461,7 +480,6 @@ def run_step( self._init_foreach(step_name, join_type, inputs, split_index) # Add foreach stack to metadata of the task - foreach_stack = ( self.flow._foreach_stack if hasattr(self.flow, "_foreach_stack") and self.flow._foreach_stack @@ -583,7 +601,6 @@ def run_step( # should either be set prior to running the user code or listed in # FlowSpec._EPHEMERAL to allow for proper merging/importing of # user artifacts in the user's step code. - if join_type: # Join step: @@ -624,6 +641,27 @@ def run_step( "inputs." % step_name ) self.flow._set_datastore(inputs[0]) + # Iterate over all artifacts in the parent pathspec and add them + # to the current flow's datastore. We need to do this explictly + # since we want to persist even those attributes that are not + # used / redefined in the spin step. + if self.spin_metadata and persist: + st_time = time.time() + for artifact_name in self.flow._datastore._objects.keys(): + # This is highly inefficient since we are loading data + # that we don't need, but there is no better way to + # support this now + artifact_data = self.spin_artifacts.get( + artifact_name, self.flow._datastore[artifact_name] + ) + setattr( + self.flow, + artifact_name, + artifact_data, + ) + print( + f"Time taken to load all artifacts: {time.time() - st_time:.2f} seconds" + ) if input_paths: # initialize parameters (if they exist) # We take Parameter values from the first input, @@ -666,10 +704,14 @@ def run_step( self.ubf_context, ) + st_time = time.time() if join_type: self._exec_step_function(step_func, input_obj) else: self._exec_step_function(step_func) + print( + f"Time taken to run the step function: {time.time() - st_time:.2f} seconds" + ) for deco in decorators: deco.task_post_step( @@ -726,7 +768,11 @@ def run_step( ) try: # persisting might fail due to unpicklable artifacts. + st_time = time.time() output.persist(self.flow) + print( + f"Time taken to persist the output: {time.time() - st_time:.2f} seconds" + ) except Exception as ex: self.flow._task_ok = False raise ex diff --git a/metaflow/util.py b/metaflow/util.py index f9051aff589..4666678d5e9 100644 --- a/metaflow/util.py +++ b/metaflow/util.py @@ -7,6 +7,7 @@ from functools import wraps from io import BytesIO from itertools import takewhile +from typing import Dict, Any import re from metaflow.exception import MetaflowUnknownUser, MetaflowInternalError @@ -177,6 +178,45 @@ def resolve_identity(): return "%s:%s" % (identity_type, identity_value) +def get_latest_task_pathspec(flow_name: str, step_name: str) -> (str, str): + """ + Returns a task pathspec from the latest run of the flow for the queried step. + If the queried step has several tasks, the task pathspec of the first task is returned. + + Parameters + ---------- + flow_name : str + The name of the flow. + step_name : str + The name of the step. + + Returns + ------- + Task + A Metaflow Task instance containing the latest task for the queried step. + + Raises + ------ + MetaflowNotFound + If no task or run is found for the queried step. + """ + from metaflow import Flow, Step + from metaflow.exception import MetaflowNotFound + + run = Flow(flow_name, _namespace_check=False).latest_run + + if run is None: + raise MetaflowNotFound(f"No run found for the flow {flow_name}") + + try: + task = Step(f"{flow_name}/{run.id}/{step_name}", _namespace_check=False).task + return task + except Exception: + raise MetaflowNotFound( + f"No step *{step_name}* found in run *{run.id}* for flow *{flow_name}*" + ) + + def get_latest_run_id(echo, flow_name): from metaflow.plugins.datastores.local_storage import LocalStorage @@ -467,6 +507,41 @@ def to_pod(value): return str(value) +def read_artifacts_module(file_path: str) -> Dict[str, Any]: + """ + Read a Python module from the given file path and return its ARTIFACTS variable. + + Parameters + ---------- + file_path : str + The path to the Python file containing the ARTIFACTS variable. + + Returns + ------- + Dict[str, Any] + A dictionary containing the ARTIFACTS variable from the module. + + Raises + ------- + MetaflowInternalError + If the file cannot be read or does not contain the ARTIFACTS variable. + """ + import importlib.util + + try: + spec = importlib.util.spec_from_file_location("artifacts_module", file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + variables = vars(module) + if "ARTIFACTS" not in variables: + raise MetaflowInternalError( + f"Module {file_path} does not contain ARTIFACTS variable" + ) + return variables.get("ARTIFACTS") + except Exception as e: + raise MetaflowInternalError(f"Error reading file {file_path}") from e + + if sys.version_info[:2] > (3, 5): from metaflow._vendor.packaging.version import parse as version_parse else: diff --git a/test/unit/spin/artifacts/complex_dag_step_a.py b/test/unit/spin/artifacts/complex_dag_step_a.py new file mode 100644 index 00000000000..b7e81bf1b6f --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_a.py @@ -0,0 +1 @@ +ARTIFACTS = {"my_output": [10, 11, 12]} diff --git a/test/unit/spin/artifacts/complex_dag_step_d.py b/test/unit/spin/artifacts/complex_dag_step_d.py new file mode 100644 index 00000000000..5aa40d64766 --- /dev/null +++ b/test/unit/spin/artifacts/complex_dag_step_d.py @@ -0,0 +1,11 @@ +from metaflow import Run + + +def _get_artifact(): + task = Run("ComplexDAGFlow/2")["step_d"].task + task_pathspec = next(task.parent_task_pathspecs) + _, inp_path = task_pathspec.split("/", 1) + return {inp_path: {"my_output": [-1]}} + + +ARTIFACTS = _get_artifact() diff --git a/test/unit/spin/complex_dag_flow.py b/test/unit/spin/complex_dag_flow.py new file mode 100644 index 00000000000..04b185fe40f --- /dev/null +++ b/test/unit/spin/complex_dag_flow.py @@ -0,0 +1,116 @@ +from metaflow import FlowSpec, step, project, conda, Task, pypi + + +class ComplexDAGFlow(FlowSpec): + @step + def start(self): + self.split_start = [1, 2, 3] + self.my_output = [] + print("My output is: ", self.my_output) + self.next(self.step_a, foreach="split_start") + + @step + def step_a(self): + self.split_a = [4, 5] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_b, foreach="split_a") + + @step + def step_b(self): + self.split_b = [6, 7, 8] + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_c, foreach="split_b") + + @conda(libraries={"numpy": "2.1.1"}) + @step + def step_c(self): + import numpy as np + + self.np_version = np.__version__ + print(f"numpy version: {self.np_version}") + self.my_output = self.my_output + [self.input] + [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_d) + + @step + def step_d(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_e) + + @step + def step_e(self): + print(f"I am step E. Input is: {self.input}") + self.split_e = [9, 10] + print("My output is: ", self.my_output) + self.next(self.step_f, foreach="split_e") + + @step + def step_f(self): + self.my_output = self.my_output + [self.input] + print("My output is: ", self.my_output) + self.next(self.step_g) + + @step + def step_g(self): + print("My output is: ", self.my_output) + self.next(self.step_h) + + @step + def step_h(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_i) + + @step + def step_i(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.step_j) + + @step + def step_j(self): + print("My output is: ", self.my_output) + self.next(self.step_k, self.step_l) + + @step + def step_k(self): + self.my_output = self.my_output + [11] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @step + def step_l(self): + print(f"I am step L. Input is: {self.input}") + self.my_output = self.my_output + [12] + print("My output is: ", self.my_output) + self.next(self.step_m) + + @conda(libraries={"scikit-learn": "1.3.0"}) + @step + def step_m(self, inputs): + import sklearn + + self.sklearn_version = sklearn.__version__ + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("Sklearn version: ", self.sklearn_version) + print("My output is: ", self.my_output) + self.next(self.step_n) + + @step + def step_n(self, inputs): + self.my_output = sorted([inp.my_output for inp in inputs])[0] + print("My output is: ", self.my_output) + self.next(self.end) + + @step + def end(self): + self.my_output = self.my_output + [13] + print("My output is: ", self.my_output) + print("Flow is complete!") + + +if __name__ == "__main__": + ComplexDAGFlow() diff --git a/test/unit/spin/merge_artifacts_flow.py b/test/unit/spin/merge_artifacts_flow.py new file mode 100644 index 00000000000..59f1390e052 --- /dev/null +++ b/test/unit/spin/merge_artifacts_flow.py @@ -0,0 +1,63 @@ +from metaflow import FlowSpec, step + + +class MergeArtifactsFlow(FlowSpec): + + @step + def start(self): + self.pass_down = "a" + self.next(self.a, self.b) + + @step + def a(self): + self.common = 5 + self.x = 1 + self.y = 3 + self.from_a = 6 + self.next(self.join) + + @step + def b(self): + self.common = 5 + self.x = 2 + self.y = 4 + self.next(self.join) + + @step + def join(self, inputs): + print(f"In join step, self._datastore: {(type(self._datastore))}") + self.x = inputs.a.x + self.merge_artifacts(inputs, exclude=["y"]) + print("x is %s" % self.x) + print("pass_down is %s" % self.pass_down) + print("common is %d" % self.common) + print("from_a is %d" % self.from_a) + self.next(self.c) + + @step + def c(self): + self.next(self.d, self.e) + + @step + def d(self): + self.conflicting = 7 + self.next(self.join2) + + @step + def e(self): + self.conflicting = 8 + self.next(self.join2) + + @step + def join2(self, inputs): + self.merge_artifacts(inputs, include=["pass_down", "common"]) + print("Only pass_down and common exist here") + self.next(self.end) + + @step + def end(self): + pass + + +if __name__ == "__main__": + MergeArtifactsFlow() diff --git a/test/unit/spin/test_spin.py b/test/unit/spin/test_spin.py new file mode 100644 index 00000000000..75577f9c624 --- /dev/null +++ b/test/unit/spin/test_spin.py @@ -0,0 +1,138 @@ +import pytest +from metaflow import Runner, Run + + +@pytest.fixture +def complex_dag_run(): + # with Runner('complex_dag_flow.py').run() as running: + # yield running.run + return Run("ComplexDAGFlow/3", _namespace_check=False) + + +@pytest.fixture +def merge_artifacts_run(): + # with Runner('merge_artifacts_flow.py').run() as running: + # yield running.run + return Run("MergeArtifactsFlow/55", _namespace_check=False) + + +def _assert_artifacts(task, spin_task): + spin_task_artifacts = { + artifact.id: artifact.data for artifact in spin_task.artifacts + } + print(f"Spin task artifacts: {spin_task_artifacts}") + for artifact in task.artifacts: + assert ( + artifact.id in spin_task_artifacts + ), f"Artifact {artifact.id} not found in spin task" + assert ( + artifact.data == spin_task_artifacts[artifact.id] + ), f"Expected {artifact.data} but got {spin_task_artifacts[artifact.id]} for artifact {artifact.id}" + + +def _run_step(file_name, run, step_name, is_conda=False): + task = run[step_name].task + if not is_conda: + with Runner(file_name).spin(step_name, spin_pathspec=task.pathspec) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + _assert_artifacts(task, spin.task) + else: + with Runner(file_name, environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + ) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + print(f"Spin task artifacts: {spin.task.artifacts}") + _assert_artifacts(task, spin.task) + + +def test_complex_dag_flow(complex_dag_run): + print(f"Running test for ComplexDAGFlow flow: {complex_dag_run}") + for step in complex_dag_run.steps(): + print("-" * 100) + _run_step("complex_dag_flow.py", complex_dag_run, step.id, is_conda=True) + + +def test_merge_artifacts_flow(merge_artifacts_run): + print(f"Running test for merge artifacts flow: {merge_artifacts_run}") + for step in merge_artifacts_run.steps(): + print("-" * 100) + _run_step("merge_artifacts_flow.py", merge_artifacts_run, step.id) + + +def test_artifacts_module(complex_dag_run): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_a" + task = complex_dag_run[step_name].task + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + artifacts_module="./artifacts/complex_dag_step_a.py", + ) as spin: + print("-" * 50) + print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + spin_task = spin.task + print(f"my_output: {spin_task['my_output']}") + assert spin_task["my_output"].data == [10, 11, 12, 3] + + +def test_artifacts_module_join_step(complex_dag_run): + print(f"Running test for artifacts module in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_d" + task = complex_dag_run[step_name].task + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + artifacts_module="./artifacts/complex_dag_step_d.py", + ) as spin: + print("-" * 50) + print(f"Running test for step: step_a with task pathspec: {task.pathspec}") + spin_task = spin.task + assert spin_task["my_output"].data == [-1] + + +def test_skip_decorators(complex_dag_run): + print(f"Running test for skip decorator in ComplexDAGFlow: {complex_dag_run}") + step_name = "step_m" + task = complex_dag_run[step_name].task + # Check if sklearn is available in the outer environment + # If not, this test will fail as it requires sklearn to be installed and skip_decorator + # is set to True + is_sklearn = True + try: + import sklearn + except ImportError: + is_sklearn = False + if is_sklearn: + # We verify that the sklearn version is the same as the one in the outside environment + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + skip_decorators=True, + ) as spin: + print("-" * 50) + print( + f"Running test for step: {step_name} with task pathspec: {task.pathspec}" + ) + spin_task = spin.task + import sklearn + + expected_version = sklearn.__version__ + assert ( + spin_task["sklearn_version"].data == expected_version + ), f"Expected sklearn version {expected_version} but got {spin_task['sklearn_version']}" + else: + # We assert that an exception is raised when trying to run the step with skip_decorators=True + with pytest.raises(Exception) as exc_info: + with Runner("complex_dag_flow.py", environment="conda").spin( + step_name, + spin_pathspec=task.pathspec, + skip_decorators=True, + ): + pass