diff --git a/README.md b/README.md index e3b4f536..7864c7c0 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ Install the [PyPI](https://pypi.org/project/stability-sdk/) package via: - `pyenv/bin/activate` to use the venv. - Set the `STABILITY_HOST` environment variable. This is by default set to the production endpoint `grpc.stability.ai:443`. - Set the `STABILITY_KEY` environment variable. +- Optional, set the `DEFAULT_ENGINE` environment variable. This is by default set to `stable-diffusion-xl-1024-v1-0`. +- Optional, set the `DEFAULT_UPSCALE_ENGINE` environment variable. This is by default set to `esrgan-v1-x2plus`. Then to invoke: diff --git a/src/stability_sdk/client.py b/src/stability_sdk/client.py index 82c4ffd8..905df0c9 100644 --- a/src/stability_sdk/client.py +++ b/src/stability_sdk/client.py @@ -95,8 +95,8 @@ def __init__( self, host: str = "grpc.stability.ai:443", key: str = "", - engine: str = "stable-diffusion-xl-1024-v1-0", - upscale_engine: str = "esrgan-v1-x2plus", + engine: str = None, + upscale_engine: str = None, verbose: bool = False, wait_for_ready: bool = True, ): @@ -105,16 +105,20 @@ def __init__( :param host: Host to connect to. :param key: Key to use for authentication. - :param engine: Engine to use. - :param upscale_engine: Upscale engine to use. + :param engine: Engine to use. Defaults to the value from the environment + variable DEFAULT_ENGINE, or "stable-diffusion-xl-1024-v1-0" if the + variable is not set. + :param upscale_engine: Upscale engine to use. Defaults to the value from the + environment variable UPSCALE_ENGINE, or "esrgan-v1-x2plus" if the variable + is not set. :param verbose: Whether to print debug messages. :param wait_for_ready: Whether to wait for the server to be ready, or to fail immediately. """ self.verbose = verbose - self.engine = engine - self.upscale_engine = upscale_engine - + self.engine = engine or os.getenv("DEFAULT_ENGINE", "stable-diffusion-xl-1024-v1-0") + self.upscale_engine = upscale_engine or os.getenv("DEFAULT_UPSCALE_ENGINE", "esrgan-v1-x2plus") + self.grpc_args = {"wait_for_ready": wait_for_ready} if verbose: logger.info(f"Opening channel to {host}") @@ -500,7 +504,7 @@ def process_cli( "-e", type=str, help="engine to use for upscale", - default="esrgan-v1-x2plus", + default=None, ) parser_upscale.add_argument( "prompt", nargs="*" @@ -573,7 +577,7 @@ def process_cli( "-e", type=str, help="engine to use for inference", - default="stable-diffusion-xl-1024-v1-0", + default=None, ) parser_generate.add_argument( "--init_image",