-
Notifications
You must be signed in to change notification settings - Fork 26
Get reformulator working #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: public
Are you sure you want to change the base?
Changes from 4 commits
f7a97ce
a3fdc97
212b88c
c364342
0c0aec0
cf2fca5
811d4e0
f34de04
25813f6
f907ea1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
import litellm | ||
|
||
from utils.misc import load_formatting_funcs, replace_media | ||
from utils.llm import load_api_keys, llm_price, tkn_len, chat, model_name_matcher | ||
from utils.llm import llm_price, tkn_len, chat, model_name_matcher | ||
from utils.anki import anki, sync_anki, addtags, removetags, updatenote | ||
from utils.logger import create_loggers | ||
from utils.datasets import load_dataset, semantic_prompt_filtering | ||
|
@@ -51,9 +51,6 @@ | |
d = datetime.datetime.today() | ||
today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}" | ||
|
||
whi("Loading api keys") | ||
load_api_keys() | ||
|
||
|
||
# status string | ||
STAT_CHANGED_CONT = "Content has been changed" | ||
|
@@ -210,8 +207,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): | |
litellm.set_verbose = verbose | ||
|
||
# arg sanity check and storing | ||
# TODO: Is this needed? The example in the readme doesn't set it | ||
# assert "note:" in query, f"You have to specify a notetype in the query ({query})" | ||
assert "note:" in query, f"You have to specify a notetype in the query ({query})" | ||
assert mode in ["reformulate", "reset"], "Invalid value for 'mode'" | ||
assert isinstance(exclude_done, bool), "exclude_done must be a boolean" | ||
assert isinstance(exclude_version, bool), "exclude_version must be a boolean" | ||
|
@@ -266,21 +262,26 @@ def handle_exception(exc_type, exc_value, exc_traceback): | |
query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\"" | ||
|
||
# load db just in case | ||
|
||
# TODO: How is the user supposed to create the database in the first place? | ||
# self.db_content = self.load_db() | ||
# if not self.db_content: | ||
# red("Empty database. If you have already ran anki_reformulator " | ||
# "before then something went wrong!") | ||
# else: | ||
# self.compute_cost(self.db_content) | ||
self.db_content = self.load_db() | ||
if not self.db_content: | ||
red("Empty database. If you have already ran anki_reformulator " | ||
"before then something went wrong!") | ||
whi("Trying to create a new database") | ||
self.save_to_db({}) | ||
self.db_content = self.load_db() | ||
assert self.db_content, "Could not create database" | ||
|
||
# TODO: What should be in the database normally? This fails with an empty database | ||
whi("Computing estimated costs") | ||
# self.compute_cost(self.db_content) | ||
Grazfather marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# load dataset | ||
whi("Loading dataset") | ||
dataset = load_dataset(dataset_path) | ||
# check that each note is valid but exclude the system prompt | ||
for id, d in enumerate(dataset): | ||
if id != 0: | ||
dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"] | ||
# check that each note is valid but exclude the system prompt, which is | ||
# the first entry | ||
for id, d in enumerate(dataset[1:]): | ||
dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"] | ||
assert len(dataset) % 2 == 1, "Even number of examples in dataset" | ||
self.dataset = dataset | ||
|
||
|
@@ -293,15 +294,17 @@ def handle_exception(exc_type, exc_value, exc_traceback): | |
if nids: | ||
red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}") | ||
|
||
# find notes ids for the first time | ||
nids = anki(action="findNotes", query=query) | ||
# find notes ids for the specific note type | ||
nids = anki(action="findNotes", query="note:AnkiAITest") | ||
# nids = anki(action="findNotes", query=query) | ||
assert nids, f"No notes found for the query '{query}'" | ||
|
||
# find the model field names | ||
# find the field names for this note type | ||
fields = anki(action="notesInfo", | ||
notes=[int(nids[0])])[0]["fields"] | ||
# assert "AnkiReformulator" in fields.keys(), \ | ||
# "The notetype to edit must have a field called 'AnkiReformulator'" | ||
assert "AnkiReformulator" in fields.keys(), \ | ||
"The notetype to edit must have a field called 'AnkiReformulator'" | ||
# NOTE: This gets the first field. Is that what we want? Or do we specifically want the AnkiReformulator field? | ||
thiswillbeyourgithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.field_name = list(fields.keys())[0] | ||
|
||
if self.exclude_media: | ||
|
@@ -323,13 +326,13 @@ def handle_exception(exc_type, exc_value, exc_traceback): | |
anki(action="notesInfo", notes=nids) | ||
).set_index("noteId") | ||
self.notes = self.notes.loc[nids] | ||
assert not self.notes.empty, "Empty notes df" | ||
assert not self.notes.empty, "Empty notes" | ||
thiswillbeyourgithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
assert len(set(self.notes["modelName"].tolist())) == 1, \ | ||
"Contains more than 1 note type" | ||
|
||
# check absence of image and sounds in the main field | ||
# as well incorrect tags | ||
# as well as incorrect tags | ||
for nid, note in self.notes.iterrows(): | ||
if self.exclude_media: | ||
_, media = replace_media( | ||
|
@@ -355,43 +358,31 @@ def handle_exception(exc_type, exc_value, exc_traceback): | |
assert not tag.lower().startswith("ankireformulator") | ||
|
||
# check if too many tokens | ||
tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset]) | ||
tkn_sum += sum( | ||
tkn_len( | ||
replace_media( | ||
content=note["fields"][self.field_name]["value"], | ||
media=None, | ||
mode="remove_media", | ||
)[0] | ||
) | ||
for _, note in self.notes.iterrows() | ||
) | ||
if tkn_sum > tkn_warn_limit: | ||
raise Exception( | ||
f"Found {tkn_sum} tokens to process, which is " | ||
f"higher than the limit of {tkn_warn_limit}" | ||
) | ||
tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset) | ||
tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"], | ||
media=None, | ||
mode="remove_media")[0]) | ||
for _, note in self.notes.iterrows()) | ||
assert tkn_sum <= tkn_warn_limit, (f"Found {tkn_sum} tokens to process, which is " | ||
f"higher than the limit of {tkn_warn_limit}") | ||
|
||
if len(self.notes) > n_note_limit: | ||
raise Exception( | ||
f"Found {len(self.notes)} notes to process " | ||
f"which is higher than the limit of {n_note_limit}" | ||
) | ||
assert len(self.notes) <= n_note_limit, (f"Found {len(self.notes)} notes to process " | ||
f"which is higher than the limit of {n_note_limit}") | ||
|
||
if self.mode == "reformulate": | ||
func = self.reformulate | ||
elif self.mode == "reset": | ||
func = self.reset | ||
else: | ||
raise ValueError(self.mode) | ||
raise ValueError(f"Unknown mode {self.mode}") | ||
|
||
def error_wrapped_func(*args, **kwargs): | ||
"""Wrapper that catches exceptions and marks failed notes with appropriate tags.""" | ||
try: | ||
return func(*args, **kwargs) | ||
except Exception as err: | ||
addtags(nid=note.name, tags="AnkiReformulator::FAILED") | ||
red(f"Error when running self.{self.mode}: '{err}'") | ||
red(f"Error when running self.{func.__name__}: '{err}'") | ||
return str(err) | ||
|
||
# getting all the new values in parallel and using caching | ||
|
@@ -407,11 +398,9 @@ def error_wrapped_func(*args, **kwargs): | |
) | ||
) | ||
|
||
failed_runs = [ | ||
self.notes.iloc[i_nv] | ||
for i_nv in range(len(new_values)) | ||
if isinstance(new_values[i_nv], str) | ||
] | ||
failed_runs = [self.notes.iloc[i_nv] | ||
for i_nv in range(len(new_values)) | ||
if isinstance(new_values[i_nv], str)] | ||
if failed_runs: | ||
red(f"Found {len(failed_runs)} failed notes") | ||
failed_run_index = pd.DataFrame(failed_runs).index | ||
|
@@ -421,6 +410,7 @@ def error_wrapped_func(*args, **kwargs): | |
assert len(new_values) == len(self.notes) | ||
|
||
# applying the changes | ||
whi("Applying changes") | ||
for values in tqdm(new_values, desc="Applying changes to anki"): | ||
if self.mode == "reformulate": | ||
self.apply_reformulate(values) | ||
|
@@ -429,8 +419,10 @@ def error_wrapped_func(*args, **kwargs): | |
else: | ||
raise ValueError(self.mode) | ||
|
||
whi("Clearing unused tags") | ||
anki(action="clearUnusedTags") | ||
|
||
# TODO: Why add and them remove them? | ||
# add and remove the tag TODO to make it easier to re add by the user | ||
# as it was cleared by calling 'clearUnusedTags' | ||
nid, note = next(self.notes.iterrows()) | ||
|
@@ -439,7 +431,7 @@ def error_wrapped_func(*args, **kwargs): | |
|
||
sync_anki() | ||
|
||
# display again the total cost at the end | ||
# display the total cost again at the end | ||
db = self.load_db() | ||
assert db, "Empty database at the end of the run. Something went wrong?" | ||
self.compute_cost(db) | ||
|
@@ -450,10 +442,11 @@ def compute_cost(self, db_content: List[Dict]) -> None: | |
This is used to know if something went wrong. | ||
""" | ||
n_db = len(db_content) | ||
red(f"Number of entries in databases/reformulator.db: {n_db}") | ||
red(f"Number of entries in databases/reforumulator/reformulator.db: {n_db}") | ||
dol_costs = [] | ||
dol_missing = 0 | ||
for dic in db_content: | ||
# TODO: Mode isn't a field in the reformulator database dictionaries table | ||
thiswillbeyourgithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if dic["mode"] != "reformulate": | ||
continue | ||
try: | ||
|
@@ -478,14 +471,14 @@ def compute_cost(self, db_content: List[Dict]) -> None: | |
|
||
def reformulate(self, nid: int, note: pd.Series) -> Dict: | ||
"""Generate a reformulated version of a note's content using an LLM. | ||
|
||
Parameters | ||
---------- | ||
nid : int | ||
Note ID from Anki | ||
note : pd.Series | ||
Row from the notes DataFrame containing the note data | ||
|
||
Returns | ||
------- | ||
Dict | ||
|
@@ -529,8 +522,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: | |
elif d["role"] == "user": | ||
newcontent = self.dataset[i + 1]["content"] | ||
else: | ||
raise ValueError( | ||
f"Unexpected role of message in dataset: {d}") | ||
raise ValueError(f"Unexpected role of message in dataset: {d}") | ||
skip_llm = True | ||
break | ||
|
||
|
@@ -606,8 +598,11 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: | |
log["note_field_formattednewcontent"] = formattednewcontent | ||
log["status"] = STAT_OK_REFORM | ||
|
||
if iscloze(content + newcontent + formattednewcontent): | ||
if iscloze(content) and iscloze( newcontent + formattednewcontent): | ||
# check that no cloze were lost | ||
# TODO: Bug here: `iscloze` can return true if the new content is a | ||
# close, but if the original content is not a cloze, then this | ||
# fails | ||
thiswillbeyourgithub marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for cl in getclozes(content): | ||
cl = cl.split("::")[0] + "::" | ||
assert cl.startswith("{{c") and cl in content | ||
|
@@ -622,7 +617,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: | |
|
||
def apply_reformulate(self, log: Dict) -> None: | ||
"""Apply reformulation changes to an Anki note and update its metadata. | ||
|
||
Parameters | ||
---------- | ||
log : Dict | ||
|
@@ -692,14 +687,14 @@ def apply_reformulate(self, log: Dict) -> None: | |
|
||
def reset(self, nid: int, note: pd.Series) -> Dict: | ||
"""Reset a note back to its state before reformulation. | ||
|
||
Parameters | ||
---------- | ||
nid : int | ||
Note ID from Anki | ||
note : pd.Series | ||
Row from the notes DataFrame containing the note data | ||
|
||
Returns | ||
------- | ||
Dict | ||
|
@@ -873,7 +868,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict: | |
|
||
def apply_reset(self, log: Dict) -> None: | ||
"""Apply reset changes to an Anki note and update its metadata. | ||
|
||
Parameters | ||
---------- | ||
log : Dict | ||
|
@@ -940,12 +935,12 @@ def apply_reset(self, log: Dict) -> None: | |
|
||
def save_to_db(self, dictionnary: Dict) -> bool: | ||
"""Save a log dictionary to the SQLite database. | ||
|
||
Parameters | ||
---------- | ||
dictionnary : Dict | ||
Log dictionary to save | ||
|
||
Returns | ||
------- | ||
bool | ||
|
@@ -970,7 +965,7 @@ def save_to_db(self, dictionnary: Dict) -> bool: | |
|
||
def load_db(self) -> Dict: | ||
"""Load all log dictionaries from the SQLite database. | ||
|
||
Returns | ||
------- | ||
Dict | ||
|
@@ -999,5 +994,8 @@ def load_db(self) -> Dict: | |
whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'") | ||
AnkiReformulator(*args, **kwargs) | ||
sync_anki() | ||
except Exception: | ||
except AssertionError as e: | ||
red(e) | ||
except Exception as e: | ||
red(e) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure about this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a personal thing, we can change it. Basically, I make it so that if we fail an assert, it's probably an issue with the invocation, and we log it. If it's another type of exception, then we probably don't expect it, so we log it, but we also re-raise the exception to print the stack trace. I can remove it, but it's helping me with debugging. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whatever helps you helps all of us here so no biggy |
||
raise |
Uh oh!
There was an error while loading. Please reload this page.