Skip to content

Commit 4aa9e98

Browse files
authored
fix signature input type (#8413)
1 parent 58914cb commit 4aa9e98

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

dspy/predict/chain_of_thought.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
class ChainOfThought(Module):
1212
def __init__(
1313
self,
14-
signature: Type[Signature],
14+
signature: Union[str, Type[Signature]],
1515
rationale_field: Optional[Union[OutputField, FieldInfo]] = None,
1616
rationale_field_type: Type = str,
1717
**config: dict[str, Any],

dspy/predict/code_act.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import inspect
22
import logging
3-
from inspect import Signature
43
from typing import Callable, Optional, Type, Union
54

65
import dspy
76
from dspy.adapters.types.tool import Tool
87
from dspy.predict.program_of_thought import ProgramOfThought
98
from dspy.predict.react import ReAct
109
from dspy.primitives.python_interpreter import PythonInterpreter
11-
from dspy.signatures.signature import ensure_signature
10+
from dspy.signatures.signature import Signature, ensure_signature
1211

1312
logger = logging.getLogger(__name__)
1413

dspy/predict/predict.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import random
3+
from typing import Optional, Type, Union
34

45
from pydantic import BaseModel
56

@@ -10,13 +11,14 @@
1011
from dspy.predict.parameter import Parameter
1112
from dspy.primitives.module import Module
1213
from dspy.primitives.prediction import Prediction
13-
from dspy.signatures.signature import ensure_signature
14+
from dspy.signatures.signature import Signature, ensure_signature
15+
from dspy.utils.callback import BaseCallback
1416

1517
logger = logging.getLogger(__name__)
1618

1719

1820
class Predict(Module, Parameter):
19-
def __init__(self, signature, callbacks=None, **config):
21+
def __init__(self, signature: Union[str, Type[Signature]], callbacks: Optional[list[BaseCallback]] = None, **config):
2022
super().__init__(callbacks=callbacks)
2123
self.stage = random.randbytes(8).hex()
2224
self.signature = ensure_signature(signature)

0 commit comments

Comments
 (0)