Skip to content

Commit 1ea510a

Browse files
Style fix for dspy/datasets (#8175)
* enable style check * update rules * fix style for dspy/datasets
1 parent 5e2d6a2 commit 1ea510a

File tree

9 files changed

+211
-46
lines changed

9 files changed

+211
-46
lines changed

dspy/clients/lm.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ async def async_stream_completion():
298298
return async_stream_completion
299299

300300

301-
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
301+
def litellm_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None):
302+
cache = cache or {"no-cache": True, "no-store": True}
302303
stream_completion = _get_stream_completion_fn(request, cache, sync=True)
303304
if stream_completion is None:
304305
return litellm.completion(
@@ -311,7 +312,8 @@ def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cac
311312
return stream_completion()
312313

313314

314-
def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
315+
def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None):
316+
cache = cache or {"no-cache": True, "no-store": True}
315317
# Extract the provider and model from the model string.
316318
# TODO: Not all the models are in the format of "provider/model"
317319
model = request.pop("model").split("/", 1)
@@ -336,7 +338,8 @@ def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"n
336338
)
337339

338340

339-
async def alitellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
341+
async def alitellm_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None):
342+
cache = cache or {"no-cache": True, "no-store": True}
340343
stream_completion = _get_stream_completion_fn(request, cache, sync=False)
341344
if stream_completion is None:
342345
return await litellm.acompletion(
@@ -349,9 +352,8 @@ async def alitellm_completion(request: Dict[str, Any], num_retries: int, cache={
349352
return await stream_completion()
350353

351354

352-
async def alitellm_text_completion(
353-
request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}
354-
):
355+
async def alitellm_text_completion(request: Dict[str, Any], num_retries: int, cache: Optional[Dict[str, Any]] = None):
356+
cache = cache or {"no-cache": True, "no-store": True}
355357
model = request.pop("model").split("/", 1)
356358
provider, model = model[0] if len(model) > 1 else "openai", model[-1]
357359

dspy/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1+
from dspy.datasets.alfworld import AlfWorld
12
from dspy.datasets.colors import Colors
23
from dspy.datasets.dataloader import DataLoader
34
from dspy.datasets.dataset import Dataset
45
from dspy.datasets.hotpotqa import HotPotQA
56
from dspy.datasets.math import MATH
6-
from dspy.datasets.alfworld import AlfWorld
77

88
__all__ = [
99
"Colors",

dspy/datasets/alfworld/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from dspy.datasets.alfworld.alfworld import AlfWorld
1+
from dspy.datasets.alfworld.alfworld import AlfWorld

dspy/datasets/alfworld/alfworld.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import queue
33
import random
44

5+
56
def env_worker(inq, outq):
67
"""
78
Worker process: creates a single AlfredTWEnv instance,
@@ -10,16 +11,19 @@ def env_worker(inq, outq):
1011

1112
try:
1213
import io
13-
import yaml
14+
from contextlib import redirect_stderr, redirect_stdout
15+
1416
import alfworld.agents.environment as environment
15-
from contextlib import redirect_stdout, redirect_stderr
17+
import yaml
1618
except ImportError:
17-
raise ImportError("alfworld is not installed. " \
18-
"Please install it via `pip install alfworld==0.3.5` then run `alfworld-download`.")
19+
raise ImportError(
20+
"alfworld is not installed. "
21+
"Please install it via `pip install alfworld==0.3.5` then run `alfworld-download`."
22+
)
1923

2024
buf = io.StringIO()
2125
base_dir = os.path.dirname(os.path.abspath(__file__))
22-
config_path = os.path.join(base_dir, 'base_config.yml')
26+
config_path = os.path.join(base_dir, "base_config.yml")
2327

2428
with open(config_path) as f:
2529
config = yaml.safe_load(f)
@@ -30,19 +34,19 @@ def env_worker(inq, outq):
3034
env = None
3135
while True:
3236
cmd, data = inq.get()
33-
if cmd == 'init':
37+
if cmd == "init":
3438
env = base_env.init_env(batch_size=1)
3539
env.skip(data)
3640
task_def, info = env.reset()
3741
outq.put((task_def[0], info))
38-
elif cmd == 'step':
42+
elif cmd == "step":
3943
obs, rew, done, info = env.step([data])
4044
outq.put((obs, rew, done, info))
41-
elif cmd == 'close':
42-
outq.put('CLOSED')
45+
elif cmd == "close":
46+
outq.put("CLOSED")
4347
break
4448
else:
45-
outq.put('UNKNOWN_CMD')
49+
outq.put("UNKNOWN_CMD")
4650

4751

4852
class EnvPool:
@@ -54,6 +58,7 @@ class EnvPool:
5458
obs, rew, done, info = sess.step("go north")
5559
...
5660
"""
61+
5762
def __init__(self, size=2):
5863
self.size = size
5964
self.workers = []
@@ -62,8 +67,7 @@ def __init__(self, size=2):
6267
try:
6368
import multiprocess as mp
6469
except ImportError:
65-
raise ImportError("multiprocess is not installed. " \
66-
"Please install it via `pip install multiprocess`.")
70+
raise ImportError("multiprocess is not installed. " "Please install it via `pip install multiprocess`.")
6771

6872
# Must call set_start_method('spawn') here, before creating any processes
6973
try:
@@ -93,7 +97,7 @@ def close_all(self):
9397
while not self.available.empty():
9498
wid = self.available.get()
9599
inq, outq, proc = self.workers[wid]
96-
inq.put(('close', None))
100+
inq.put(("close", None))
97101
outq.get() # Wait 'CLOSED'
98102
inq.close()
99103
outq.close()
@@ -109,6 +113,7 @@ class _EnvSession:
109113
A context manager that acquires a worker from the pool,
110114
provides .init(idx) and .step(action), then releases the worker.
111115
"""
116+
112117
def __init__(self, pool: EnvPool):
113118
self.pool = pool
114119
self.wid = None
@@ -123,11 +128,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
123128
self.pool._release(self.wid)
124129

125130
def init(self, idx):
126-
self.inq.put(('init', idx))
131+
self.inq.put(("init", idx))
127132
return self.outq.get() # (task_def, info)
128133

129134
def step(self, action):
130-
self.inq.put(('step', action))
135+
self.inq.put(("step", action))
131136
return self.outq.get() # (obs, rew, done, info)
132137

133138

@@ -136,7 +141,8 @@ def __init__(self, max_threads=20):
136141
self.POOL = EnvPool(size=max_threads)
137142

138143
import dspy
139-
dataset = [dspy.Example(idx=idx).with_inputs('idx') for idx in range(3500)]
144+
145+
dataset = [dspy.Example(idx=idx).with_inputs("idx") for idx in range(3500)]
140146
random.Random(0).shuffle(dataset)
141147

142148
trainset, devset = dataset[:3000], dataset[-500:]

dspy/datasets/colors.py

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,146 @@
33
from dspy.datasets.dataset import Dataset
44

55
### A bunch of colors, originally from matplotlib
6-
all_colors = ['alice blue', 'dodger blue', 'light sky blue', 'deep sky blue', 'sky blue', 'steel blue', 'light steel blue', 'medium blue', 'navy blue', 'blue', 'royal blue', 'cadet blue', 'cornflower blue', 'medium slate blue', 'slate blue', 'dark slate blue', 'powder blue', 'turquoise', 'dark turquoise', 'medium turquoise', 'pale turquoise', 'light sea green', 'medium sea green', 'sea green', 'forest green', 'green yellow', 'lime green', 'dark green', 'green', 'lime', 'chartreuse', 'lawn green', 'yellow green', 'olive green', 'dark olive green', 'medium spring green', 'spring green', 'medium aquamarine', 'aquamarine', 'aqua', 'cyan', 'dark cyan', 'teal', 'medium orchid', 'dark orchid', 'orchid', 'blue violet', 'violet', 'dark violet', 'plum', 'thistle', 'magenta', 'fuchsia', 'dark magenta', 'medium purple', 'purple', 'rebecca purple', 'dark red', 'fire brick', 'indian red', 'light coral', 'dark salmon', 'light salmon', 'salmon', 'red', 'crimson', 'tomato', 'coral', 'orange red', 'dark orange', 'orange', 'yellow', 'gold', 'light goldenrod yellow', 'pale goldenrod', 'goldenrod', 'dark goldenrod', 'beige', 'moccasin', 'blanched almond', 'navajo white', 'antique white', 'bisque', 'burlywood', 'dark khaki', 'khaki', 'tan', 'wheat', 'snow', 'floral white', 'old lace', 'ivory', 'linen', 'seashell', 'honeydew', 'mint cream', 'azure', 'lavender', 'ghost white', 'white smoke', 'gainsboro', 'light gray', 'silver', 'dark gray', 'gray', 'dim gray', 'slate gray', 'light slate gray', 'dark slate gray', 'black', 'medium violet red', 'pale violet red', 'deep pink', 'hot pink', 'light pink', 'pink', 'peach puff', 'rosy brown', 'saddle brown', 'sandy brown', 'chocolate', 'peru', 'sienna', 'brown', 'maroon', 'white', 'misty rose', 'lavender blush', 'papaya whip', 'lemon chiffon', 'light yellow', 'corn silk', 'pale green', 'light green', 'olive drab', 'olive', 'dark sea green']
6+
all_colors = [
7+
"alice blue",
8+
"dodger blue",
9+
"light sky blue",
10+
"deep sky blue",
11+
"sky blue",
12+
"steel blue",
13+
"light steel blue",
14+
"medium blue",
15+
"navy blue",
16+
"blue",
17+
"royal blue",
18+
"cadet blue",
19+
"cornflower blue",
20+
"medium slate blue",
21+
"slate blue",
22+
"dark slate blue",
23+
"powder blue",
24+
"turquoise",
25+
"dark turquoise",
26+
"medium turquoise",
27+
"pale turquoise",
28+
"light sea green",
29+
"medium sea green",
30+
"sea green",
31+
"forest green",
32+
"green yellow",
33+
"lime green",
34+
"dark green",
35+
"green",
36+
"lime",
37+
"chartreuse",
38+
"lawn green",
39+
"yellow green",
40+
"olive green",
41+
"dark olive green",
42+
"medium spring green",
43+
"spring green",
44+
"medium aquamarine",
45+
"aquamarine",
46+
"aqua",
47+
"cyan",
48+
"dark cyan",
49+
"teal",
50+
"medium orchid",
51+
"dark orchid",
52+
"orchid",
53+
"blue violet",
54+
"violet",
55+
"dark violet",
56+
"plum",
57+
"thistle",
58+
"magenta",
59+
"fuchsia",
60+
"dark magenta",
61+
"medium purple",
62+
"purple",
63+
"rebecca purple",
64+
"dark red",
65+
"fire brick",
66+
"indian red",
67+
"light coral",
68+
"dark salmon",
69+
"light salmon",
70+
"salmon",
71+
"red",
72+
"crimson",
73+
"tomato",
74+
"coral",
75+
"orange red",
76+
"dark orange",
77+
"orange",
78+
"yellow",
79+
"gold",
80+
"light goldenrod yellow",
81+
"pale goldenrod",
82+
"goldenrod",
83+
"dark goldenrod",
84+
"beige",
85+
"moccasin",
86+
"blanched almond",
87+
"navajo white",
88+
"antique white",
89+
"bisque",
90+
"burlywood",
91+
"dark khaki",
92+
"khaki",
93+
"tan",
94+
"wheat",
95+
"snow",
96+
"floral white",
97+
"old lace",
98+
"ivory",
99+
"linen",
100+
"seashell",
101+
"honeydew",
102+
"mint cream",
103+
"azure",
104+
"lavender",
105+
"ghost white",
106+
"white smoke",
107+
"gainsboro",
108+
"light gray",
109+
"silver",
110+
"dark gray",
111+
"gray",
112+
"dim gray",
113+
"slate gray",
114+
"light slate gray",
115+
"dark slate gray",
116+
"black",
117+
"medium violet red",
118+
"pale violet red",
119+
"deep pink",
120+
"hot pink",
121+
"light pink",
122+
"pink",
123+
"peach puff",
124+
"rosy brown",
125+
"saddle brown",
126+
"sandy brown",
127+
"chocolate",
128+
"peru",
129+
"sienna",
130+
"brown",
131+
"maroon",
132+
"white",
133+
"misty rose",
134+
"lavender blush",
135+
"papaya whip",
136+
"lemon chiffon",
137+
"light yellow",
138+
"corn silk",
139+
"pale green",
140+
"light green",
141+
"olive drab",
142+
"olive",
143+
"dark sea green",
144+
]
145+
7146

8147
class Colors(Dataset):
9148
def __init__(self, sort_by_suffix=True, *args, **kwargs) -> None:
@@ -12,22 +151,24 @@ def __init__(self, sort_by_suffix=True, *args, **kwargs) -> None:
12151
self.sort_by_suffix = sort_by_suffix
13152
colors = self.sorted_by_suffix(all_colors)
14153

15-
train_size = int(len(colors) * 0.6) # chosen to ensure that similar colors aren't repeated between train and dev
154+
train_size = int(
155+
len(colors) * 0.6
156+
) # chosen to ensure that similar colors aren't repeated between train and dev
16157
train_colors, dev_colors = colors[:train_size], colors[train_size:]
17158

18-
self._train = [dict(color=color) for color in train_colors]
19-
self._dev = [dict(color=color) for color in dev_colors]
159+
self._train = [{"color": color} for color in train_colors]
160+
self._dev = [{"color": color} for color in dev_colors]
20161

21162
random.Random(0).shuffle(self._train)
22163
random.Random(0).shuffle(self._dev)
23-
164+
24165
def sorted_by_suffix(self, colors):
25166
if not self.sort_by_suffix:
26167
return colors
27168

28169
if isinstance(colors[0], str):
29170
sorted_colors = sorted(colors, key=lambda x: x[::-1])
30171
else:
31-
sorted_colors = sorted(colors, key=lambda x: x['color'][::-1])
172+
sorted_colors = sorted(colors, key=lambda x: x["color"][::-1])
32173

33174
return sorted_colors

0 commit comments

Comments
 (0)