Skip to content

Commit 2787453

Browse files
authored
add quote rule (#8383)
1 parent 077fc2c commit 2787453

File tree

9 files changed

+41
-40
lines changed

9 files changed

+41
-40
lines changed

dspy/adapters/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def translate_field_type(field_name, field_info):
9090
elif field_type in (int, float):
9191
desc = f"must be a single {field_type.__name__} value"
9292
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
93-
enum_vals = '; '.join(str(member.value) for member in field_type)
93+
enum_vals = "; ".join(str(member.value) for member in field_type)
9494
desc = f"must be one of: {enum_vals}"
9595
elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
9696
desc = (

dspy/primitives/python_interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
if deno_command:
5757
self.deno_command = list(deno_command)
5858
else:
59-
args = ['deno', 'run', '--allow-read']
59+
args = ["deno", "run", "--allow-read"]
6060
self._env_arg = ""
6161
if self.enable_env_vars:
6262
user_vars = [str(v).strip() for v in self.enable_env_vars]

dspy/propose/dataset_summary_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def reorder_keys(match):
3636
# Extracting the keys from the match
3737
keys_str = match.group(1)
3838
# Splitting the keys, stripping extra spaces, and sorting them
39-
keys = sorted(key.strip() for key in keys_str.split(','))
39+
keys = sorted(key.strip() for key in keys_str.split(","))
4040
# Formatting the sorted keys back into the expected structure
4141
return f"input_keys={{{', '.join(keys)}}}"
4242

dspy/propose/grounded_proposer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def gather_examples_from_sets(candidate_sets, max_examples):
224224
outputs = []
225225
for field_name, field in get_signature(program.predictors()[pred_i]).fields.items():
226226
# Access the '__dspy_field_type' from the extra metadata
227-
dspy_field_type = field.json_schema_extra.get('__dspy_field_type')
227+
dspy_field_type = field.json_schema_extra.get("__dspy_field_type")
228228

229229
# Based on the '__dspy_field_type', append to the respective list
230230
if dspy_field_type == "input":

dspy/propose/utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616

1717
def strip_prefix(text):
18-
pattern = r'^[\*\s]*(([\w\'\-]+\s+){0,4}[\w\'\-]+):\s*'
19-
modified_text = re.sub(pattern, '', text)
20-
return modified_text.strip("\"")
18+
pattern = r"^[\*\s]*(([\w\'\-]+\s+){0,4}[\w\'\-]+):\s*"
19+
modified_text = re.sub(pattern, "", text)
20+
return modified_text.strip('"')
2121

2222
def create_instruction_set_history_string(base_program, trial_logs, top_n):
2323
program_history = []
@@ -42,7 +42,7 @@ def create_instruction_set_history_string(base_program, trial_logs, top_n):
4242
unique_program_history.append(entry)
4343

4444
# Get the top n programs from program history
45-
top_n_program_history = sorted(unique_program_history, key=lambda x: x['score'], reverse=True)[:top_n]
45+
top_n_program_history = sorted(unique_program_history, key=lambda x: x["score"], reverse=True)[:top_n]
4646
top_n_program_history.reverse()
4747

4848
# Create formatted string
@@ -71,7 +71,7 @@ def get_program_instruction_set_string(program):
7171
instruction_list = []
7272
for _, pred in enumerate(program.predictors()):
7373
pred_instructions = get_signature(pred).instructions
74-
instruction_list.append(f"\"{pred_instructions}\"")
74+
instruction_list.append(f'"{pred_instructions}"')
7575
# Joining the list into a single string that looks like a list
7676
return f"[{', '.join(instruction_list)}]"
7777

@@ -97,15 +97,15 @@ def create_predictor_level_history_string(base_program, predictor_i, trial_logs,
9797
score = history_item["score"]
9898

9999
if instruction in instruction_aggregate:
100-
instruction_aggregate[instruction]['total_score'] += score
101-
instruction_aggregate[instruction]['count'] += 1
100+
instruction_aggregate[instruction]["total_score"] += score
101+
instruction_aggregate[instruction]["count"] += 1
102102
else:
103-
instruction_aggregate[instruction] = {'total_score': score, 'count': 1}
103+
instruction_aggregate[instruction] = {"total_score": score, "count": 1}
104104

105105
# Calculate average score for each instruction and prepare for sorting
106106
predictor_history = []
107107
for instruction, data in instruction_aggregate.items():
108-
average_score = data['total_score'] / data['count']
108+
average_score = data["total_score"] / data["count"]
109109
predictor_history.append((instruction, average_score))
110110

111111
# Deduplicate and sort by average score, then select top N
@@ -141,7 +141,7 @@ def create_example_string(fields, example):
141141
output.append(field_str)
142142

143143
# Joining all the field strings
144-
return '\n'.join(output)
144+
return "\n".join(output)
145145

146146
def get_dspy_source_code(module):
147147
header = []
@@ -169,18 +169,18 @@ def get_dspy_source_code(module):
169169
if item in completed_set:
170170
continue
171171
if isinstance(item, Parameter):
172-
if hasattr(item, 'signature') and item.signature is not None and item.signature.__pydantic_parent_namespace__['signature_name'] + "_sig" not in completed_set:
172+
if hasattr(item, "signature") and item.signature is not None and item.signature.__pydantic_parent_namespace__["signature_name"] + "_sig" not in completed_set:
173173
try:
174174
header.append(inspect.getsource(item.signature))
175175
print(inspect.getsource(item.signature))
176176
except (TypeError, OSError):
177177
header.append(str(item.signature))
178-
completed_set.add(item.signature.__pydantic_parent_namespace__['signature_name'] + "_sig")
178+
completed_set.add(item.signature.__pydantic_parent_namespace__["signature_name"] + "_sig")
179179
if isinstance(item, dspy.Module):
180180
code = get_dspy_source_code(item).strip()
181181
if code not in completed_set:
182182
header.append(code)
183183
completed_set.add(code)
184184
completed_set.add(item)
185185

186-
return '\n\n'.join(header) + '\n\n' + base_code
186+
return "\n\n".join(header) + "\n\n" + base_code

dspy/teleprompt/avatar_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def compile(self, student, *, trainset):
187187
best_score = -999 if self.optimize_for == "max" else 999
188188

189189
for i in range(self.max_iters):
190-
print(20*'=')
190+
print(20*"=")
191191
print(f"Iteration {i+1}/{self.max_iters}")
192192

193193
score, pos_inputs, neg_inputs = self._get_pos_neg_results(best_actor, trainset)

dspy/teleprompt/grpo.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def __init__(
4040
report_train_scores: bool = False,
4141
failure_score: float = 0,
4242
format_failure_score: float = -1,
43-
variably_invoked_predictor_grouping_mode: Union[Literal['truncate'], Literal['fill'], Literal['ragged']] = 'truncate',
44-
variably_invoked_predictor_fill_strategy: Optional[Union[Literal['randint'], Literal['max']]] = None,
43+
variably_invoked_predictor_grouping_mode: Union[Literal["truncate"], Literal["fill"], Literal["ragged"]] = "truncate",
44+
variably_invoked_predictor_fill_strategy: Optional[Union[Literal["randint"], Literal["max"]]] = None,
4545
):
4646
super().__init__(train_kwargs=train_kwargs)
4747
self.metric = metric
@@ -70,9 +70,9 @@ def __init__(
7070
# The backend will be called with a batch of (num_dspy_examples_per_grpo_step * num_rollouts_per_grpo_step * num_predictors) per training set if multitask is True
7171
# If multitask is False, the backend will be called with a batch of (num_dspy_examples_per_grpo_step * num_rollouts_per_grpo_step) per training job
7272
self.variably_invoked_predictor_grouping_mode = variably_invoked_predictor_grouping_mode
73-
if variably_invoked_predictor_grouping_mode == 'fill':
73+
if variably_invoked_predictor_grouping_mode == "fill":
7474
assert variably_invoked_predictor_fill_strategy is not None, "variably_invoked_predictor_fill_strategy must be set when variably_invoked_predictor_grouping_mode is 'fill'"
75-
assert variably_invoked_predictor_fill_strategy in ['randint', 'max'], "variably_invoked_predictor_fill_strategy must be either 'randint' or 'max'"
75+
assert variably_invoked_predictor_fill_strategy in ["randint", "max"], "variably_invoked_predictor_fill_strategy must be either 'randint' or 'max'"
7676
self.variably_invoked_predictor_fill_strategy = variably_invoked_predictor_fill_strategy
7777

7878
self.shuffled_trainset_ids = []
@@ -374,7 +374,7 @@ def compile(
374374
format_failure_score=self.format_failure_score,
375375
)
376376
for data_dict in round_data:
377-
example_ind_in_subsample = data_dict['example_ind'] % len(subsample_training_dataset)
377+
example_ind_in_subsample = data_dict["example_ind"] % len(subsample_training_dataset)
378378
data_dict["example_ind"] = example_ind_in_subsample
379379
trace_data[example_ind_in_subsample][tind].append(data_dict)
380380

@@ -425,10 +425,10 @@ def compile(
425425
logger.warning(f"Skipping example {example_ind} for predictor {pred_id} as it has no invocations.")
426426
continue
427427

428-
if self.variably_invoked_predictor_grouping_mode == 'truncate':
428+
if self.variably_invoked_predictor_grouping_mode == "truncate":
429429
predictor_example_invocations = [invocation[:min_len] for invocation in predictor_example_invocations]
430-
elif self.variably_invoked_predictor_grouping_mode == 'fill':
431-
if self.variably_invoked_predictor_fill_strategy == 'randint':
430+
elif self.variably_invoked_predictor_grouping_mode == "fill":
431+
if self.variably_invoked_predictor_fill_strategy == "randint":
432432
selector = lambda l: self.rng.choice(l) # noqa: E731, E741
433433
else:
434434
selector = lambda l: l[-1] # noqa: E731, E741
@@ -437,7 +437,7 @@ def compile(
437437
for invocation in predictor_example_invocations
438438
]
439439
else:
440-
assert self.variably_invoked_predictor_grouping_mode == 'ragged', f"Unknown variably invoked predictor grouping mode {self.variably_invoked_predictor_grouping_mode}"
440+
assert self.variably_invoked_predictor_grouping_mode == "ragged", f"Unknown variably invoked predictor grouping mode {self.variably_invoked_predictor_grouping_mode}"
441441
max_len = max([len(predictor_example_invocations[i]) for i in range(len(predictor_example_invocations))])
442442

443443
example_training_data: List[GRPOGroup] = [[] for _ in range(max_len)]
@@ -481,7 +481,7 @@ def compile(
481481
inputs=trace_instance[1],
482482
outputs=trace_instance[2],
483483
demos=[] # TODO: Add support for demos
484-
)['messages']
484+
)["messages"]
485485

486486
assert all_messages[:-1] == inp_messages, f"Input messages {inp_messages} do not match the expected messages {all_messages[:-1]}"
487487

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ select = [
137137
"UP", # pyupgrade
138138
"N", # pep8-naming
139139
"RUF", # ruff-specific rules
140+
"Q", # flake8-quotes
140141
]
141142

142143
ignore = [

tests/predict/test_code_act.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ def test_codeact_code_generation():
3939
res = program(question="What is 1+1?")
4040
assert res.answer == "2"
4141
assert res.trajectory == {
42-
'code_output_0': '"2\\n"',
43-
'generated_code_0': 'result = add(1,1)\nprint(result)',
42+
"code_output_0": '"2\\n"',
43+
"generated_code_0": "result = add(1,1)\nprint(result)",
4444
}
4545
assert program.interpreter.deno_process is None
4646

@@ -72,8 +72,8 @@ def test_codeact_support_multiple_fields():
7272
assert res.maximum == "6"
7373
assert res.minimum == "2"
7474
assert res.trajectory == {
75-
'code_output_0': '"{\'maximum\': 6.0, \'minimum\': 2.0}\\n"',
76-
'generated_code_0': "result = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)",
75+
"code_output_0": '"{\'maximum\': 6.0, \'minimum\': 2.0}\\n"',
76+
"generated_code_0": "result = extract_maximum_minimum('2, 3, 5, 6')\nprint(result)",
7777
}
7878
assert program.interpreter.deno_process is None
7979

@@ -100,10 +100,10 @@ def test_codeact_code_parse_failure():
100100
res = program(question="What is 1+1?")
101101
assert res.answer == "2"
102102
assert res.trajectory == {
103-
'generated_code_0': 'parse(error',
104-
'observation_0': 'Failed to execute the generated code: Invalid Python syntax. message: ',
105-
'generated_code_1': 'result = add(1,1)\nprint(result)',
106-
'code_output_1': '"2\\n"',
103+
"generated_code_0": "parse(error",
104+
"observation_0": "Failed to execute the generated code: Invalid Python syntax. message: ",
105+
"generated_code_1": "result = add(1,1)\nprint(result)",
106+
"code_output_1": '"2\\n"',
107107
}
108108
assert program.interpreter.deno_process is None
109109

@@ -130,10 +130,10 @@ def test_codeact_code_execution_failure():
130130
res = program(question="What is 1+1?")
131131
assert res.answer == "2"
132132
assert res.trajectory == {
133-
'generated_code_0': 'unknown+1',
134-
'observation_0': 'Failed to execute the generated code: NameError: ["name \'unknown\' is not defined"]',
135-
'generated_code_1': 'result = add(1,1)\nprint(result)',
136-
'code_output_1': '"2\\n"',
133+
"generated_code_0": "unknown+1",
134+
"observation_0": 'Failed to execute the generated code: NameError: ["name \'unknown\' is not defined"]',
135+
"generated_code_1": "result = add(1,1)\nprint(result)",
136+
"code_output_1": '"2\\n"',
137137
}
138138
assert program.interpreter.deno_process is None
139139

0 commit comments

Comments
 (0)