diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index f74a70726ba..001d03e54a1 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -18,6 +18,7 @@ ("tag", ".tag_cli.cli"), ("spot-metadata", ".kubernetes.spot_metadata_cli.cli"), ("logs", ".logs_cli.cli"), + ("environment", ".pypi.environment_cli.cli"), ] # Add additional commands to the runner here diff --git a/metaflow/plugins/pypi/conda_environment.py b/metaflow/plugins/pypi/conda_environment.py index 75a954a3023..962eddbbb90 100644 --- a/metaflow/plugins/pypi/conda_environment.py +++ b/metaflow/plugins/pypi/conda_environment.py @@ -5,6 +5,7 @@ import io import json import os +import shutil import tarfile import threading from concurrent.futures import ThreadPoolExecutor, as_completed @@ -32,6 +33,7 @@ def __init__(self, msg): class CondaEnvironment(MetaflowEnvironment): TYPE = "conda" _filecache = None + _disable_cache = False def __init__(self, flow): self.flow = flow @@ -107,7 +109,10 @@ def solve(id_, environment, type_): return ( id_, ( - self.read_from_environment_manifest([id_, platform, type_]) + ( + not self._disable_cache + and self.read_from_environment_manifest([id_, platform, type_]) + ) or self.write_to_environment_manifest( [id_, platform, type_], self.solvers[type_].solve(id_, **environment), @@ -153,7 +158,7 @@ def _path(url, local_path): _meta = copy.deepcopy(local_packages) for id_, packages, _, _ in results: for package in packages: - if package.get("path"): + if package.get("path") and not self._disable_cache: # Cache only those packages that manifest is unaware of local_packages.pop(package["url"], None) else: @@ -186,7 +191,7 @@ def _path(url, local_path): storage.save_bytes( list_of_path_and_filehandle, len_hint=len(list_of_path_and_filehandle), - # overwrite=True, + overwrite=self._disable_cache, ) for id_, packages, _, platform in results: if id_ in dirty: @@ -290,6 +295,9 @@ def pypi_solve(env): self.logger("Virtual environment(s) bootstrapped!") + def disable_cache(self): + self._disable_cache = True + def executable(self, step_name, default=None): step = next((step for step in self.flow if step.name == step_name), None) if step is None: @@ -319,6 +327,21 @@ def is_disabled(self, step): return str(disabled).lower() == "true" return False + def delete_environment(self, step): + env = self.get_environment(step) + paths = [] + for solver in self.solvers.keys(): + if solver not in env: + continue + for platform in env[solver].get("platforms", [None]): + paths.append( + self.solvers[solver].path_to_environment(env["id_"], platform) + ) + + # delete collected paths + for path in paths: + shutil.rmtree(path, ignore_errors=True) + @functools.lru_cache(maxsize=None) def get_environment(self, step): environment = {} diff --git a/metaflow/plugins/pypi/environment_cli.py b/metaflow/plugins/pypi/environment_cli.py new file mode 100644 index 00000000000..0a08629f298 --- /dev/null +++ b/metaflow/plugins/pypi/environment_cli.py @@ -0,0 +1,50 @@ +from metaflow._vendor import click +from metaflow.exception import MetaflowException + + +@click.group() +def cli(): + pass + + +@cli.group(help="Commands related to managing the conda/pypi environments") +@click.pass_context +def environment(ctx): + # the logger is configured in cli.py + global echo + echo = ctx.obj.echo + + +@environment.command(help="Resolve the environment(s)") +@click.option( + "--step", + "steps", + multiple=True, + default=[], + help="Steps to resolve the environment for", +) +@click.option( + "--force/--no-force", + default=False, + is_flag=True, + help="Force re-resolving the environment(s)", +) +@click.pass_obj +def resolve(obj, steps, force=False): + # possibly limiting steps to resolve. make sure its a list and not a tuple + step_names = list(steps) + + steps = [step for step in obj.flow if (step.name in step_names) or not step_names] + + # Delete existing environments if we are rebuilding. + if force: + for step in steps: + obj.environment.delete_environment(step) + + if not hasattr(obj.environment, "disable_cache"): + raise MetaflowException("The environment does not support disabling the cache.") + + # Disable the cache before initializing if we are rebuilding. + if force: + obj.environment.disable_cache() + obj.environment.init_environment(echo, only_steps=step_names)