Skip to content

Commit 2bbfef2

Browse files
committed
Adds OpenAI-based additions to the flask app and waitress.
1 parent c41cbe6 commit 2bbfef2

17 files changed

+939
-12
lines changed

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pybin

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
config.txt
2+
__pycache__

Dockerfile

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
FROM python:3.11-slim
22

33
WORKDIR /app
4-
COPY ./src /app
4+
COPY requirements.txt .
55

6-
RUN pip install Flask
6+
RUN pip install -r requirements.txt
7+
8+
COPY ./test /app/test
9+
COPY ./lib /app/lib
10+
COPY ./src /app/src
711

812
EXPOSE 5000
9-
CMD ["python", "app.py"]
13+
CMD ["waitress-serve", "--host=0.0.0.0", "--port=5000", "--call", "src:create_app"]

config.txt.sample

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
OPENAI_API_KEY=

docker-compose.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ version: '3'
22
services:
33
python-proxy:
44
build: .
5+
env_file: "config.txt"
56
ports:
67
- "5000:5000"

lib/__init__.py

Whitespace-only changes.

lib/assess.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Normal imports
2+
import csv, glob, json, time, os
3+
from multiprocessing import Pool
4+
import concurrent.futures
5+
import io
6+
import json
7+
8+
# Import our support classes
9+
from lib.config import SUPPORTED_MODELS, VALID_GRADES
10+
from lib.grade import Grade
11+
from lib.report import Report
12+
from lib.rubric_tester import (
13+
read_inputs,
14+
get_expected_grades,
15+
get_examples,
16+
get_passing_grades,
17+
get_student_files,
18+
validate_rubrics,
19+
validate_students,
20+
grade_student_work,
21+
compute_accuracy
22+
)
23+
24+
def grade(code, prompt, rubric, api_key='', llm_model='gpt-4', num_responses=1, temperature=0.2, num_passing_grades=2, remove_comments=False):
25+
llm_model = 'gpt-4'
26+
num_responses = 1
27+
temperature = 0.2
28+
num_passing_grades = 2
29+
OPENAI_API_KEY = api_key
30+
31+
# Set the key
32+
if OPENAI_API_KEY:
33+
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
34+
elif not 'OPENAI_API_KEY' in os.environ:
35+
print("Must set OPENAI_API_KEY!")
36+
return {}
37+
else:
38+
print("Using set OPENAI_API_KEY")
39+
40+
grade = Grade()
41+
return grade.grade_student_work(
42+
prompt, rubric, code, "student", [],
43+
use_cached=False,
44+
write_cached=False,
45+
num_responses=num_responses,
46+
temperature=temperature,
47+
llm_model=llm_model,
48+
remove_comments=remove_comments
49+
)

lib/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
VALID_GRADES = ["Extensive Evidence", "Convincing Evidence", "Limited Evidence", "No Evidence"]
2+
SUPPORTED_MODELS = ['gpt-4', 'gpt-4-0314', 'gpt-4-32k', 'gpt-4-32k-0314']

