@@ -490,13 +490,14 @@ def _embedding_pricing(self): pass
490
490
491
491
492
492
class Indexer (Expression ):
493
- def __init__ (self , index_name : str = 'data-index' , top_k : int = 10 , batch_size : int = 20 ):
493
+ def __init__ (self , index_name : str = 'data-index' , top_k : int = 8 , batch_size : int = 20 ):
494
494
super ().__init__ ()
495
- self .index_name = index_name
496
- self .elements = []
497
- self .batch_size = batch_size
498
- self .top_k = top_k
495
+ self .index_name = index_name
496
+ self .elements = []
497
+ self .batch_size = batch_size
498
+ self .top_k = top_k
499
499
self .NEWLINES_RE = re .compile (r"\n{2,}" ) # two or more "\n" characters
500
+ self .retrieval = None
500
501
501
502
def split_paragraphs (self , input_text = "" ):
502
503
no_newlines = input_text .strip ("\n " ) # remove leading and trailing "\n"
@@ -508,7 +509,7 @@ def split_paragraphs(self, input_text=""):
508
509
509
510
return paragraphs
510
511
511
- def split_huge_paragraphs (self , input_text : List [str ], max_length = 400 ):
512
+ def split_huge_paragraphs (self , input_text : List [str ], max_length = 300 ):
512
513
paragraphs = []
513
514
for text in input_text :
514
515
words = text .split ()
@@ -535,6 +536,7 @@ def forward(self, query: Optional[Symbol] = None, *args, **kwargs) -> Symbol:
535
536
def _func (query ):
536
537
res = that .get (Symbol (query ).embed ().value , index_top_k = that .top_k ).ast ()
537
538
res = [v ['metadata' ]['text' ] for v in res ['matches' ]]
539
+ that .retrieval = res
538
540
sym = that ._to_symbol (res )
539
541
rsp = sym .query (query , max_tokens = 2000 )
540
542
return rsp
0 commit comments