Skip to content

Commit 56efe71

Browse files
authored
Show warning if forward is directly called (#8380)
* show warning if forward is directly called * lint
1 parent 0a6b50e commit 56efe71

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

dspy/primitives/program.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import inspect
2+
import logging
13
from typing import Optional
24

35
import magicattr
@@ -9,6 +11,7 @@
911
from dspy.utils.inspect_history import pretty_print_history
1012
from dspy.utils.usage_tracker import track_usage
1113

14+
logger = logging.getLogger(__name__)
1215

1316
class ProgramMeta(type):
1417
"""Metaclass ensuring every ``dspy.Module`` instance is properly initialised."""
@@ -155,6 +158,20 @@ def batch(
155158
results = parallel_executor.forward(exec_pairs)
156159
return results
157160

161+
def __getattribute__(self, name):
162+
attr = super().__getattribute__(name)
163+
164+
if name == 'forward' and callable(attr):
165+
# Check if forward iscalled through __call__
166+
stack = inspect.stack()
167+
called_via_call = len(stack) > 1 and stack[1].function == '__call__'
168+
169+
if not called_via_call:
170+
logger.warning(f"Calling {self.__class__.__name__}.forward() directly is discouraged. "
171+
f"Use {self.__class__.__name__}() instead.")
172+
173+
return attr
174+
158175

159176
def set_attribute_by_name(obj, name, value):
160177
magicattr.set(obj, name, value)

tests/primitives/test_module.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,3 +489,25 @@ async def aforward(self, question: str, **kwargs) -> str:
489489
assert len(program.history) == 2
490490
assert len(program.cot.history) == 3
491491
assert len(program.cot.predict.history) == 3
492+
493+
494+
def test_forward_direct_call_warning(capsys):
495+
class TestModule(dspy.Module):
496+
def forward(self, x):
497+
return x
498+
499+
module = TestModule()
500+
module.forward("test")
501+
captured = capsys.readouterr()
502+
assert "Calling TestModule.forward() directly is discouraged" in captured.err
503+
504+
505+
def test_forward_through_call_no_warning(capsys):
506+
class TestModule(dspy.Module):
507+
def forward(self, x):
508+
return x
509+
510+
module = TestModule()
511+
module(x="test")
512+
captured = capsys.readouterr()
513+
assert "Calling TestModule.forward() directly is discouraged" not in captured.err

0 commit comments

Comments
 (0)