lib/grade.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import os
2+
import json
3+
import csv
4+
import time
5+
import requests
6+
7+
from typing import List, Dict, Any
8+
from lib.config import VALID_GRADES
9+
10+
from io import StringIO
11+
12+
class InvalidResponseError(Exception):
13+
pass
14+
15+
class Grade:
16+
def __init__(self):
17+
pass
18+
19+
def grade_student_work(self, prompt, rubric, student_code, student_id, examples=[], use_cached=False, write_cached=True, num_responses=0, temperature=0.0, llm_model="", remove_comments=False):
20+
if use_cached and os.path.exists(f"cached_responses/{student_id}.json"):
21+
with open(f"cached_responses/{student_id}.json", 'r') as f:
22+
return json.load(f)
23+
24+
api_url = 'https://api.openai.com/v1/chat/completions'
25+
headers = {
26+
'Content-Type': 'application/json',
27+
'Authorization': f"Bearer {os.getenv('OPENAI_API_KEY')}"
28+
}
29+
30+
# Sanitize student code
31+
student_code = self.sanitize_code(student_code, remove_comments=remove_comments)
32+
33+
messages = self.compute_messages(prompt, rubric, student_code, examples=examples)
34+
data = {
35+
'model': llm_model,
36+
'temperature': temperature,
37+
'messages': messages,
38+
'n': num_responses,
39+
}
40+
41+
start_time = time.time()
42+
try:
43+
response = requests.post(api_url, headers=headers, json=data, timeout=120)
44+
except requests.exceptions.ReadTimeout:
45+
print(f"{student_id} request timed out in {(time.time() - start_time):.0f} seconds.")
46+
return None
47+
48+
if response.status_code != 200:
49+
print(f"{student_id} Error calling the API: {response.status_code}")
50+
print(f"{student_id} Response body: {response.text}")
51+
return None
52+
53+
tokens = response.json()['usage']['total_tokens']
54+
print(f"{student_id} request succeeded in {(time.time() - start_time):.0f} seconds. {tokens} tokens used.")
55+
56+
tsv_data_choices = [self.get_tsv_data_if_valid(choice['message']['content'], rubric, student_id, choice_index=index) for index, choice in enumerate(response.json()['choices']) if choice['message']['content']]
57+
tsv_data_choices = [choice for choice in tsv_data_choices if choice]
58+
59+
if len(tsv_data_choices) == 0:
60+
tsv_data = None
61+
elif len(tsv_data_choices) == 1:
62+
tsv_data = tsv_data_choices[0]
63+
else:
64+
tsv_data = self.get_consensus_response(tsv_data_choices, student_id)
65+
66+
# only write to cache if the response is valid
67+
if write_cached and tsv_data:
68+
with open(f"cached_responses/{student_id}.json", 'w') as f:
69+
json.dump(tsv_data, f, indent=4)
70+
71+
return tsv_data
72+
73+
def sanitize_code(self, student_code, remove_comments=False):
74+
# Remove comments
75+
if remove_comments:
76+
student_code = "\n".join(
77+
list(
78+
map(lambda x:
79+
x[0:x.index("//")] if "//" in x else x,
80+
student_code.split('\n')
81+
)
82+
)
83+
)
84+
85+
return student_code
86+
87+
def compute_messages(self, prompt, rubric, student_code, examples=[]):
88+
messages = [
89+
{'role': 'system', 'content': f"{prompt}\n\nRubric:\n{rubric}"}
90+
]
91+
for example_js, example_rubric in examples:
92+
messages.append({'role': 'user', 'content': example_js})
93+
messages.append({'role': 'assistant', 'content': example_rubric})
94+
messages.append({'role': 'user', 'content': student_code})
95+
return messages
96+
97+
def get_tsv_data_if_valid(self, response_text, rubric, student_id, choice_index=None):
98+
choice_text = f"Choice {choice_index}: " if choice_index is not None else ''
99+
if not response_text:
100+
print(f"{student_id} {choice_text} Invalid response: empty response")
101+
return None
102+
text = response_text.strip()
103+
104+
# Remove anything up to the first column name
105+
if "\nKey Concept" in text:
106+
index = text.index("\nKey Concept")
107+
text = text[index:].strip()
108+
109+
# Replace escaped tabs
110+
if '\\t' in text:
111+
text = text.replace("\\t", "\t")
112+
113+
# Replace double tabs... ugh
114+
text = text.replace("\t\t", "\t")
115+
116+
# If there is a tab, it is probably TSV
117+
if '\t' not in text:
118+
if ' | ' in text:
119+
# Ok, sometimes it does markdown sequence... which means it does '|'
120+
# as a delimiter and has lines with '---' in them
121+
lines = text.split('\n')
122+
lines = list(filter(lambda x: "---" not in x, lines))
123+
text = "\n".join(lines)
124+
print("response was markdown and not tsv, delimiting by '|'")
125+
126+
tsv_data = list(csv.DictReader(StringIO(text), delimiter='|'))
127+
else:
128+
# Let's assume it is CSV
129+
print("response had no tabs so is not tsv, delimiting by ','")
130+
tsv_data = list(csv.DictReader(StringIO(text), delimiter=','))
131+
else:
132+
# Let's assume it is TSV
133+
tsv_data = list(csv.DictReader(StringIO(text), delimiter='\t'))
134+
135+
try:
136+
self.sanitize_server_response(tsv_data)
137+
self.validate_server_response(tsv_data, rubric)
138+
return [row for row in tsv_data]
139+
except InvalidResponseError as e:
140+
print(f"{student_id} {choice_text} Invalid response: {str(e)}\n{response_text}")
141+
return None
142+
143+
def parse_tsv(self, tsv_text):
144+
rows = tsv_text.split("\n")
145+
header = rows.pop(0).split("\t")
146+
return [dict(zip(header, row.split("\t"))) for row in rows]
147+
148+
def sanitize_server_response(self, tsv_data):
149+
if not isinstance(tsv_data, list):
150+
return
151+
152+
# Strip whitespace and quotes from fields
153+
for row in tsv_data:
154+
for key in list(row.keys()):
155+
if isinstance(row[key], str):
156+
row[key] = row[key].strip().strip('"')
157+
158+
if isinstance(key, str):
159+
if key.strip() != key:
160+
row[key.strip()] = row[key]
161+
del row[key]
162+
163+
# Remove rows that don't start with reasonable things
164+
for row in tsv_data:
165+
if "Key Concept" in row:
166+
if not row["Key Concept"][0:1].isalnum():
167+
tsv_data.remove(row)
168+
169+
def validate_server_response(self, tsv_data, rubric):
170+
expected_columns = ["Key Concept", "Observations", "Grade", "Reason"]
171+
172+
rubric_key_concepts = list(set(row['Key Concept'] for row in csv.DictReader(rubric.splitlines())))
173+
174+
if not isinstance(tsv_data, list):
175+
raise InvalidResponseError('invalid format')
176+
177+
if not all((set(row.keys()) & set(expected_columns)) == set(expected_columns) for row in tsv_data):
178+
raise InvalidResponseError('incorrect column names')
179+
180+
key_concepts_from_response = list(set(row["Key Concept"] for row in tsv_data))
181+
if sorted(rubric_key_concepts) != sorted(key_concepts_from_response):
182+
raise InvalidResponseError('invalid or missing key concept')
183+
184+
for row in tsv_data:
185+
if row["Grade"] not in VALID_GRADES:
186+
raise InvalidResponseError(f"invalid grade value: '{row['Grade']}'")
187+
188+
def get_consensus_response(self, choices, student_id):
189+
from collections import Counter
190+
191+
key_concept_to_grades = {}
192+
for choice in choices:
193+
for row in choice:
194+
if row['Key Concept'] not in key_concept_to_grades:
195+
key_concept_to_grades[row['Key Concept']] = []
196+
key_concept_to_grades[row['Key Concept']].append(row['Grade'])
197+
198+
key_concept_to_majority_grade = {}
199+
for key_concept, grades in key_concept_to_grades.items():
200+
majority_grade = Counter(grades).most_common(1)[0][0]
201+
key_concept_to_majority_grade[key_concept] = majority_grade
202+
if majority_grade != grades[0]:
203+
print(f"outvoted {student_id} Key Concept: {key_concept} first grade: {grades[0]} majority grade: {majority_grade}")
204+
205+
key_concept_to_observations = {}
206+
key_concept_to_reason = {}
207+
for choice in choices:
208+
for row in choice:
209+
key_concept = row['Key Concept']
210+
if key_concept_to_majority_grade[key_concept] == row['Grade']:
211+
if key_concept not in key_concept_to_observations:
212+
key_concept_to_observations[key_concept] = row['Observations']
213+
key_concept_to_reason[key_concept] = row['Reason']
214+
215+
return [{'Key Concept': key_concept, 'Observations': key_concept_to_observations[key_concept], 'Grade': grade, 'Reason': f"<b>Votes: [{', '.join(key_concept_to_grades[key_concept])}]</b><br>{key_concept_to_reason[key_concept]}"} for key_concept, grade in key_concept_to_majority_grade.items()]

0 commit comments

Comments
 (0)