Skip to content

Get reformulator working #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: public
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,13 +387,14 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a
Click to read more
</summary>

First, create an _API_KEYS/_ directory and place your API key in a separate file.
First, ensure that you API keys are set in you env variables.

Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it.

#### Reformulator

Next... create a database? it expects a sqlite db in databases/reformulator/reformulator?
* Can handle it in code

Next... something about adding a field called `AnkiReformulator` to notes you want to change?
* Do you have to create a special note type for this to work?
Expand All @@ -402,7 +403,7 @@ The Reformulator can be run from the command line:

```bash
python reformulator.py \
--query "(rated:2:1 OR rated:2:2) -is:suspended" \
--query "note:Basic (rated:2:1 OR rated:2:2) -is:suspended" \
--dataset_path "examples/reformulator_dataset.txt" \
--string_formatting "examples/string_formatting.py" \
--ntfy_url "ntfy.sh/YOUR_TOPIC" \
Expand Down
134 changes: 66 additions & 68 deletions reformulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import litellm

from utils.misc import load_formatting_funcs, replace_media
from utils.llm import load_api_keys, llm_price, tkn_len, chat, model_name_matcher
from utils.llm import llm_price, tkn_len, chat, model_name_matcher
from utils.anki import anki, sync_anki, addtags, removetags, updatenote
from utils.logger import create_loggers
from utils.datasets import load_dataset, semantic_prompt_filtering
Expand All @@ -51,9 +51,6 @@
d = datetime.datetime.today()
today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}"

whi("Loading api keys")
load_api_keys()


# status string
STAT_CHANGED_CONT = "Content has been changed"
Expand Down Expand Up @@ -210,8 +207,7 @@ def handle_exception(exc_type, exc_value, exc_traceback):
litellm.set_verbose = verbose

# arg sanity check and storing
# TODO: Is this needed? The example in the readme doesn't set it
# assert "note:" in query, f"You have to specify a notetype in the query ({query})"
assert "note:" in query, f"You have to specify a notetype in the query ({query})"
assert mode in ["reformulate", "reset"], "Invalid value for 'mode'"
assert isinstance(exclude_done, bool), "exclude_done must be a boolean"
assert isinstance(exclude_version, bool), "exclude_version must be a boolean"
Expand Down Expand Up @@ -266,21 +262,26 @@ def handle_exception(exc_type, exc_value, exc_traceback):
query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\""

# load db just in case

# TODO: How is the user supposed to create the database in the first place?
# self.db_content = self.load_db()
# if not self.db_content:
# red("Empty database. If you have already ran anki_reformulator "
# "before then something went wrong!")
# else:
# self.compute_cost(self.db_content)
self.db_content = self.load_db()
if not self.db_content:
red("Empty database. If you have already ran anki_reformulator "
"before then something went wrong!")
whi("Trying to create a new database")
self.save_to_db({})
self.db_content = self.load_db()
assert self.db_content, "Could not create database"

# TODO: What should be in the database normally? This fails with an empty database
whi("Computing estimated costs")
# self.compute_cost(self.db_content)

# load dataset
whi("Loading dataset")
dataset = load_dataset(dataset_path)
# check that each note is valid but exclude the system prompt
for id, d in enumerate(dataset):
if id != 0:
dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"]
# check that each note is valid but exclude the system prompt, which is
# the first entry
for id, d in enumerate(dataset[1:]):
dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"]
assert len(dataset) % 2 == 1, "Even number of examples in dataset"
self.dataset = dataset

Expand All @@ -293,15 +294,17 @@ def handle_exception(exc_type, exc_value, exc_traceback):
if nids:
red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}")

# find notes ids for the first time
nids = anki(action="findNotes", query=query)
# find notes ids for the specific note type
nids = anki(action="findNotes", query="note:AnkiAITest")
# nids = anki(action="findNotes", query=query)
assert nids, f"No notes found for the query '{query}'"

# find the model field names
# find the field names for this note type
fields = anki(action="notesInfo",
notes=[int(nids[0])])[0]["fields"]
# assert "AnkiReformulator" in fields.keys(), \
# "The notetype to edit must have a field called 'AnkiReformulator'"
assert "AnkiReformulator" in fields.keys(), \
"The notetype to edit must have a field called 'AnkiReformulator'"
# NOTE: This gets the first field. Is that what we want? Or do we specifically want the AnkiReformulator field?
self.field_name = list(fields.keys())[0]

if self.exclude_media:
Expand All @@ -323,13 +326,13 @@ def handle_exception(exc_type, exc_value, exc_traceback):
anki(action="notesInfo", notes=nids)
).set_index("noteId")
self.notes = self.notes.loc[nids]
assert not self.notes.empty, "Empty notes df"
assert not self.notes.empty, "Empty notes"

assert len(set(self.notes["modelName"].tolist())) == 1, \
"Contains more than 1 note type"

# check absence of image and sounds in the main field
# as well incorrect tags
# as well as incorrect tags
for nid, note in self.notes.iterrows():
if self.exclude_media:
_, media = replace_media(
Expand All @@ -355,43 +358,31 @@ def handle_exception(exc_type, exc_value, exc_traceback):
assert not tag.lower().startswith("ankireformulator")

# check if too many tokens
tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset])
tkn_sum += sum(
tkn_len(
replace_media(
content=note["fields"][self.field_name]["value"],
media=None,
mode="remove_media",
)[0]
)
for _, note in self.notes.iterrows()
)
if tkn_sum > tkn_warn_limit:
raise Exception(
f"Found {tkn_sum} tokens to process, which is "
f"higher than the limit of {tkn_warn_limit}"
)
tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset)
tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"],
media=None,
mode="remove_media")[0])
for _, note in self.notes.iterrows())
assert tkn_sum <= tkn_warn_limit, (f"Found {tkn_sum} tokens to process, which is "
f"higher than the limit of {tkn_warn_limit}")

if len(self.notes) > n_note_limit:
raise Exception(
f"Found {len(self.notes)} notes to process "
f"which is higher than the limit of {n_note_limit}"
)
assert len(self.notes) <= n_note_limit, (f"Found {len(self.notes)} notes to process "
f"which is higher than the limit of {n_note_limit}")

if self.mode == "reformulate":
func = self.reformulate
elif self.mode == "reset":
func = self.reset
else:
raise ValueError(self.mode)
raise ValueError(f"Unknown mode {self.mode}")

def error_wrapped_func(*args, **kwargs):
"""Wrapper that catches exceptions and marks failed notes with appropriate tags."""
try:
return func(*args, **kwargs)
except Exception as err:
addtags(nid=note.name, tags="AnkiReformulator::FAILED")
red(f"Error when running self.{self.mode}: '{err}'")
red(f"Error when running self.{func.__name__}: '{err}'")
return str(err)

# getting all the new values in parallel and using caching
Expand All @@ -407,11 +398,9 @@ def error_wrapped_func(*args, **kwargs):
)
)

failed_runs = [
self.notes.iloc[i_nv]
for i_nv in range(len(new_values))
if isinstance(new_values[i_nv], str)
]
failed_runs = [self.notes.iloc[i_nv]
for i_nv in range(len(new_values))
if isinstance(new_values[i_nv], str)]
if failed_runs:
red(f"Found {len(failed_runs)} failed notes")
failed_run_index = pd.DataFrame(failed_runs).index
Expand All @@ -421,6 +410,7 @@ def error_wrapped_func(*args, **kwargs):
assert len(new_values) == len(self.notes)

# applying the changes
whi("Applying changes")
for values in tqdm(new_values, desc="Applying changes to anki"):
if self.mode == "reformulate":
self.apply_reformulate(values)
Expand All @@ -429,8 +419,10 @@ def error_wrapped_func(*args, **kwargs):
else:
raise ValueError(self.mode)

whi("Clearing unused tags")
anki(action="clearUnusedTags")

# TODO: Why add and them remove them?
# add and remove the tag TODO to make it easier to re add by the user
# as it was cleared by calling 'clearUnusedTags'
nid, note = next(self.notes.iterrows())
Expand All @@ -439,7 +431,7 @@ def error_wrapped_func(*args, **kwargs):

sync_anki()

# display again the total cost at the end
# display the total cost again at the end
db = self.load_db()
assert db, "Empty database at the end of the run. Something went wrong?"
self.compute_cost(db)
Expand All @@ -450,10 +442,11 @@ def compute_cost(self, db_content: List[Dict]) -> None:
This is used to know if something went wrong.
"""
n_db = len(db_content)
red(f"Number of entries in databases/reformulator.db: {n_db}")
red(f"Number of entries in databases/reforumulator/reformulator.db: {n_db}")
dol_costs = []
dol_missing = 0
for dic in db_content:
# TODO: Mode isn't a field in the reformulator database dictionaries table
if dic["mode"] != "reformulate":
continue
try:
Expand All @@ -478,14 +471,14 @@ def compute_cost(self, db_content: List[Dict]) -> None:

def reformulate(self, nid: int, note: pd.Series) -> Dict:
"""Generate a reformulated version of a note's content using an LLM.

Parameters
----------
nid : int
Note ID from Anki
note : pd.Series
Row from the notes DataFrame containing the note data

Returns
-------
Dict
Expand Down Expand Up @@ -529,8 +522,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
elif d["role"] == "user":
newcontent = self.dataset[i + 1]["content"]
else:
raise ValueError(
f"Unexpected role of message in dataset: {d}")
raise ValueError(f"Unexpected role of message in dataset: {d}")
skip_llm = True
break

Expand Down Expand Up @@ -606,8 +598,11 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:
log["note_field_formattednewcontent"] = formattednewcontent
log["status"] = STAT_OK_REFORM

if iscloze(content + newcontent + formattednewcontent):
if iscloze(content) and iscloze( newcontent + formattednewcontent):
# check that no cloze were lost
# TODO: Bug here: `iscloze` can return true if the new content is a
# close, but if the original content is not a cloze, then this
# fails
for cl in getclozes(content):
cl = cl.split("::")[0] + "::"
assert cl.startswith("{{c") and cl in content
Expand All @@ -622,7 +617,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict:

def apply_reformulate(self, log: Dict) -> None:
"""Apply reformulation changes to an Anki note and update its metadata.

Parameters
----------
log : Dict
Expand Down Expand Up @@ -692,14 +687,14 @@ def apply_reformulate(self, log: Dict) -> None:

def reset(self, nid: int, note: pd.Series) -> Dict:
"""Reset a note back to its state before reformulation.

Parameters
----------
nid : int
Note ID from Anki
note : pd.Series
Row from the notes DataFrame containing the note data

Returns
-------
Dict
Expand Down Expand Up @@ -873,7 +868,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict:

def apply_reset(self, log: Dict) -> None:
"""Apply reset changes to an Anki note and update its metadata.

Parameters
----------
log : Dict
Expand Down Expand Up @@ -940,12 +935,12 @@ def apply_reset(self, log: Dict) -> None:

def save_to_db(self, dictionnary: Dict) -> bool:
"""Save a log dictionary to the SQLite database.

Parameters
----------
dictionnary : Dict
Log dictionary to save

Returns
-------
bool
Expand All @@ -970,7 +965,7 @@ def save_to_db(self, dictionnary: Dict) -> bool:

def load_db(self) -> Dict:
"""Load all log dictionaries from the SQLite database.

Returns
-------
Dict
Expand Down Expand Up @@ -999,5 +994,8 @@ def load_db(self) -> Dict:
whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'")
AnkiReformulator(*args, **kwargs)
sync_anki()
except Exception:
except AssertionError as e:
red(e)
except Exception as e:
red(e)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a personal thing, we can change it.

Basically, I make it so that if we fail an assert, it's probably an issue with the invocation, and we log it. If it's another type of exception, then we probably don't expect it, so we log it, but we also re-raise the exception to print the stack trace. I can remove it, but it's helping me with debugging.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever helps you helps all of us here so no biggy

raise
2 changes: 1 addition & 1 deletion utils/cloze_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def iscloze(text: str) -> bool:

def getclozes(text: str) -> List[str]:
"return the cloze found in the text. Should only be called on cloze notes"
assert iscloze(text)
assert iscloze(text), f"Text '{text}' does not contain a cloze"
return re.findall(CLOZE_REGEX, text)


Expand Down
Loading