Skip to content

Commit 5e2d6a2

Browse files
Fix code style in dspy/dsp (#8176)
* enable style check * update rules * allow wildcard imports * fix code style for dspy/dsp
1 parent fa69844 commit 5e2d6a2

File tree

4 files changed

+161
-126
lines changed

4 files changed

+161
-126
lines changed

dspy/dsp/colbertv2.py

Lines changed: 70 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ def __init__(
2121
self.url = f"{url}:{port}" if port else url
2222

2323
def __call__(
24-
self, query: str, k: int = 10, simplify: bool = False,
24+
self,
25+
query: str,
26+
k: int = 10,
27+
simplify: bool = False,
2528
) -> Union[list[str], list[dotdict]]:
2629
if self.post_requests:
2730
topk: list[dict[str, Any]] = colbertv2_post_request(self.url, query, k)
@@ -36,9 +39,7 @@ def __call__(
3639

3740
@request_cache()
3841
def colbertv2_get_request_v2(url: str, query: str, k: int):
39-
assert (
40-
k <= 100
41-
), "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment."
42+
assert k <= 100, "Only k <= 100 is supported for the hosted ColBERTv2 server at the moment."
4243

4344
payload = {"query": query, "k": k}
4445
res = requests.get(url, params=payload, timeout=10)
@@ -72,88 +73,112 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs):
7273

7374
colbertv2_post_request = colbertv2_post_request_v2_wrapped
7475

76+
7577
class ColBERTv2RetrieverLocal:
76-
def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False):
78+
def __init__(self, passages: List[str], colbert_config=None, load_only: bool = False):
7779
"""Colbertv2 retriever module
7880
7981
Args:
8082
passages (List[str]): list of passages
8183
colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to None.
8284
load_only (bool, optional): whether to load the index or build and then load. Defaults to False.
8385
"""
84-
assert colbert_config is not None, "Please pass a valid colbert_config, which you can import from colbert.infra.config import ColBERTConfig and modify it"
86+
assert (
87+
colbert_config is not None
88+
), "Please pass a valid colbert_config, which you can import from colbert.infra.config import ColBERTConfig and modify it"
8589
self.colbert_config = colbert_config
8690

87-
assert self.colbert_config.checkpoint is not None, "Please pass a valid checkpoint like colbert-ir/colbertv2.0, which you can modify in the ColBERTConfig with attribute name checkpoint"
91+
assert (
92+
self.colbert_config.checkpoint is not None
93+
), "Please pass a valid checkpoint like colbert-ir/colbertv2.0, which you can modify in the ColBERTConfig with attribute name checkpoint"
8894
self.passages = passages
89-
90-
assert self.colbert_config.index_name is not None, "Please pass a valid index_name, which you can modify in the ColBERTConfig with attribute name index_name"
95+
96+
assert (
97+
self.colbert_config.index_name is not None
98+
), "Please pass a valid index_name, which you can modify in the ColBERTConfig with attribute name index_name"
9199
self.passages = passages
92100

93101
if not load_only:
94-
print(f"Building the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}")
102+
print(
103+
f"Building the index for experiment {self.colbert_config.experiment} with index name "
104+
f"{self.colbert_config.index_name}"
105+
)
95106
self.build_index()
96-
97-
print(f"Loading the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}")
107+
108+
print(
109+
f"Loading the index for experiment {self.colbert_config.experiment} with index name "
110+
f"{self.colbert_config.index_name}"
111+
)
98112
self.searcher = self.get_index()
99113

100114
def build_index(self):
101-
102115
try:
103-
import colbert # noqa: F401
116+
import colbert # noqa: F401
104117
except ImportError:
105-
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")
118+
print(
119+
"Colbert not found. Please check your installation or install the module using pip install "
120+
"colbert-ai[faiss-gpu,torch]."
121+
)
106122

107123
from colbert import Indexer
108124
from colbert.infra import Run, RunConfig
109-
with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)):
125+
126+
with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)):
110127
indexer = Indexer(checkpoint=self.colbert_config.checkpoint, config=self.colbert_config)
111128
indexer.index(name=self.colbert_config.index_name, collection=self.passages, overwrite=True)
112129

113130
def get_index(self):
114131
try:
115-
import colbert # noqa: F401
132+
import colbert # noqa: F401
116133
except ImportError:
117-
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")
134+
print(
135+
"Colbert not found. Please check your installation or install the module using pip install "
136+
"colbert-ai[faiss-gpu,torch]."
137+
)
118138

