Skip to content

Commit 2867d94

Browse files
committed
fix(nft): improve TraitAttribute validation
1 parent 5e34954 commit 2867d94

File tree

2 files changed

+135
-18
lines changed

2 files changed

+135
-18
lines changed

app/api/nft/models.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1+
import json
12
from enum import Enum
23
from typing import Any
34

4-
from pydantic import BaseModel, ConfigDict, Field, field_validator
5+
from pydantic import (
6+
BaseModel,
7+
ConfigDict,
8+
Field,
9+
field_validator,
10+
model_validator,
11+
)
512
from pydantic.alias_generators import to_camel
613

714

@@ -11,6 +18,52 @@ def strip_trailing_slash_validator(v: str | None) -> str | None:
1118
return v
1219

1320

21+
class TraitAttribute(BaseModel):
22+
trait_type: str
23+
value: str | bool | int | float | None = None
24+
25+
@model_validator(mode="before")
26+
def check_trait_type_omitted(cls, data):
27+
if not isinstance(data, dict):
28+
raise ValueError("TraitAttribute data must be a dictionary")
29+
30+
trait_type = data.get("trait_type", "").strip() or data.get("name", "").strip()
31+
32+
if not trait_type and not data.get("value"):
33+
raise ValueError(
34+
"Either trait_type or value must be provided for TraitAttribute"
35+
)
36+
37+
data["trait_type"] = trait_type or "Unknown"
38+
39+
# Handle complex value types (dict, list) by serializing to JSON.
40+
# If serialization fails, the TraitAttribute will be invalid and skipped.
41+
value = data.get("value")
42+
if isinstance(value, (dict, list)):
43+
data["value"] = json.dumps(value)
44+
45+
return data
46+
47+
48+
class AttributesValidationMixin:
49+
"""Mixin to provide shared attribute validation logic for metadata classes."""
50+
51+
@field_validator("attributes", mode="before")
52+
@classmethod
53+
def validate_attributes(cls, v):
54+
if not isinstance(v, list):
55+
return []
56+
57+
# Filter out invalid attributes, keeping only valid ones
58+
def is_valid_attribute(attr_data):
59+
try:
60+
return TraitAttribute.model_validate(attr_data)
61+
except Exception:
62+
return False
63+
64+
return [attr_data for attr_data in v if is_valid_attribute(attr_data)]
65+
66+
1467
class AlchemyTokenType(str, Enum):
1568
ERC721 = "ERC721"
1669
ERC1155 = "ERC1155"
@@ -42,12 +95,7 @@ def validate_urls(cls, v: str | None) -> str | None:
4295
model_config = ConfigDict(alias_generator=to_camel)
4396

4497

45-
class TraitAttribute(BaseModel):
46-
trait_type: str
47-
value: str | bool | int | float | None = None
48-
49-
50-
class AlchemyRawMetadata(BaseModel):
98+
class AlchemyRawMetadata(BaseModel, AttributesValidationMixin):
5199
name: str | None = None
52100
description: str | None = None
53101
image: str | None = None
@@ -59,16 +107,6 @@ class AlchemyRawMetadata(BaseModel):
59107
def validate_urls(cls, v: str | None) -> str | None:
60108
return strip_trailing_slash_validator(v)
61109

62-
@field_validator("attributes", mode="before")
63-
@classmethod
64-
def validate_attributes(cls, v):
65-
if isinstance(v, list):
66-
# Already in the correct format
67-
return v
68-
else:
69-
# Return empty list for any other type (dict, string, etc.)
70-
return []
71-
72110
model_config = ConfigDict(alias_generator=to_camel)
73111

74112

@@ -198,7 +236,7 @@ def validate_image(cls, v: Any) -> str | None:
198236
)
199237

200238

201-
class SolanaAssetContentMetadata(BaseModel):
239+
class SolanaAssetContentMetadata(BaseModel, AttributesValidationMixin):
202240
name: str
203241
symbol: str | None = None
204242
description: str | None = None

app/api/nft/test_models.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
from pydantic import ValidationError
3+
4+
from app.api.nft.models import TraitAttribute
5+
6+
7+
@pytest.mark.parametrize(
8+
"input_data,expected_trait_type,expected_value",
9+
[
10+
({"trait_type": "Color", "value": "Red"}, "Color", "Red"),
11+
({"name": "Shape", "value": "Round"}, "Shape", "Round"),
12+
({"value": "Some Value"}, "Unknown", "Some Value"),
13+
({"trait_type": "", "value": "Empty String"}, "Unknown", "Empty String"),
14+
({"trait_type": " ", "value": "Whitespace"}, "Unknown", "Whitespace"),
15+
({"trait_type": "Size"}, "Size", None),
16+
({"trait_type": "Name", "value": "Test"}, "Name", "Test"),
17+
({"trait_type": "Count", "value": 42}, "Count", 42),
18+
({"trait_type": "Price", "value": 1.99}, "Price", 1.99),
19+
({"trait_type": "Rare", "value": True}, "Rare", True),
20+
({"trait_type": "Test", "value": None}, "Test", None),
21+
({"trait_type": "Test", "value": ""}, "Test", ""),
22+
(
23+
{
24+
"trait_type": "parent",
25+
"value": {"parent": {"PARENT_CANNOT_SET_TTL": False}},
26+
},
27+
"parent",
28+
'{"parent": {"PARENT_CANNOT_SET_TTL": false}}',
29+
),
30+
(
31+
{"trait_type": "tags", "value": ["tag1", "tag2", "tag3"]},
32+
"tags",
33+
'["tag1", "tag2", "tag3"]',
34+
),
35+
],
36+
)
37+
def test_trait_attribute_valid_cases(input_data, expected_trait_type, expected_value):
38+
attr = TraitAttribute.model_validate(input_data)
39+
assert attr.trait_type == expected_trait_type
40+
assert attr.value == expected_value
41+
42+
43+
@pytest.mark.parametrize(
44+
"input_data,expected_error",
45+
[
46+
("not a dict", "TraitAttribute data must be a dictionary"),
47+
({}, "Either trait_type or value must be provided"),
48+
({"invalid_field": "value"}, "Either trait_type or value must be provided"),
49+
(
50+
{"trait_type": "", "value": None},
51+
"Either trait_type or value must be provided for TraitAttribute",
52+
),
53+
(
54+
{"trait_type": "", "value": ""},
55+
"Either trait_type or value must be provided for TraitAttribute",
56+
),
57+
(
58+
{"name": "", "value": None},
59+
"Either trait_type or value must be provided for TraitAttribute",
60+
),
61+
(
62+
{"name": "", "value": ""},
63+
"Either trait_type or value must be provided for TraitAttribute",
64+
),
65+
(
66+
{"value": None},
67+
"Either trait_type or value must be provided for TraitAttribute",
68+
),
69+
(
70+
{"value": ""},
71+
"Either trait_type or value must be provided for TraitAttribute",
72+
),
73+
],
74+
)
75+
def test_trait_attribute_invalid_cases(input_data, expected_error):
76+
with pytest.raises(ValidationError) as exc_info:
77+
TraitAttribute.model_validate(input_data)
78+
79+
assert expected_error in str(exc_info.value)

0 commit comments

Comments
 (0)