Skip to content

Commit 6cd4096

Browse files
Depth Anything V2 (#6674)
- Updated the previous DepthAnything manual implementation to use the `transformers` implementation instead. So we can get upstream features. - Plugged in the DepthAnything models to be handled by Invoke's Model Manager. - `small_v2` model will use DepthAnythingV2. This has been added as a new model option and is now also the default in the Linear UI. ![opera_TxRhmbFole](https://github.com/user-attachments/assets/2a25abe3-ba0b-4f97-b75a-2ce5fd6246e6) # Merge Review and merge.
2 parents 140670d + 408a1d6 commit 6cd4096

File tree

14 files changed

+190
-786
lines changed

14 files changed

+190
-786
lines changed

invokeai/app/invocations/controlnet_image_processors.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from controlnet_aux.util import HWC3, ade_palette
2222
from PIL import Image
2323
from pydantic import BaseModel, Field, field_validator, model_validator
24+
from transformers import pipeline
25+
from transformers.pipelines import DepthEstimationPipeline
2426

2527
from invokeai.app.invocations.baseinvocation import (
2628
BaseInvocation,
@@ -44,13 +46,12 @@
4446
from invokeai.app.services.shared.invocation_context import InvocationContext
4547
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
4648
from invokeai.backend.image_util.canny import get_canny_edges
47-
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
49+
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
4850
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
4951
from invokeai.backend.image_util.hed import HEDProcessor
5052
from invokeai.backend.image_util.lineart import LineartProcessor
5153
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
5254
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
53-
from invokeai.backend.util.devices import TorchDevice
5455

5556

5657
class ControlField(BaseModel):
@@ -592,36 +593,48 @@ def run_processor(self, image: Image.Image) -> Image.Image:
592593
return color_map
593594

594595

595-
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
596+
DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small", "small_v2"]
597+
# DepthAnything V2 Small model is licensed under Apache 2.0 but not the base and large models.
598+
DEPTH_ANYTHING_MODELS = {
599+
"large": "LiheYoung/depth-anything-large-hf",
600+
"base": "LiheYoung/depth-anything-base-hf",
601+
"small": "LiheYoung/depth-anything-small-hf",
602+
"small_v2": "depth-anything/Depth-Anything-V2-Small-hf",
603+
}
596604

597605

598606
@invocation(
599607
"depth_anything_image_processor",
600608
title="Depth Anything Processor",
601609
tags=["controlnet", "depth", "depth anything"],
602610
category="controlnet",
603-
version="1.1.2",
611+
version="1.1.3",
604612
)
605613
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
606614
"""Generates a depth map based on the Depth Anything algorithm"""
607615

608616
model_size: DEPTH_ANYTHING_MODEL_SIZES = InputField(
609-
default="small", description="The size of the depth model to use"
617+
default="small_v2", description="The size of the depth model to use"
610618
)
611619
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
612620

613621
def run_processor(self, image: Image.Image) -> Image.Image:
614-
def loader(model_path: Path):
615-
return DepthAnythingDetector.load_model(
616-
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
617-
)
622+
def load_depth_anything(model_path: Path):
623+
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
624+
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
625+
return DepthAnythingPipeline(depth_anything_pipeline)
618626

619627
with self._context.models.load_remote_model(
620-
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader
621-
) as model:
622-
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
623-
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
624-
return processed_image
628+
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
629+
) as depth_anything_detector:
630+
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
631+
depth_map = depth_anything_detector.generate_depth(image)
632+
633+
# Resizing to user target specified size
634+
new_height = int(image.size[1] * (self.resolution / image.size[0]))
635+
depth_map = depth_map.resize((self.resolution, new_height))
636+
637+
return depth_map
625638

626639

627640
@invocation(

invokeai/backend/image_util/depth_anything/__init__.py

Lines changed: 0 additions & 90 deletions
This file was deleted.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Optional
2+
3+
import torch
4+
from PIL import Image
5+
from transformers.pipelines import DepthEstimationPipeline
6+
7+
from invokeai.backend.raw_model import RawModel
8+
9+
10+
class DepthAnythingPipeline(RawModel):
11+
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
12+
for Invoke's Model Management System"""
13+
14+
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
15+
self._pipeline = pipeline
16+
17+
def generate_depth(self, image: Image.Image) -> Image.Image:
18+
depth_map = self._pipeline(image)["depth"]
19+
assert isinstance(depth_map, Image.Image)
20+
return depth_map
21+
22+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
23+
if device is not None and device.type not in {"cpu", "cuda"}:
24+
device = None
25+
self._pipeline.model.to(device=device, dtype=dtype)
26+
self._pipeline.device = self._pipeline.model.device
27+
28+
def calc_size(self) -> int:
29+
from invokeai.backend.model_manager.load.model_util import calc_module_size
30+
31+
return calc_module_size(self._pipeline.model)

invokeai/backend/image_util/depth_anything/model/blocks.py

Lines changed: 0 additions & 145 deletions
This file was deleted.

0 commit comments

Comments
 (0)