Skip to content

Commit 0b93860

Browse files
committed
feat: add support for flow and node instance naming
1 parent b7444fa commit 0b93860

File tree

3 files changed

+316
-26
lines changed

3 files changed

+316
-26
lines changed

pocketflow/__init__.py

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,29 @@
1-
import asyncio, warnings, copy, time
1+
import asyncio, warnings, copy, time, sys
22

33
class BaseNode:
4-
def __init__(self): self.params,self.successors={},{}
4+
def __init__(self):
5+
self.params, self.successors = {}, {}
6+
self.name = self.get_instance_name() or f"node_{hash(self)}"
7+
self.flow = None # Will be set by Flow._propagate_flow
8+
self.parent = None # Will be set by Flow._propagate_flow
9+
10+
def get_instance_name(self):
11+
"""Find the variable name this instance is assigned to, if any"""
12+
try:
13+
frame = sys._getframe(1)
14+
while frame:
15+
for scope in (frame.f_locals, frame.f_globals):
16+
for key, value in scope.items():
17+
if value is self and not key.startswith('_') and key != 'self':
18+
return key
19+
frame = frame.f_back
20+
except (AttributeError, ValueError):
21+
pass
22+
return None
23+
24+
def _get_name(self):
25+
"""Return the instance name, either from name attribute or lookup"""
26+
return self.name or self.get_instance_name() or f"node_{hash(self)}"
527
def set_params(self,params): self.params=params
628
def add_successor(self,node,action="default"):
729
if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'")
@@ -11,8 +33,8 @@ def exec(self,prep_res): pass
1133
def post(self,shared,prep_res,exec_res): pass
1234
def _exec(self,prep_res): return self.exec(prep_res)
1335
def _run(self,shared): p=self.prep(shared);e=self._exec(p);return self.post(shared,p,e)
14-
def run(self,shared):
15-
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
36+
def run(self,shared):
37+
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
1638
return self._run(shared)
1739
def __rshift__(self,other): return self.add_successor(other)
1840
def __sub__(self,action):
@@ -37,22 +59,45 @@ class BatchNode(Node):
3759
def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in items]
3860

3961
class Flow(BaseNode):
40-
def __init__(self,start): super().__init__();self.start=start
62+
def __init__(self, start, name=None):
63+
super().__init__()
64+
self.start = start
65+
self.name = name or self.get_instance_name() or f"flow_{hash(self)}"
66+
self._propagate_flow(self.start)
67+
68+
def _propagate_flow(self, node, visited=None):
69+
"""Set flow and parent references on all nodes in the flow"""
70+
if visited is None:
71+
visited = set()
72+
73+
if node is None or id(node) in visited:
74+
return
75+
76+
visited.add(id(node))
77+
node.flow = self
78+
node.parent = self.parent if hasattr(self, 'parent') else None
79+
80+
for successor in node.successors.values():
81+
self._propagate_flow(successor, visited)
4182
def get_next_node(self,curr,action):
4283
nxt=curr.successors.get(action or "default")
4384
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
4485
return nxt
45-
def _orch(self,shared,params=None):
46-
curr,p=copy.copy(self.start),(params or {**self.params})
47-
while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
86+
def _orch(self, shared, params=None):
87+
curr, p = copy.copy(self.start), (params or {**self.params})
88+
while curr:
89+
curr.set_params(p)
90+
c = curr._run(shared)
91+
curr = copy.copy(self.get_next_node(curr, c))
4892
def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None)
4993
def exec(self,prep_res): raise RuntimeError("Flow can't exec.")
5094

5195
class BatchFlow(Flow):
52-
def _run(self,shared):
53-
pr=self.prep(shared) or []
54-
for bp in pr: self._orch(shared,{**self.params,**bp})
55-
return self.post(shared,pr,None)
96+
def _run(self, shared):
97+
pr = self.prep(shared) or []
98+
for bp in pr:
99+
self._orch(shared, {**self.params, **bp})
100+
return self.post(shared, pr, None)
56101