119139
from colbert import Searcher
120140
from colbert.infra import Run, RunConfig
121-
141+
122142
with Run().context(RunConfig(experiment=self.colbert_config.experiment)):
123143
searcher = Searcher(index=self.colbert_config.index_name, collection=self.passages)
124144
return searcher
125-
145+
126146
def __call__(self, *args: Any, **kwargs: Any) -> Any:
127147
return self.forward(*args, **kwargs)
128148

129-
def forward(self,query:str,k:int=7,**kwargs):
149+
def forward(self, query: str, k: int = 7, **kwargs):
130150
import torch
131-
151+
132152
if kwargs.get("filtered_pids"):
133153
filtered_pids = kwargs.get("filtered_pids")
134154
assert type(filtered_pids) == List[int], "The filtered pids should be a list of integers"
135155
device = "cuda" if torch.cuda.is_available() else "cpu"
136156
results = self.searcher.search(
137157
query,
138-
#Number of passages to receive
139-
k=k,
140-
#Passing the filter function of relevant
158+
# Number of passages to receive
159+
k=k,
160+
# Passing the filter function of relevant
141161
filter_fn=lambda pids: torch.tensor(
142-
[pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device))
162+
[pid for pid in pids if pid in filtered_pids], dtype=torch.int32
163+
).to(device),
164+
)
143165
else:
144166
searcher_results = self.searcher.search(query, k=k)
145167
results = []
146-
for pid,rank,score in zip(*searcher_results):
147-
results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid}))
168+
for pid, rank, score in zip(*searcher_results): # noqa: B007
169+
results.append(dotdict({"long_text": self.searcher.collection[pid], "score": score, "pid": pid}))
148170
return results
149171

172+
150173
class ColBERTv2RerankerLocal:
151-
152-
def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'):
174+
def __init__(self, colbert_config=None, checkpoint: str = "bert-base-uncased"):
153175
try:
154-
import colbert # noqa: F401
176+
import colbert # noqa: F401
155177
except ImportError:
156-
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")
178+
print(
179+
"Colbert not found. Please check your installation or install the module using pip install "
180+
"colbert-ai[faiss-gpu,torch]."
181+
)
157182
"""_summary_
158183
159184
Args:
@@ -167,24 +192,25 @@ def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'):
167192
def __call__(self, *args: Any, **kwargs: Any) -> Any:
168193
return self.forward(*args, **kwargs)
169194

170-
def forward(self,query:str,passages:List[str]=[]):
195+
def forward(self, query: str, passages: Optional[List[str]] = None):
171196
assert len(passages) > 0, "Passages should not be empty"
172197

173198
import numpy as np
174199
from colbert.modeling.colbert import ColBERT
175200
from colbert.modeling.tokenization.doc_tokenization import DocTokenizer
176201
from colbert.modeling.tokenization.query_tokenization import QueryTokenizer
177-
202+
203+
passages = passages or []
178204
self.colbert_config.nway = len(passages)
179-
query_tokenizer = QueryTokenizer(self.colbert_config,verbose=1)
205+
query_tokenizer = QueryTokenizer(self.colbert_config, verbose=1)
180206
doc_tokenizer = DocTokenizer(self.colbert_config)
181-
query_ids,query_masks = query_tokenizer.tensorize([query])
182-
doc_ids,doc_masks = doc_tokenizer.tensorize(passages)
183-
184-
col = ColBERT(self.checkpoint,self.colbert_config)
185-
Q = col.query(query_ids,query_masks)
186-
DOC_IDS,DOC_MASKS = col.doc(doc_ids,doc_masks,keep_dims='return_mask')
187-
Q_duplicated = Q.repeat_interleave(len(passages), dim=0).contiguous()
188-
tensor_scores = col.score(Q_duplicated,DOC_IDS,DOC_MASKS)
207+
query_ids, query_masks = query_tokenizer.tensorize([query])
208+
doc_ids, doc_masks = doc_tokenizer.tensorize(passages)
209+
210+
col = ColBERT(self.checkpoint, self.colbert_config)
211+
q = col.query(query_ids, query_masks)
212+
doc_ids, doc_masks = col.doc(doc_ids, doc_masks, keep_dims="return_mask")
213+
q_duplicated = q.repeat_interleave(len(passages), dim=0).contiguous()
214+
tensor_scores = col.score(q_duplicated, doc_ids, doc_masks)
189215
passage_score_arr = np.array([score.cpu().detach().numpy().tolist() for score in tensor_scores])
190-
return passage_score_arr
216+
return passage_score_arr

0 commit comments

Comments
 (0)