Skip to content

Commit 762b520

Browse files
authored
refactor: Make target spec a discriminated union (#103)
1 parent 2fe0f68 commit 762b520

File tree

1 file changed

+43
-24
lines changed

1 file changed

+43
-24
lines changed

letta_evals/models.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class Sample(BaseModel):
2727
# Config models
2828

2929

30-
class TargetSpec(BaseModel):
31-
"""Target configuration for evaluation."""
30+
class BaseTargetSpec(BaseModel):
31+
"""Base target configuration with common fields."""
3232

3333
kind: TargetKind = Field(description="Type of target (agent)")
3434
base_url: str = Field(default="http://localhost:8283", description="Letta server URL")
@@ -37,12 +37,6 @@ class TargetSpec(BaseModel):
3737
project_id: Optional[str] = Field(default=None, description="Letta project ID")
3838
max_retries: int = Field(default=0, description="Maximum number of retries for failed create_stream calls")
3939

40-
agent_id: Optional[str] = Field(default=None, description="ID of existing agent to use")
41-
agent_file: Optional[Path] = Field(default=None, description="Path to .af agent file to upload")
42-
agent_script: Optional[str] = Field(
43-
default=None, description="Path to Python script with AgentFactory (e.g., script.py:FactoryClass)"
44-
)
45-
4640
# model configs to test (names without .json extension)
4741
model_configs: Optional[List[str]] = Field(
4842
default=None, description="List of model config names from llm_model_configs directory"
@@ -53,32 +47,57 @@ class TargetSpec(BaseModel):
5347
default=None, description="List of model handles (e.g., 'openai/gpt-4.1') for cloud deployments"
5448
)
5549

56-
# letta_code specific fields
57-
working_dir: Optional[Path] = Field(default=None, description="Working directory for letta code execution")
58-
allowed_tools: Optional[List[str]] = Field(
59-
default=None, description="List of allowed tools for letta code (e.g., ['Bash', 'Read'])"
60-
)
61-
disallowed_tools: Optional[List[str]] = Field(default=None, description="List of disallowed tools for letta code")
62-
6350
# internal field for path resolution
6451
base_dir: Optional[Path] = Field(default=None, exclude=True)
6552

53+
54+
class LettaAgentTargetSpec(BaseTargetSpec):
55+
"""Letta agent target configuration."""
56+
57+
kind: Literal[TargetKind.LETTA_AGENT] = TargetKind.LETTA_AGENT
58+
59+
agent_id: Optional[str] = Field(default=None, description="ID of existing agent to use")
60+
agent_file: Optional[Path] = Field(default=None, description="Path to .af agent file to upload")
61+
agent_script: Optional[str] = Field(
62+
default=None, description="Path to Python script with AgentFactory (e.g., script.py:FactoryClass)"
63+
)
64+
6665
@field_validator("agent_file")
66+
@classmethod
6767
def validate_agent_file(cls, v: Optional[Path]) -> Optional[Path]:
6868
if v and not str(v).endswith(".af"):
6969
raise ValueError("Agent file must have .af extension")
7070
return v
7171

72-
def __init__(self, **data):
73-
super().__init__(**data)
74-
if self.kind == TargetKind.LETTA_AGENT:
75-
sources = [self.agent_id, self.agent_file, self.agent_script]
76-
provided = sum(1 for s in sources if s is not None)
72+
@model_validator(mode="after")
73+
def validate_agent_source(self):
74+
sources = [self.agent_id, self.agent_file, self.agent_script]
75+
provided = sum(1 for s in sources if s is not None)
76+
77+
if provided == 0:
78+
raise ValueError("Agent target requires one of: agent_id, agent_file, or agent_script")
79+
if provided > 1:
80+
raise ValueError("Agent target can only have one of: agent_id, agent_file, or agent_script")
81+
82+
return self
83+
84+
85+
class LettaCodeTargetSpec(BaseTargetSpec):
86+
"""Letta code target configuration."""
87+
88+
kind: Literal[TargetKind.LETTA_CODE] = TargetKind.LETTA_CODE
7789

78-
if provided == 0:
79-
raise ValueError("Agent target requires one of: agent_id, agent_file, or agent_script")
80-
if provided > 1:
81-
raise ValueError("Agent target can only have one of: agent_id, agent_file, or agent_script")
90+
working_dir: Optional[Path] = Field(default=None, description="Working directory for letta code execution")
91+
allowed_tools: Optional[List[str]] = Field(
92+
default=None, description="List of allowed tools for letta code (e.g., ['Bash', 'Read'])"
93+
)
94+
disallowed_tools: Optional[List[str]] = Field(default=None, description="List of disallowed tools for letta code")
95+
96+
97+
TargetSpec = Annotated[
98+
Union[LettaAgentTargetSpec, LettaCodeTargetSpec],
99+
Field(discriminator="kind"),
100+
]
82101

83102

84103
class BaseGraderSpec(BaseModel):

0 commit comments

Comments
 (0)