57102
class AsyncNode(Node):
58103
def prep(self,shared): raise RuntimeError("Use prep_async.")
@@ -64,14 +109,14 @@ async def prep_async(self,shared): pass
64109
async def exec_async(self,prep_res): pass
65110
async def exec_fallback_async(self,prep_res,exc): raise exc
66111
async def post_async(self,shared,prep_res,exec_res): pass
67-
async def _exec(self,prep_res):
112+
async def _exec(self,prep_res):
68113
for i in range(self.max_retries):
69114
try: return await self.exec_async(prep_res)
70115
except Exception as e:
71116
if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e)
72117
if self.wait>0: await asyncio.sleep(self.wait)
73-
async def run_async(self,shared):
74-
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
118+
async def run_async(self,shared):
119+
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
75120
return await self._run_async(shared)
76121
async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e)
77122

@@ -82,19 +127,26 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode):
82127
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
83128

84129
class AsyncFlow(Flow,AsyncNode):
85-
async def _orch_async(self,shared,params=None):
86-
curr,p=copy.copy(self.start),(params or {**self.params})
87-
while curr:curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared);curr=copy.copy(self.get_next_node(curr,c))
130+
async def _orch_async(self, shared, params=None):
131+
curr, p = copy.copy(self.start), (params or {**self.params})
132+
while curr:
133+
curr.set_params(p)
134+
if isinstance(curr, AsyncNode):
135+
c = await curr._run_async(shared)
136+
else:
137+
c = curr._run(shared)
138+
curr = copy.copy(self.get_next_node(curr, c))
88139
async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None)
89140

90141
class AsyncBatchFlow(AsyncFlow,BatchFlow):
91-
async def _run_async(self,shared):
92-
pr=await self.prep_async(shared) or []
93-
for bp in pr: await self._orch_async(shared,{**self.params,**bp})
94-
return await self.post_async(shared,pr,None)
142+
async def _run_async(self, shared):
143+
pr = await self.prep_async(shared) or []
144+
for bp in pr:
145+
await self._orch_async(shared, {**self.params, **bp})
146+
return await self.post_async(shared, pr, None)
95147

96148
class AsyncParallelBatchFlow(AsyncFlow,BatchFlow):
97-
async def _run_async(self,shared):
98-
pr=await self.prep_async(shared) or []
99-
await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr))
100-
return await self.post_async(shared,pr,None)
149+
async def _run_async(self, shared):
150+
pr = await self.prep_async(shared) or []
151+
await asyncio.gather(*(self._orch_async(shared, {**self.params, **bp}) for bp in pr))
152+
return await self.post_async(shared, pr, None)

pocketflow/example.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fasdasdasd

