Skip to content

Commit 3abb0fb

Browse files
authored
Merge pull request #14 from lpm0073/next
add unit tests for command line examples
2 parents 9bff947 + c945753 commit 3abb0fb

15 files changed

+128
-65
lines changed

CHANGELOG.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
## [1.1.1](https://github.com/lpm0073/netec-llm/compare/v1.1.0...v1.1.1) (2023-12-01)
99

10-
1110
### Bug Fixes
1211

13-
* had to switch to bm25_encoder so that vector store is searchable ([bad6994](https://github.com/lpm0073/netec-llm/commit/bad699481d217dde81877d85124395529652dabe))
12+
- had to switch to bm25_encoder so that vector store is searchable ([bad6994](https://github.com/lpm0073/netec-llm/commit/bad699481d217dde81877d85124395529652dabe))
1413

1514
# [1.1.0](https://github.com/lpm0073/netec-llm/compare/v1.0.0...v1.1.0) (2023-12-01)
1615

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ SHELL := /bin/bash
33
ifneq ("$(wildcard .env)","")
44
include .env
55
else
6-
$(shell echo -e "OPENAI_API_ORGANIZATION=PLEASE-ADD-ME\nOPENAI_API_KEY=PLEASE-ADD-ME\nPINECONE_API_KEY=PLEASE-ADD-ME\nPINECONE_ENVIRONMENT=gcp-starter\nPINECONE_INDEX_NAME=netec-ssm\nDEBUG_MODE=True\n" >> .env)
6+
$(shell echo -e "OPENAI_API_ORGANIZATION=PLEASE-ADD-ME\nOPENAI_API_KEY=PLEASE-ADD-ME\nPINECONE_API_KEY=PLEASE-ADD-ME\nPINECONE_ENVIRONMENT=gcp-starter\nPINECONE_INDEX_NAME=hsr\nOPENAI_CHAT_MODEL_NAME=gpt-3.5-turbo\nOPENAI_PROMPT_MODEL_NAME=text-davinci-003\nOPENAI_CHAT_TEMPERATURE=0.0\nOPENAI_CHAT_MAX_RETRIES=3\nDEBUG_MODE=True\n" >> .env)
77
endif
88

99
.PHONY: analyze init activate test lint clean

models/const.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# pylint: disable=too-few-public-methods
3-
"""Sales Support Model (SSM) for the LangChain project."""
3+
"""Sales Support Model (hsr) for the LangChain project."""
44

55
import os
66

@@ -15,11 +15,26 @@
1515
OPENAI_API_ORGANIZATION = os.environ["OPENAI_API_ORGANIZATION"]
1616
PINECONE_API_KEY = os.environ["PINECONE_API_KEY"]
1717
PINECONE_ENVIRONMENT = os.environ["PINECONE_ENVIRONMENT"]
18-
PINECONE_INDEX_NAME = os.environ["PINECONE_INDEX_NAME"]
18+
PINECONE_INDEX_NAME = os.environ.get("PINECONE_INDEX_NAME", "hsr")
19+
OPENAI_CHAT_MODEL_NAME = os.environ.get("OPENAI_CHAT_MODEL_NAME", "gpt-3.5-turbo")
20+
OPENAI_PROMPT_MODEL_NAME = os.environ.get("OPENAI_PROMPT_MODEL_NAME", "text-davinci-003")
21+
OPENAI_CHAT_TEMPERATURE = float(os.environ.get("OPENAI_CHAT_TEMPERATURE", 0.0))
22+
OPENAI_CHAT_MAX_RETRIES = int(os.environ.get("OPENAI_CHAT_MAX_RETRIES", 3))
23+
OPENAI_CHAT_CACHE = bool(os.environ.get("OPENAI_CHAT_CACHE", True))
1924
else:
2025
raise FileNotFoundError("No .env file found in root directory of repository")
2126

2227

28+
class Config:
29+
"""Configuration parameters."""
30+
31+
OPENAI_CHAT_MODEL_NAME: str = OPENAI_CHAT_MODEL_NAME
32+
OPENAI_PROMPT_MODEL_NAME: str = OPENAI_PROMPT_MODEL_NAME
33+
OPENAI_CHAT_TEMPERATURE: float = OPENAI_CHAT_TEMPERATURE
34+
OPENAI_CHAT_MAX_RETRIES: int = OPENAI_CHAT_MAX_RETRIES
35+
OPENAI_CHAT_CACHE: bool = OPENAI_CHAT_CACHE
36+
37+
2338
class Credentials:
2439
"""Credentials."""
2540

models/examples/load.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
2+
"""Sales Support Model (hsr) Retrieval Augmented Generation (RAG)"""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66

77

8-
ssm = HybridSearchRetriever()
8+
hsr = HybridSearchRetriever()
99

1010
if __name__ == "__main__":
1111
parser = argparse.ArgumentParser(description="RAG example")
1212
parser.add_argument("filepath", type=str, help="Location of PDF documents")
1313
args = parser.parse_args()
1414

15-
ssm.load(filepath=args.filepath)
15+
hsr.load(filepath=args.filepath)

models/examples/prompt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM)"""
2+
"""Sales Support Model (hsr)"""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66

77

8-
ssm = HybridSearchRetriever()
8+
hsr = HybridSearchRetriever()
99

1010

1111
if __name__ == "__main__":
12-
parser = argparse.ArgumentParser(description="SSM examples")
12+
parser = argparse.ArgumentParser(description="hsr examples")
1313
parser.add_argument("system_prompt", type=str, help="A system prompt to send to the model.")
1414
parser.add_argument("human_prompt", type=str, help="A human prompt to send to the model.")
1515
args = parser.parse_args()
1616

17-
result = ssm.cached_chat_request(args.system_prompt, args.human_prompt)
17+
result = hsr.cached_chat_request(args.system_prompt, args.human_prompt)
1818
print(result)

models/examples/rag.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) Retrieval Augmented Generation (RAG)"""
2+
"""Sales Support Model (hsr) Retrieval Augmented Generation (RAG)"""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66

77

8-
ssm = HybridSearchRetriever()
8+
hsr = HybridSearchRetriever()
99

1010
if __name__ == "__main__":
1111
parser = argparse.ArgumentParser(description="RAG example")
1212
parser.add_argument("prompt", type=str, help="A question about the PDF contents")
1313
args = parser.parse_args()
1414

15-
result = ssm.rag(prompt=args.prompt)
15+
result = hsr.rag(prompt=args.prompt)
1616
print(result)

models/examples/training_services.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) for the LangChain project."""
2+
"""Sales Support Model (hsr) for the LangChain project."""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66
from models.prompt_templates import NetecPromptTemplates
77

88

9-
ssm = HybridSearchRetriever()
9+
hsr = HybridSearchRetriever()
1010
templates = NetecPromptTemplates()
1111

1212
if __name__ == "__main__":
13-
parser = argparse.ArgumentParser(description="SSM examples")
13+
parser = argparse.ArgumentParser(description="hsr examples")
1414
parser.add_argument("concept", type=str, help="A kind of training that Netec provides.")
1515
args = parser.parse_args()
1616

1717
prompt = templates.training_services
18-
result = ssm.prompt_with_template(prompt=prompt, concept=args.concept)
18+
result = hsr.prompt_with_template(prompt=prompt, concept=args.concept)
1919
print(result)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
# -*- coding: utf-8 -*-
2-
"""Sales Support Model (SSM) for the LangChain project."""
2+
"""Sales Support Model (hsr) for the LangChain project."""
33
import argparse
44

55
from models.hybrid_search_retreiver import HybridSearchRetriever
66
from models.prompt_templates import NetecPromptTemplates
77

88

9-
ssm = HybridSearchRetriever()
9+
hsr = HybridSearchRetriever()
1010
templates = NetecPromptTemplates()
1111

1212
if __name__ == "__main__":
13-
parser = argparse.ArgumentParser(description="SSM Oracle examples")
13+
parser = argparse.ArgumentParser(description="hsr Oracle examples")
1414
parser.add_argument("concept", type=str, help="An Oracle certification exam prep")
1515
args = parser.parse_args()
1616

1717
prompt = templates.oracle_training_services
18-
result = ssm.prompt_with_template(prompt=prompt, concept=args.concept)
18+
result = hsr.prompt_with_template(prompt=prompt, concept=args.concept)
1919
print(result)

models/hybrid_search_retreiver.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import glob
2121
import os
2222
import textwrap
23-
from typing import List
2423

2524
# pinecone integration
2625
import pinecone
@@ -44,25 +43,23 @@
4443
from pinecone_text.sparse import BM25Encoder
4544

4645
# this project
47-
from models.const import Credentials
46+
from models.const import Config, Credentials
4847

4948

5049
###############################################################################
5150
# initializations
5251
###############################################################################
53-
DEFAULT_MODEL_NAME = "text-davinci-003"
52+
DEFAULT_MODEL_NAME = Config.OPENAI_PROMPT_MODEL_NAME
5453
pinecone.init(api_key=Credentials.PINECONE_API_KEY, environment=Credentials.PINECONE_ENVIRONMENT)
5554
set_llm_cache(InMemoryCache())
5655

5756

5857
class TextSplitter:
5958
"""
60-
Custom text splitter that add metadata to the Document object
59+
Custom text splitter that adds metadata to the Document object
6160
which is required by PineconeHybridSearchRetriever.
6261
"""
6362

64-
# ...
65-
6663
def create_documents(self, texts):
6764
"""Create documents"""
6865
documents = []
@@ -74,16 +71,16 @@ def create_documents(self, texts):
7471

7572

7673
class HybridSearchRetriever:
77-
"""Sales Support Model (SSM)."""
74+
"""Hybrid Search Retriever (OpenAI + Pinecone)"""
7875

7976
# prompting wrapper
8077
chat = ChatOpenAI(
8178
api_key=Credentials.OPENAI_API_KEY,
8279
organization=Credentials.OPENAI_API_ORGANIZATION,
83-
cache=True,
84-
max_retries=3,
85-
model="gpt-3.5-turbo",
86-
temperature=0.0,
80+
cache=Config.OPENAI_CHAT_CACHE,
81+
max_retries=Config.OPENAI_CHAT_MAX_RETRIES,
82+
model=Config.OPENAI_CHAT_MODEL_NAME,
83+
temperature=Config.OPENAI_CHAT_TEMPERATURE,
8784
)
8885

8986
# embeddings
@@ -112,22 +109,6 @@ def prompt_with_template(self, prompt: PromptTemplate, concept: str, model: str
112109
retval = llm(prompt.format(concept=concept))
113110
return retval
114111

115-
def fit_tf_idf_values(self, corpus: List[str]):
116-
"""Fit TF-IDF values.
117-
1. Fit the BM25 encoder on the corpus
118-
2. Encode the corpus
119-
3. Store the encoded corpus in Pinecone
120-
"""
121-
corpus = ["foo", "bar", "world", "hello"]
122-
123-
# fit tf-idf values on your corpus
124-
self.bm25_encoder.fit(corpus)
125-
126-
# persist the values to a json file
127-
self.bm25_encoder.dump("bm25_values.json")
128-
self.bm25_encoder = BM25Encoder().load("bm25_values.json")
129-
self.bm25_encoder.fit(corpus)
130-
131112
def load(self, filepath: str):
132113
"""
133114
Embed PDF.
@@ -201,9 +182,9 @@ def rag(self, prompt: str):
201182
document_texts = [doc.page_content for doc in documents]
202183
leader = textwrap.dedent(
203184
"""\
204-
You can assume that the following is true,
205-
and you should attempt to incorporate these facts
206-
in your response:
185+
\n\nYou can assume that the following is true.
186+
You should attempt to incorporate these facts
187+
into your response:\n\n
207188
"""
208189
)
209190

models/prompt_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
# pylint: disable=too-few-public-methods
3-
"""Sales Support Model (SSM) prompt templates"""
3+
"""Sales Support Model (hsr) prompt templates"""
44

55
from langchain.prompts import PromptTemplate
66

0 commit comments

Comments
 (0)