|
| 1 | +import os |
| 2 | +import json |
| 3 | +import csv |
| 4 | +import time |
| 5 | +import requests |
| 6 | + |
| 7 | +from typing import List, Dict, Any |
| 8 | +from lib.config import VALID_GRADES |
| 9 | + |
| 10 | +from io import StringIO |
| 11 | + |
| 12 | +class InvalidResponseError(Exception): |
| 13 | + pass |
| 14 | + |
| 15 | +class Grade: |
| 16 | + def __init__(self): |
| 17 | + pass |
| 18 | + |
| 19 | + def grade_student_work(self, prompt, rubric, student_code, student_id, examples=[], use_cached=False, write_cached=True, num_responses=0, temperature=0.0, llm_model="", remove_comments=False): |
| 20 | + if use_cached and os.path.exists(f"cached_responses/{student_id}.json"): |
| 21 | + with open(f"cached_responses/{student_id}.json", 'r') as f: |
| 22 | + return json.load(f) |
| 23 | + |
| 24 | + api_url = 'https://api.openai.com/v1/chat/completions' |
| 25 | + headers = { |
| 26 | + 'Content-Type': 'application/json', |
| 27 | + 'Authorization': f"Bearer {os.getenv('OPENAI_API_KEY')}" |
| 28 | + } |
| 29 | + |
| 30 | + # Sanitize student code |
| 31 | + student_code = self.sanitize_code(student_code, remove_comments=remove_comments) |
| 32 | + |
| 33 | + messages = self.compute_messages(prompt, rubric, student_code, examples=examples) |
| 34 | + data = { |
| 35 | + 'model': llm_model, |
| 36 | + 'temperature': temperature, |
| 37 | + 'messages': messages, |
| 38 | + 'n': num_responses, |
| 39 | + } |
| 40 | + |
| 41 | + start_time = time.time() |
| 42 | + try: |
| 43 | + response = requests.post(api_url, headers=headers, json=data, timeout=120) |
| 44 | + except requests.exceptions.ReadTimeout: |
| 45 | + print(f"{student_id} request timed out in {(time.time() - start_time):.0f} seconds.") |
| 46 | + return None |
| 47 | + |
| 48 | + if response.status_code != 200: |
| 49 | + print(f"{student_id} Error calling the API: {response.status_code}") |
| 50 | + print(f"{student_id} Response body: {response.text}") |
| 51 | + return None |
| 52 | + |
| 53 | + tokens = response.json()['usage']['total_tokens'] |
| 54 | + print(f"{student_id} request succeeded in {(time.time() - start_time):.0f} seconds. {tokens} tokens used.") |
| 55 | + |
| 56 | + tsv_data_choices = [self.get_tsv_data_if_valid(choice['message']['content'], rubric, student_id, choice_index=index) for index, choice in enumerate(response.json()['choices']) if choice['message']['content']] |
| 57 | + tsv_data_choices = [choice for choice in tsv_data_choices if choice] |
| 58 | + |
| 59 | + if len(tsv_data_choices) == 0: |
| 60 | + tsv_data = None |
| 61 | + elif len(tsv_data_choices) == 1: |
| 62 | + tsv_data = tsv_data_choices[0] |
| 63 | + else: |
| 64 | + tsv_data = self.get_consensus_response(tsv_data_choices, student_id) |
| 65 | + |
| 66 | + # only write to cache if the response is valid |
| 67 | + if write_cached and tsv_data: |
| 68 | + with open(f"cached_responses/{student_id}.json", 'w') as f: |
| 69 | + json.dump(tsv_data, f, indent=4) |
| 70 | + |
| 71 | + return tsv_data |
| 72 | + |
| 73 | + def sanitize_code(self, student_code, remove_comments=False): |
| 74 | + # Remove comments |
| 75 | + if remove_comments: |
| 76 | + student_code = "\n".join( |
| 77 | + list( |
| 78 | + map(lambda x: |
| 79 | + x[0:x.index("//")] if "//" in x else x, |
| 80 | + student_code.split('\n') |
| 81 | + ) |
| 82 | + ) |
| 83 | + ) |
| 84 | + |
| 85 | + return student_code |
| 86 | + |
| 87 | + def compute_messages(self, prompt, rubric, student_code, examples=[]): |
| 88 | + messages = [ |
| 89 | + {'role': 'system', 'content': f"{prompt}\n\nRubric:\n{rubric}"} |
| 90 | + ] |
| 91 | + for example_js, example_rubric in examples: |
| 92 | + messages.append({'role': 'user', 'content': example_js}) |
| 93 | + messages.append({'role': 'assistant', 'content': example_rubric}) |
| 94 | + messages.append({'role': 'user', 'content': student_code}) |
| 95 | + return messages |
| 96 | + |
| 97 | + def get_tsv_data_if_valid(self, response_text, rubric, student_id, choice_index=None): |
| 98 | + choice_text = f"Choice {choice_index}: " if choice_index is not None else '' |
| 99 | + if not response_text: |
| 100 | + print(f"{student_id} {choice_text} Invalid response: empty response") |
| 101 | + return None |
| 102 | + text = response_text.strip() |
| 103 | + |
| 104 | + # Remove anything up to the first column name |
| 105 | + if "\nKey Concept" in text: |
| 106 | + index = text.index("\nKey Concept") |
| 107 | + text = text[index:].strip() |
| 108 | + |
| 109 | + # Replace escaped tabs |
| 110 | + if '\\t' in text: |
| 111 | + text = text.replace("\\t", "\t") |
| 112 | + |
| 113 | + # Replace double tabs... ugh |
| 114 | + text = text.replace("\t\t", "\t") |
| 115 | + |
| 116 | + # If there is a tab, it is probably TSV |
| 117 | + if '\t' not in text: |
| 118 | + if ' | ' in text: |
| 119 | + # Ok, sometimes it does markdown sequence... which means it does '|' |
| 120 | + # as a delimiter and has lines with '---' in them |
| 121 | + lines = text.split('\n') |
| 122 | + lines = list(filter(lambda x: "---" not in x, lines)) |
| 123 | + text = "\n".join(lines) |
| 124 | + print("response was markdown and not tsv, delimiting by '|'") |
| 125 | + |
| 126 | + tsv_data = list(csv.DictReader(StringIO(text), delimiter='|')) |
| 127 | + else: |
| 128 | + # Let's assume it is CSV |
| 129 | + print("response had no tabs so is not tsv, delimiting by ','") |
| 130 | + tsv_data = list(csv.DictReader(StringIO(text), delimiter=',')) |
| 131 | + else: |
| 132 | + # Let's assume it is TSV |
| 133 | + tsv_data = list(csv.DictReader(StringIO(text), delimiter='\t')) |
| 134 | + |
| 135 | + try: |
| 136 | + self.sanitize_server_response(tsv_data) |
| 137 | + self.validate_server_response(tsv_data, rubric) |
| 138 | + return [row for row in tsv_data] |
| 139 | + except InvalidResponseError as e: |
| 140 | + print(f"{student_id} {choice_text} Invalid response: {str(e)}\n{response_text}") |
| 141 | + return None |
| 142 | + |
| 143 | + def parse_tsv(self, tsv_text): |
| 144 | + rows = tsv_text.split("\n") |
| 145 | + header = rows.pop(0).split("\t") |
| 146 | + return [dict(zip(header, row.split("\t"))) for row in rows] |
| 147 | + |
| 148 | + def sanitize_server_response(self, tsv_data): |
| 149 | + if not isinstance(tsv_data, list): |
| 150 | + return |
| 151 | + |
| 152 | + # Strip whitespace and quotes from fields |
| 153 | + for row in tsv_data: |
| 154 | + for key in list(row.keys()): |
| 155 | + if isinstance(row[key], str): |
| 156 | + row[key] = row[key].strip().strip('"') |
| 157 | + |
| 158 | + if isinstance(key, str): |
| 159 | + if key.strip() != key: |
| 160 | + row[key.strip()] = row[key] |
| 161 | + del row[key] |
| 162 | + |
| 163 | + # Remove rows that don't start with reasonable things |
| 164 | + for row in tsv_data: |
| 165 | + if "Key Concept" in row: |
| 166 | + if not row["Key Concept"][0:1].isalnum(): |
| 167 | + tsv_data.remove(row) |
| 168 | + |
| 169 | + def validate_server_response(self, tsv_data, rubric): |
| 170 | + expected_columns = ["Key Concept", "Observations", "Grade", "Reason"] |
| 171 | + |
| 172 | + rubric_key_concepts = list(set(row['Key Concept'] for row in csv.DictReader(rubric.splitlines()))) |
| 173 | + |
| 174 | + if not isinstance(tsv_data, list): |
| 175 | + raise InvalidResponseError('invalid format') |
| 176 | + |
| 177 | + if not all((set(row.keys()) & set(expected_columns)) == set(expected_columns) for row in tsv_data): |
| 178 | + raise InvalidResponseError('incorrect column names') |
| 179 | + |
| 180 | + key_concepts_from_response = list(set(row["Key Concept"] for row in tsv_data)) |
| 181 | + if sorted(rubric_key_concepts) != sorted(key_concepts_from_response): |
| 182 | + raise InvalidResponseError('invalid or missing key concept') |
| 183 | + |
| 184 | + for row in tsv_data: |
| 185 | + if row["Grade"] not in VALID_GRADES: |
| 186 | + raise InvalidResponseError(f"invalid grade value: '{row['Grade']}'") |
| 187 | + |
| 188 | + def get_consensus_response(self, choices, student_id): |
| 189 | + from collections import Counter |
| 190 | + |
| 191 | + key_concept_to_grades = {} |
| 192 | + for choice in choices: |
| 193 | + for row in choice: |
| 194 | + if row['Key Concept'] not in key_concept_to_grades: |
| 195 | + key_concept_to_grades[row['Key Concept']] = [] |
| 196 | + key_concept_to_grades[row['Key Concept']].append(row['Grade']) |
| 197 | + |
| 198 | + key_concept_to_majority_grade = {} |
| 199 | + for key_concept, grades in key_concept_to_grades.items(): |
| 200 | + majority_grade = Counter(grades).most_common(1)[0][0] |
| 201 | + key_concept_to_majority_grade[key_concept] = majority_grade |
| 202 | + if majority_grade != grades[0]: |
| 203 | + print(f"outvoted {student_id} Key Concept: {key_concept} first grade: {grades[0]} majority grade: {majority_grade}") |
| 204 | + |
| 205 | + key_concept_to_observations = {} |
| 206 | + key_concept_to_reason = {} |
| 207 | + for choice in choices: |
| 208 | + for row in choice: |
| 209 | + key_concept = row['Key Concept'] |
| 210 | + if key_concept_to_majority_grade[key_concept] == row['Grade']: |
| 211 | + if key_concept not in key_concept_to_observations: |
| 212 | + key_concept_to_observations[key_concept] = row['Observations'] |
| 213 | + key_concept_to_reason[key_concept] = row['Reason'] |
| 214 | + |
| 215 | + return [{'Key Concept': key_concept, 'Observations': key_concept_to_observations[key_concept], 'Grade': grade, 'Reason': f"<b>Votes: [{', '.join(key_concept_to_grades[key_concept])}]</b><br>{key_concept_to_reason[key_concept]}"} for key_concept, grade in key_concept_to_majority_grade.items()] |
0 commit comments