pocketflow/rework.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
2+
# from pocketflow import *
3+
from __init__ import *
4+
import os
5+
6+
def call_llm(prompt):
7+
# Your API logic here
8+
return prompt
9+
10+
class LoadFile(Node):
11+
def __init__(self, name=None):
12+
super().__init__()
13+
# Use provided name or fall back to automatic lookup
14+
self.name = name or self.name
15+
def prep(self, shared):
16+
print(f" In : {self.__class__.__name__}")
17+
"""Load file from disk"""
18+
filename = self.params["filename"]
19+
with open(filename, "r") as file:
20+
return file.read()
21+
22+
def exec(self, prep_res):
23+
"""Return file content"""
24+
return prep_res
25+
26+
def post(self, shared, prep_res, exec_res):
27+
"""Store file content in shared"""
28+
shared["file_content"] = exec_res
29+
return "default"
30+
31+
32+
class GetOpinion(Node):
33+
def __init__(self, name=None):
34+
super().__init__()
35+
# Use provided name or fall back to automatic lookup
36+
self.name = name or self.name
37+
38+
def prep(self, shared):
39+
print(f" In : {self.__class__.__name__}")
40+
print(f"My name is: {self.name} (instance of {self.__class__.__name__})")
41+
if self.flow:
42+
print(f"Flow name: {self.flow.name}")
43+
if self.flow.parent:
44+
print(f"Parent flow: {self.flow.parent.name}")
45+
"""Get file content from shared"""
46+
if not shared.get("reworked_file_content"):
47+
return shared["file_content"]
48+
else:
49+
return "Original text :\n" + shared["file_content"] + "Revised version:\n" + shared["reworked_file_content"]
50+
51+
def exec(self, prep_res):
52+
"""Ask LLM for opinion on file content"""
53+
prompt = f"What's your opinion on this text: {prep_res}. Provide opinion on how to make it better."
54+
return call_llm(prompt)
55+
56+
def post(self, shared, prep_res, exec_res):
57+
"""Store opinion in shared"""
58+
shared["opinion"] = exec_res
59+
return "default"
60+
61+
class GetValidation(Node):
62+
def __init__(self, name=None):
63+
super().__init__()
64+
# Use provided name or fall back to automatic lookup
65+
self.name = name or self.name
66+
def prep(self, shared):
67+
print(f" In : {self.__class__.__name__}")
68+
"""Get file content from shared"""
69+
shared['discussion'] = shared["file_content"] + shared["opinion"] + "Final revised text : " + shared["reworked_file_content"]
70+
return
71+
72+
def exec(self, prep_res):
73+
"""Ask LLM for opinion on file content"""
74+
prompt = f"Validate that the final revised text is valid and reflects the changes proposed in opinion : {prep_res}. \nReply `IS VALID` if it is of `NOT VALID` if it needs some more work."
75+
return call_llm(prompt)
76+
77+
def post(self, shared, prep_res, exec_res):
78+
"""Store rework count in shared"""
79+
if "IS VALID" in exec_res:
80+
return "default"
81+
else:
82+
return "invalid"
83+
84+
85+
class ReworkFile(Node):
86+
def __init__(self, name=None):
87+
super().__init__()
88+
# Use provided name or fall back to automatic lookup
89+
self.name = name or self.name
90+
def prep(self, shared):
91+
print(f" In : {self.__class__.__name__}")
92+
"""Get file content and opinion from shared"""
93+
return shared["file_content"], shared["opinion"]
94+
95+
def exec(self, prep_res):
96+
"""Ask LLM to rework file based on opinion"""
97+
file_content, opinion = prep_res
98+
prompt = f"Rework this text based on the opinion: {opinion}\n\nOriginal text: {file_content}"
99+
return call_llm(prompt)
100+
101+
def post(self, shared, prep_res, exec_res):
102+
"""Store reworked file content in shared"""
103+
if "rework2_flow_min_count" in self.params:
104+
rework_count = self.params["rework2_flow_min_count"]
105+
shared["reworked_file_content"] = exec_res
106+
if not shared.get("reworked_file_content_count"):
107+
shared["reworked_file_content_count"] = 1
108+
elif shared.get("reworked_file_content_count"):
109+
shared["reworked_file_content_count"] += 1
110+
111+
if shared["reworked_file_content_count"] < rework_count:
112+
print(f"Less than {self.params["rework2_flow_min_count"]} rework for rework2_flow, so going for pass #{shared["reworked_file_content_count"]}.")
113+
return "rework"
114+
else:
115+
return "default"
116+
else:
117+
shared["reworked_file_content"] = exec_res
118+
119+
120+
class SaveFile(Node):
121+
def __init__(self, name=None):
122+
super().__init__()
123+
# Use provided name or fall back to automatic lookup
124+
self.name = name or self.name
125+
def prep(self, shared):
126+
print(f" In : {self.__class__.__name__}")
127+
"""Get reworked file content and original filename from shared"""
128+
filename = self.params["filename"]
129+
if "reworked_file_content" in shared:
130+
return shared["reworked_file_content"], filename
131+
else:
132+
print("Error")
133+
134+
def exec(self, prep_res):
135+
"""Save reworked file content to new file"""
136+
reworked_file_content, filename = prep_res
137+
new_filename = f"{filename.split('.')[0]}_v2.{filename.split('.')[-1]}"
138+
with open(new_filename, "w") as file:
139+
file.write(reworked_file_content)
140+
return reworked_file_content
141+
142+
def post(self, shared, prep_res, exec_res):
143+
filename = self.params["filename"]
144+
"""Return success message"""
145+
print(f"Saved to {filename} the content : \n{exec_res}")
146+
147+
148+
# # # Comment this from here
149+
# # First flow
150+
# Create nodes
151+
load_Node = LoadFile(name="load_Node")
152+
opinion_Node = GetOpinion(name="opinion_Node")
153+
rework_Node = ReworkFile(name="rework_Node")
154+
save_Node = SaveFile(name="save_Node")
155+
156+
# Connect nodes
157+
load_Node >> opinion_Node >> rework_Node >> save_Node
158+
159+
# Create flow
160+
rework_Flow = Flow(start=load_Node,name="rework_Flow")
161+
162+
# Set flow params
163+
rework_Flow.set_params({"filename": "example.txt"})
164+
# Run flow
165+
shared = {}
166+
rework_Flow.run(shared)
167+
# # # To here for second workflow to work
168+
169+
# # Second flow
170+
# Create nodes
171+
load2_Node = LoadFile(name="load2_Node")
172+
opinion2_Node = GetOpinion(name="opinion2_Node")
173+
rework2_Node = ReworkFile(name="rework2_Node")
174+
valid2_Node = GetValidation(name="valid2_Node")
175+
save2_Node = SaveFile(name="save2_Node")
176+
177+
print(f" NAME is : {opinion2_Node.name}")
178+
179+
# Connect nodes
180+
load2_Node >> opinion2_Node
181+
opinion2_Node >> rework2_Node
182+
183+
rework2_Node - "default" >> valid2_Node
184+
rework2_Node - "rework" >> opinion2_Node
185+
186+
# Get second opinion it if rework asked because in rework_flow2 and less than 2 rework
187+
valid2_Node - "invalid" >> opinion2_Node
188+
valid2_Node - "default" >> save2_Node
189+
190+
# Create flow with explicit name
191+
rework2_Flow = Flow(start=load2_Node, name="rework2_Flow")
192+
# rework2_Flow.name = "rework2_Flow" # Set explicit name
193+
194+
# Set flow params
195+
# This will not set params if class Flow was already initialized with other params ?
196+
rework2_Flow.set_params({"filename": "example.txt", "rework2_flow_min_count" : 3})
197+
198+
# Run flow
199+
shared2 = {}
200+
rework2_Flow.run(shared2)
201+
202+
def build_mermaid(start):
203+
visited, lines = set(), ["graph LR"]
204+
205+
def get_name(n):
206+
"""Get the node's name for use in the diagram"""
207+
if isinstance(n, Flow):
208+
return n._get_name()
209+
return n._get_name().replace(' ', '_') # Mermaid needs no spaces in node names
210+
211+
def link(a, b):
212+
lines.append(f" {get_name(a)} --> {get_name(b)}")
213+
def walk(node, parent=None):
214+
if node in visited:
215+
return parent and link(parent, node)
216+
visited.add(node)
217+
if isinstance(node, Flow):
218+
node.start and parent and link(parent, node)
219+
# Add flow name and class name to subgraph label
220+
flow_label = f"{node._get_name()} ({type(node).__name__})"
221+
lines.append(f"\n subgraph {get_name(node)}[\"{flow_label}\"]")
222+
node.start and walk(node.start)
223+
for nxt in node.successors.values():
224+
node.start and walk(nxt, node.start) or (parent and link(parent, nxt)) or walk(nxt)
225+
lines.append(" end\n")
226+
else:
227+
# Add both instance name and class name to node label
228+
node_label = f"{node._get_name()} ({type(node).__name__})"
229+
lines.append(f" {get_name(node)}[\"{node_label}\"]")
230+
parent and link(parent, node)
231+
[walk(nxt, node) for nxt in node.successors.values()]
232+
walk(start)
233+
return "\n".join(lines)
234+
235+
print(build_mermaid(start=rework_Flow))
236+
237+
print(build_mermaid(start=rework2_Flow))

0 commit comments

Comments
 (0)