Skip to content

Commit d99c51b

Browse files
feat(nodes): improved pydantic type annotation massaging
When we do our field type overrides to allow invocations to be instantiated without all required fields, we were not modifying the annotation of the field but did set the default value of the field to `None`. This results in an error when doing a ser/de round trip. Here's what we end up doing: ```py from pydantic import BaseModel, Field class MyModel(BaseModel): foo: str = Field(default=None) ``` And here is a simple round-trip, which should not error but which does: ```py MyModel(**MyModel().model_dump()) # ValidationError: 1 validation error for MyModel # foo # Input should be a valid string [type=string_type, input_value=None, input_type=NoneType] # For further information visit https://errors.pydantic.dev/2.11/v/string_type ``` To fix this, we now check every incoming field and update its annotation to match its default value. In other words, when we override the default field value to `None`, we make its type annotation `<original type> | None`. This prevents the error during deserialization. This slightly alters the schema for all invocations and outputs - the values of all fields without default values are now typed as `<original type> | None`, reflecting the overrides. This means the autogenerated types for fields have also changed for fields without defaults: ```ts // Old image?: components["schemas"]["ImageField"]; // New image?: components["schemas"]["ImageField"] | null; ``` This does not break anything on the frontend.
1 parent b3ee906 commit d99c51b

File tree

2 files changed

+102
-20
lines changed

2 files changed

+102
-20
lines changed

invokeai/app/invocations/baseinvocation.py

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import inspect
66
import re
77
import sys
8+
import types
9+
import typing
810
import warnings
911
from abc import ABC, abstractmethod
1012
from enum import Enum
@@ -475,6 +477,18 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
475477
return None
476478

477479

480+
def is_optional(annotation: Any) -> bool:
481+
"""
482+
Checks if the given annotation is optional (i.e. Optional[X], Union[X, None] or X | None).
483+
"""
484+
origin = typing.get_origin(annotation)
485+
# PEP 604 unions (int|None) have origin types.UnionType
486+
is_union = origin is typing.Union or origin is types.UnionType
487+
if not is_union:
488+
return False
489+
return any(arg is type(None) for arg in typing.get_args(annotation))
490+
491+
478492
def invocation(
479493
invocation_type: str,
480494
title: Optional[str] = None,
@@ -507,6 +521,18 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
507521

508522
validate_fields(cls.model_fields, invocation_type)
509523

524+
fields: dict[str, tuple[Any, FieldInfo]] = {}
525+
526+
for field_name, field_info in cls.model_fields.items():
527+
annotation = field_info.annotation
528+
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
529+
assert isinstance(field_info.json_schema_extra, dict), (
530+
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
531+
)
532+
if field_info.default is None and not is_optional(annotation):
533+
annotation = annotation | None
534+
fields[field_name] = (annotation, field_info)
535+
510536
# Add OpenAPI schema extras
511537
uiconfig: dict[str, Any] = {}
512538
uiconfig["title"] = title
@@ -539,11 +565,17 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
539565
# Unfortunately, because the `GraphInvocation` uses a forward ref in its `graph` field's annotation, this does
540566
# not work. Instead, we have to create a new class with the type field and patch the original class with it.
541567

542-
invocation_type_annotation = Literal[invocation_type] # type: ignore
568+
invocation_type_annotation = Literal[invocation_type]
543569
invocation_type_field = Field(
544570
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
545571
)
546572

573+
# pydantic's Field function returns a FieldInfo, but they annotate it as returning a type so that type-checkers
574+
# don't get confused by something like this:
575+
# foo: str = Field() <-- this is a FieldInfo, not a str
576+
# Unfortunately this means we need to use type: ignore here to avoid type-checker errors
577+
fields["type"] = (invocation_type_annotation, invocation_type_field) # type: ignore
578+
547579
# Validate the `invoke()` method is implemented
548580
if "invoke" in cls.__abstractmethods__:
549581
raise ValueError(f'Invocation "{invocation_type}" must implement the "invoke" method')
@@ -565,17 +597,12 @@ def wrapper(cls: Type[TBaseInvocation]) -> Type[TBaseInvocation]:
565597
)
566598

567599
docstring = cls.__doc__
568-
cls = create_model(
569-
cls.__qualname__,
570-
__base__=cls,
571-
__module__=cls.__module__,
572-
type=(invocation_type_annotation, invocation_type_field),
573-
)
574-
cls.__doc__ = docstring
600+
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
601+
new_class.__doc__ = docstring
575602

576-
InvocationRegistry.register_invocation(cls)
603+
InvocationRegistry.register_invocation(new_class)
577604

578-
return cls
605+
return new_class
579606

580607
return wrapper
581608

@@ -600,23 +627,32 @@ def wrapper(cls: Type[TBaseInvocationOutput]) -> Type[TBaseInvocationOutput]:
600627

601628
validate_fields(cls.model_fields, output_type)
602629

630+
fields: dict[str, tuple[Any, FieldInfo]] = {}
631+
632+
for field_name, field_info in cls.model_fields.items():
633+
annotation = field_info.annotation
634+
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
635+
assert isinstance(field_info.json_schema_extra, dict), (
636+
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
637+
)
638+
if field_info.default is not PydanticUndefined and is_optional(annotation):
639+
annotation = annotation | None
640+
fields[field_name] = (annotation, field_info)
641+
603642
# Add the output type to the model.
604-
output_type_annotation = Literal[output_type] # type: ignore
643+
output_type_annotation = Literal[output_type]
605644
output_type_field = Field(
606645
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
607646
)
608647

648+
fields["type"] = (output_type_annotation, output_type_field) # type: ignore
649+
609650
docstring = cls.__doc__
610-
cls = create_model(
611-
cls.__qualname__,
612-
__base__=cls,
613-
__module__=cls.__module__,
614-
type=(output_type_annotation, output_type_field),
615-
)
616-
cls.__doc__ = docstring
651+
new_class = create_model(cls.__qualname__, __base__=cls, __module__=cls.__module__, **fields)
652+
new_class.__doc__ = docstring
617653

618-
InvocationRegistry.register_output(cls)
654+
InvocationRegistry.register_output(new_class)
619655

620-
return cls
656+
return new_class
621657

622658
return wrapper
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Any, Literal, Optional, Union
2+
3+
import pytest
4+
from pydantic import BaseModel
5+
6+
7+
class TestModel(BaseModel):
8+
foo: Literal["bar"] = "bar"
9+
10+
11+
@pytest.mark.parametrize(
12+
"input_type, expected",
13+
[
14+
(str, False),
15+
(list[str], False),
16+
(list[dict[str, Any]], False),
17+
(list[None], False),
18+
(list[dict[str, None]], False),
19+
(Any, False),
20+
(True, False),
21+
(False, False),
22+
(Union[str, False], False),
23+
(Union[str, True], False),
24+
(None, False),
25+
(str | None, True),
26+
(Union[str, None], True),
27+
(Optional[str], True),
28+
(str | int | None, True),
29+
(None | str | int, True),
30+
(Union[None, str], True),
31+
(Optional[str], True),
32+
(Optional[int], True),
33+
(Optional[str], True),
34+
(TestModel | None, True),
35+
(Union[TestModel, None], True),
36+
(Optional[TestModel], True),
37+
],
38+
)
39+
def test_is_optional(input_type: Any, expected: bool) -> None:
40+
"""
41+
Test the is_optional function.
42+
"""
43+
from invokeai.app.invocations.baseinvocation import is_optional
44+
45+
result = is_optional(input_type)
46+
assert result == expected, f"Expected {expected} but got {result} for input type {input_type}"

0 commit comments

Comments
 (0)