Skip to content

Commit 576bb4a

Browse files
feat(mm): support generic API tokens via regex/token pairs in config
A list of regex and token pairs is accepted. As a file is downloaded by the model installer, the URL is tested against the provided regex/token pairs. The token for the first matching regex is used during download, added as a bearer token.
1 parent b6065d6 commit 576bb4a

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

invokeai/app/services/config/config_default.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,12 @@ class InvokeBatch(InvokeAISettings):
170170
from __future__ import annotations
171171

172172
import os
173+
import re
173174
from pathlib import Path
174175
from typing import Any, ClassVar, Dict, List, Literal, Optional
175176

176177
from omegaconf import DictConfig, OmegaConf
177-
from pydantic import Field
178+
from pydantic import BaseModel, Field, field_validator
178179
from pydantic.config import JsonDict
179180
from pydantic_settings import SettingsConfigDict
180181

@@ -205,6 +206,21 @@ class Categories(object):
205206
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
206207

207208

209+
class URLRegexToken(BaseModel):
210+
url_regex: str = Field(description="Regular expression to match against the URL")
211+
token: str = Field(description="Token to use when the URL matches the regex")
212+
213+
@field_validator("url_regex")
214+
@classmethod
215+
def validate_url_regex(cls, v: str) -> str:
216+
"""Validate that the value is a valid regex."""
217+
try:
218+
re.compile(v)
219+
except re.error as e:
220+
raise ValueError(f"Invalid regex: {e}")
221+
return v
222+
223+
208224
class InvokeAIAppConfig(InvokeAISettings):
209225
"""Configuration object for InvokeAI App."""
210226

@@ -288,7 +304,7 @@ class InvokeAIAppConfig(InvokeAISettings):
288304
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
289305

290306
# MODEL IMPORT
291-
remote_repo_api_key : Optional[str] = Field(default=os.environ.get("INVOKEAI_REMOTE_REPO_API_KEY"), description="API key used when downloading remote repositories", json_schema_extra=Categories.Other)
307+
remote_api_tokens : Optional[list[URLRegexToken]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.", json_schema_extra=Categories.Other)
292308

293309
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
294310
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)

invokeai/app/services/download/download_default.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,15 +241,12 @@ def _download_next_item(self) -> None:
241241
def _do_download(self, job: DownloadJob) -> None:
242242
"""Do the actual download."""
243243
url = job.source
244-
query_params = url.query_params()
245-
if job.access_token:
246-
query_params.append(("access_token", job.access_token))
247244
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
248245
open_mode = "wb"
249246

250247
# Make a streaming request. This will retrieve headers including
251248
# content-length and content-disposition, but not fetch any content itself
252-
resp = self._requests.get(str(url), params=query_params, headers=header, stream=True)
249+
resp = self._requests.get(str(url), headers=header, stream=True)
253250
if not resp.ok:
254251
raise HTTPError(resp.reason)
255252

invokeai/app/services/model_install/model_install_default.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,16 @@ def heuristic_import(
197197
access_token=access_token,
198198
)
199199
elif re.match(r"^https?://[^/]+", source):
200+
# Pull the token from config if it exists and matches the URL
201+
_token = access_token
202+
if _token is None:
203+
for pair in self.app_config.remote_api_tokens or []:
204+
if re.search(pair.url_regex, source):
205+
_token = pair.token
206+
break
200207
source_obj = URLModelSource(
201208
url=AnyHttpUrl(source),
202-
access_token=self.app_config.remote_repo_api_key,
209+
access_token=_token,
203210
)
204211
else:
205212
raise ValueError(f"Unsupported model source: '{source}'")

0 commit comments

Comments
 (0)