From f7a97ce82aca90c108ae48307c5c496ee14f9387 Mon Sep 17 00:00:00 2001 From: Grazfather Date: Tue, 7 Jan 2025 21:28:17 -0500 Subject: [PATCH 01/10] wip: Make sense of everything --- README.md | 16 +++++++++++--- reformulator.py | 56 ++++++++++++++++++++++--------------------------- utils/llm.py | 1 + utils/logger.py | 2 +- utils/shared.py | 1 - 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 81bec5b..9e75926 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ This collection of scripts is the culmination of my efforts to contributes the A ## Tools -### Illustrator +### Illustrator Creates custom mnemonic images for your cards using AI image generation. It: - Analyzes card content to identify key concepts - Generates creative visual memory hooks @@ -387,14 +387,24 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a Click to read more +First, create an _API_KEYS/_ directory and place your API key in a separate file. + +Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it. + #### Reformulator + +Next... create a database? it expects a sqlite db in databases/reformulator/reformulator? + +Next... something about adding a field called `AnkiReformulator` to notes you want to change? +* Do you have to create a special note type for this to work? + The Reformulator can be run from the command line: ```bash python reformulator.py \ --query "(rated:2:1 OR rated:2:2) -is:suspended" \ - --dataset_path "data/reformulator_dataset.txt" \ - --string_formatting "data/string_formatting.py" \ + --dataset_path "examples/reformulator_dataset.txt" \ + --string_formatting "examples/string_formatting.py" \ --ntfy_url "ntfy.sh/YOUR_TOPIC" \ --main_field_index 0 \ --llm "openai/gpt-4" \ diff --git a/reformulator.py b/reformulator.py index 6bd554c..a55fd32 100644 --- a/reformulator.py +++ b/reformulator.py @@ -51,6 +51,7 @@ d = datetime.datetime.today() today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}" +whi("Loading api keys") load_api_keys() @@ -184,7 +185,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): [print(line) for line in traceback.format_tb(exc_traceback)] print(str(exc_value)) print(str(exc_type)) - print("\n--verbose was used so opening debug console at the " + print("\n--debug was used so opening debug console at the " "appropriate frame. Press 'c' to continue to the frame " "of this print.") pdb.post_mortem(exc_traceback) @@ -203,13 +204,14 @@ def handle_exception(exc_type, exc_value, exc_traceback): print(json.dumps(db_content, ensure_ascii=False, indent=4)) return else: - sync_anki() + # sync_anki() assert query is not None, "Must specify --query" assert dataset_path is not None, "Must specify --dataset_path" litellm.set_verbose = verbose # arg sanity check and storing - assert "note:" in query, "You have to specify a notetype in the query" + # TODO: Is this needed? The example in the readme doesn't set it + # assert "note:" in query, f"You have to specify a notetype in the query ({query})" assert mode in ["reformulate", "reset"], "Invalid value for 'mode'" assert isinstance(exclude_done, bool), "exclude_done must be a boolean" assert isinstance(exclude_version, bool), "exclude_version must be a boolean" @@ -225,7 +227,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): main_field_index = int(main_field_index) assert main_field_index >= 0, "invalid field_index" self.mode = mode - if string_formatting is not None: + if string_formatting: red(f"Loading specific string formatting from {string_formatting}") cloze_input_parser, cloze_output_parser = load_formatting_funcs( path=string_formatting, @@ -264,14 +266,14 @@ def handle_exception(exc_type, exc_value, exc_traceback): query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\"" # load db just in case - self.db_content = self.load_db() - if not self.db_content: - red( - "Empty database. If you have already ran anki_reformulator " - "before then something went wrong!" - ) - else: - self.compute_cost(self.db_content) + + # TODO: How is the user supposed to create the database in the first place? + # self.db_content = self.load_db() + # if not self.db_content: + # red("Empty database. If you have already ran anki_reformulator " + # "before then something went wrong!") + # else: + # self.compute_cost(self.db_content) # load dataset dataset = load_dataset(dataset_path) @@ -286,9 +288,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): nids = anki(action="findNotes", query="tag:AnkiReformulator::RESETTING") if nids: - red( - f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}" - ) + red(f"Found {len(nids)} notes with tag AnkiReformulator::RESETTING : {nids}") nids = anki(action="findNotes", query="tag:AnkiReformulator::DOING") if nids: red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}") @@ -298,13 +298,10 @@ def handle_exception(exc_type, exc_value, exc_traceback): assert nids, f"No notes found for the query '{query}'" # find the model field names - fields = anki( - action="notesInfo", - notes=[int(nids[0])] - )[0]["fields"] - assert ( - "AnkiReformulator" in fields.keys() - ), "The notetype to edit must have a field called 'AnkiReformulator'" + fields = anki(action="notesInfo", + notes=[int(nids[0])])[0]["fields"] + # assert "AnkiReformulator" in fields.keys(), \ + # "The notetype to edit must have a field called 'AnkiReformulator'" self.field_name = list(fields.keys())[0] if self.exclude_media: @@ -328,9 +325,8 @@ def handle_exception(exc_type, exc_value, exc_traceback): self.notes = self.notes.loc[nids] assert not self.notes.empty, "Empty notes df" - assert ( - len(set(self.notes["modelName"].tolist())) == 1 - ), "Contains more than 1 note type" + assert len(set(self.notes["modelName"].tolist())) == 1, \ + "Contains more than 1 note type" # check absence of image and sounds in the main field # as well incorrect tags @@ -358,11 +354,9 @@ def handle_exception(exc_type, exc_value, exc_traceback): else: assert not tag.lower().startswith("ankireformulator") - # check if too many tokens tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset]) tkn_sum += sum( - [ tkn_len( replace_media( content=note["fields"][self.field_name]["value"], @@ -371,7 +365,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): )[0] ) for _, note in self.notes.iterrows() - ]) + ) if tkn_sum > tkn_warn_limit: raise Exception( f"Found {tkn_sum} tokens to process, which is " @@ -983,7 +977,7 @@ def load_db(self) -> Dict: All log dictionaries from the database, or False if database not found """ if not (REFORMULATOR_DIR / "reformulator.db").exists(): - red("db not found: '$REFORMULATOR_DIR/reformulator.db'") + red(f"db not found: '{REFORMULATOR_DIR}/reformulator.db'") return False conn = sqlite3.connect(str((REFORMULATOR_DIR / "reformulator.db").absolute())) cursor = conn.cursor() @@ -1000,10 +994,10 @@ def load_db(self) -> Dict: try: args, kwargs = fire.Fire(lambda *args, **kwargs: [args, kwargs]) if "help" in kwargs: - print(help(AnkiReformulator)) + print(help(AnkiReformulator), file=sys.stderr) else: whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'") AnkiReformulator(*args, **kwargs) + sync_anki() except Exception: - sync_anki() raise diff --git a/utils/llm.py b/utils/llm.py index 5188fe3..2293d55 100644 --- a/utils/llm.py +++ b/utils/llm.py @@ -28,6 +28,7 @@ def load_api_keys() -> Dict: Path("API_KEYS").mkdir(exist_ok=True) if not list(Path("API_KEYS").iterdir()): shared.red("## No API_KEYS found in API_KEYS") + raise Exception("Need to write API KEYS to API_KEYS/") api_keys = {} for apifile in Path("API_KEYS").iterdir(): keyname = f"{apifile.stem.upper()}_API_KEY" diff --git a/utils/logger.py b/utils/logger.py index d2f3db3..c6ff952 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -85,6 +85,6 @@ def create_loggers(local_file: Union[str, PosixPath], colors: List[str]): out = [] for col in colors: log = coloured_logger(col) - setattr(shared, "col", log) + setattr(shared, col, log) out.append(log) return out diff --git a/utils/shared.py b/utils/shared.py index 757dd9a..6fafacd 100644 --- a/utils/shared.py +++ b/utils/shared.py @@ -39,4 +39,3 @@ def __setattr__(self, name: str, value) -> None: raise TypeError(f'SharedModule forbids the creation of unexpected attribute "{name}"') shared = SharedModule() - From a3fdc97a7009441d3a635ed1401570bc256ee9b2 Mon Sep 17 00:00:00 2001 From: Grazfather Date: Wed, 8 Jan 2025 15:42:23 -0500 Subject: [PATCH 02/10] More work --- README.md | 5 +- reformulator.py | 119 +++++++++++++++++++++++----------------------- utils/datasets.py | 28 +++++------ 3 files changed, 75 insertions(+), 77 deletions(-) diff --git a/README.md b/README.md index 9e75926..7df1d63 100644 --- a/README.md +++ b/README.md @@ -387,13 +387,14 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a Click to read more -First, create an _API_KEYS/_ directory and place your API key in a separate file. +First, create an _API_KEYS/_ directory and place your API key in a separate file, or ensure that you API keys are set in you env variables. Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it. #### Reformulator Next... create a database? it expects a sqlite db in databases/reformulator/reformulator? +* Can handle it in code Next... something about adding a field called `AnkiReformulator` to notes you want to change? * Do you have to create a special note type for this to work? @@ -402,7 +403,7 @@ The Reformulator can be run from the command line: ```bash python reformulator.py \ - --query "(rated:2:1 OR rated:2:2) -is:suspended" \ + --query "note:Basic (rated:2:1 OR rated:2:2) -is:suspended" \ --dataset_path "examples/reformulator_dataset.txt" \ --string_formatting "examples/string_formatting.py" \ --ntfy_url "ntfy.sh/YOUR_TOPIC" \ diff --git a/reformulator.py b/reformulator.py index a55fd32..305e18d 100644 --- a/reformulator.py +++ b/reformulator.py @@ -210,8 +210,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): litellm.set_verbose = verbose # arg sanity check and storing - # TODO: Is this needed? The example in the readme doesn't set it - # assert "note:" in query, f"You have to specify a notetype in the query ({query})" + assert "note:" in query, f"You have to specify a notetype in the query ({query})" assert mode in ["reformulate", "reset"], "Invalid value for 'mode'" assert isinstance(exclude_done, bool), "exclude_done must be a boolean" assert isinstance(exclude_version, bool), "exclude_version must be a boolean" @@ -266,21 +265,26 @@ def handle_exception(exc_type, exc_value, exc_traceback): query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\"" # load db just in case - - # TODO: How is the user supposed to create the database in the first place? - # self.db_content = self.load_db() - # if not self.db_content: - # red("Empty database. If you have already ran anki_reformulator " - # "before then something went wrong!") - # else: - # self.compute_cost(self.db_content) + self.db_content = self.load_db() + if not self.db_content: + red("Empty database. If you have already ran anki_reformulator " + "before then something went wrong!") + whi("Trying to create a new database") + self.save_to_db({}) + self.db_content = self.load_db() + assert self.db_content, "Could not create database" + + # TODO: What should be in the database normally? This fails with an empty database + whi("Computing estimated costs") + # self.compute_cost(self.db_content) # load dataset + whi("Loading dataset") dataset = load_dataset(dataset_path) - # check that each note is valid but exclude the system prompt - for id, d in enumerate(dataset): - if id != 0: - dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"] + # check that each note is valid but exclude the system prompt, which is + # the first entry + for id, d in enumerate(dataset[1:]): + dataset[id]["content"] = self.cloze_input_parser(d["content"]) if iscloze(d["content"]) else d["content"] assert len(dataset) % 2 == 1, "Even number of examples in dataset" self.dataset = dataset @@ -293,15 +297,17 @@ def handle_exception(exc_type, exc_value, exc_traceback): if nids: red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}") - # find notes ids for the first time - nids = anki(action="findNotes", query=query) + # find notes ids for the specific note type + nids = anki(action="findNotes", query="note:AnkiAITest") + # nids = anki(action="findNotes", query=query) assert nids, f"No notes found for the query '{query}'" - # find the model field names + # find the field names for this note type fields = anki(action="notesInfo", notes=[int(nids[0])])[0]["fields"] - # assert "AnkiReformulator" in fields.keys(), \ - # "The notetype to edit must have a field called 'AnkiReformulator'" + assert "AnkiReformulator" in fields.keys(), \ + "The notetype to edit must have a field called 'AnkiReformulator'" + # NOTE: This gets the first field. Is that what we want? Or do we specifically want the AnkiReformulator field? self.field_name = list(fields.keys())[0] if self.exclude_media: @@ -323,13 +329,13 @@ def handle_exception(exc_type, exc_value, exc_traceback): anki(action="notesInfo", notes=nids) ).set_index("noteId") self.notes = self.notes.loc[nids] - assert not self.notes.empty, "Empty notes df" + assert not self.notes.empty, "Empty notes" assert len(set(self.notes["modelName"].tolist())) == 1, \ "Contains more than 1 note type" # check absence of image and sounds in the main field - # as well incorrect tags + # as well as incorrect tags for nid, note in self.notes.iterrows(): if self.exclude_media: _, media = replace_media( @@ -355,35 +361,23 @@ def handle_exception(exc_type, exc_value, exc_traceback): assert not tag.lower().startswith("ankireformulator") # check if too many tokens - tkn_sum = sum([tkn_len(d["content"]) for d in self.dataset]) - tkn_sum += sum( - tkn_len( - replace_media( - content=note["fields"][self.field_name]["value"], - media=None, - mode="remove_media", - )[0] - ) - for _, note in self.notes.iterrows() - ) - if tkn_sum > tkn_warn_limit: - raise Exception( - f"Found {tkn_sum} tokens to process, which is " - f"higher than the limit of {tkn_warn_limit}" - ) + tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset) + tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"], + media=None, + mode="remove_media")[0]) + for _, note in self.notes.iterrows()) + assert tkn_sum <= tkn_warn_limit, (f"Found {tkn_sum} tokens to process, which is " + f"higher than the limit of {tkn_warn_limit}") - if len(self.notes) > n_note_limit: - raise Exception( - f"Found {len(self.notes)} notes to process " - f"which is higher than the limit of {n_note_limit}" - ) + assert len(self.notes) <= n_note_limit, (f"Found {len(self.notes)} notes to process " + f"which is higher than the limit of {n_note_limit}") if self.mode == "reformulate": func = self.reformulate elif self.mode == "reset": func = self.reset else: - raise ValueError(self.mode) + raise ValueError(f"Unknown mode {self.mode}") def error_wrapped_func(*args, **kwargs): """Wrapper that catches exceptions and marks failed notes with appropriate tags.""" @@ -407,11 +401,9 @@ def error_wrapped_func(*args, **kwargs): ) ) - failed_runs = [ - self.notes.iloc[i_nv] - for i_nv in range(len(new_values)) - if isinstance(new_values[i_nv], str) - ] + failed_runs = [self.notes.iloc[i_nv] + for i_nv in range(len(new_values)) + if isinstance(new_values[i_nv], str)] if failed_runs: red(f"Found {len(failed_runs)} failed notes") failed_run_index = pd.DataFrame(failed_runs).index @@ -421,6 +413,7 @@ def error_wrapped_func(*args, **kwargs): assert len(new_values) == len(self.notes) # applying the changes + whi("Applying changes") for values in tqdm(new_values, desc="Applying changes to anki"): if self.mode == "reformulate": self.apply_reformulate(values) @@ -429,8 +422,10 @@ def error_wrapped_func(*args, **kwargs): else: raise ValueError(self.mode) + whi("Clearing unused tags") anki(action="clearUnusedTags") + # TODO: Why add and them remove them? # add and remove the tag TODO to make it easier to re add by the user # as it was cleared by calling 'clearUnusedTags' nid, note = next(self.notes.iterrows()) @@ -439,7 +434,7 @@ def error_wrapped_func(*args, **kwargs): sync_anki() - # display again the total cost at the end + # display the total cost again at the end db = self.load_db() assert db, "Empty database at the end of the run. Something went wrong?" self.compute_cost(db) @@ -450,10 +445,11 @@ def compute_cost(self, db_content: List[Dict]) -> None: This is used to know if something went wrong. """ n_db = len(db_content) - red(f"Number of entries in databases/reformulator.db: {n_db}") + red(f"Number of entries in databases/reforumulator/reformulator.db: {n_db}") dol_costs = [] dol_missing = 0 for dic in db_content: + # TODO: Mode isn't a field in the reformulator database dictionaries table if dic["mode"] != "reformulate": continue try: @@ -478,14 +474,14 @@ def compute_cost(self, db_content: List[Dict]) -> None: def reformulate(self, nid: int, note: pd.Series) -> Dict: """Generate a reformulated version of a note's content using an LLM. - + Parameters ---------- nid : int Note ID from Anki note : pd.Series Row from the notes DataFrame containing the note data - + Returns ------- Dict @@ -622,7 +618,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: def apply_reformulate(self, log: Dict) -> None: """Apply reformulation changes to an Anki note and update its metadata. - + Parameters ---------- log : Dict @@ -692,14 +688,14 @@ def apply_reformulate(self, log: Dict) -> None: def reset(self, nid: int, note: pd.Series) -> Dict: """Reset a note back to its state before reformulation. - + Parameters ---------- nid : int Note ID from Anki note : pd.Series Row from the notes DataFrame containing the note data - + Returns ------- Dict @@ -873,7 +869,7 @@ def reset(self, nid: int, note: pd.Series) -> Dict: def apply_reset(self, log: Dict) -> None: """Apply reset changes to an Anki note and update its metadata. - + Parameters ---------- log : Dict @@ -940,12 +936,12 @@ def apply_reset(self, log: Dict) -> None: def save_to_db(self, dictionnary: Dict) -> bool: """Save a log dictionary to the SQLite database. - + Parameters ---------- dictionnary : Dict Log dictionary to save - + Returns ------- bool @@ -970,7 +966,7 @@ def save_to_db(self, dictionnary: Dict) -> bool: def load_db(self) -> Dict: """Load all log dictionaries from the SQLite database. - + Returns ------- Dict @@ -999,5 +995,8 @@ def load_db(self) -> Dict: whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'") AnkiReformulator(*args, **kwargs) sync_anki() - except Exception: + except AssertionError as e: + red(e) + except Exception as e: + red(e) raise diff --git a/utils/datasets.py b/utils/datasets.py index df46e84..6ba4278 100644 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -1,3 +1,4 @@ +import collections import json import pandas as pd import litellm @@ -51,9 +52,10 @@ def load_dataset( Returns ------- - Dict + List List of message dictionaries with 'role' and 'content' keys, - validated according to check_dataset() rules + validated according to check_dataset() rules. + First message is the system message. Raises ------ @@ -366,25 +368,21 @@ def semantic_prompt_filtering( if len(output_pr) != len(prompt_messages): red(f"Tokens of the kept prompts after {cnt} iterations: {tkns} (of all prompts: {all_tkns} tokens) Number of prompts: {len(output_pr)}/{len(prompt_messages)}") - # check no duplicate prompts + # TODO: This complains about duplicates. It looks like for some reason the + # last one is added as assistant AND user, but we only compare the content. contents = [pm["content"] for pm in output_pr] - dupli = [dp for dp in contents if contents.count(dp) > 1] + dupli = [k for k,v in collections.Counter(contents).items() if v>1] if dupli: raise Exception(f"{len(dupli)} duplicate prompts found in memory.py: {dupli}") - # remove unwanted keys - for i, d in enumerate(output_pr): - keys = [k for k in d.keys()] - for k in keys: - if k not in ["content", "role"]: - del d[k] - output_pr[i] = d + # Keep only the content and the role keys for each prompt + new_output = [{k: v for k, v in pk.items() if k in {"content", "role"}} for pk in output_pr] - assert curr_mess not in output_pr - assert output_pr, "No prompt were selected!" - check_dataset(output_pr, **check_args) + assert curr_mess not in new_output + assert new_output, "No prompt were selected!" + check_dataset(new_output, **check_args) - return output_pr + return new_output def format_anchor_key(key: str) -> str: From 212b88c165c157e2ceb53755e1b7f92d983df768 Mon Sep 17 00:00:00 2001 From: Grazfather Date: Wed, 8 Jan 2025 15:58:17 -0500 Subject: [PATCH 03/10] more --- reformulator.py | 10 ++++++---- utils/cloze_utils.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/reformulator.py b/reformulator.py index 305e18d..abb6107 100644 --- a/reformulator.py +++ b/reformulator.py @@ -385,7 +385,7 @@ def error_wrapped_func(*args, **kwargs): return func(*args, **kwargs) except Exception as err: addtags(nid=note.name, tags="AnkiReformulator::FAILED") - red(f"Error when running self.{self.mode}: '{err}'") + red(f"Error when running self.{func.__name__}: '{err}'") return str(err) # getting all the new values in parallel and using caching @@ -525,8 +525,7 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: elif d["role"] == "user": newcontent = self.dataset[i + 1]["content"] else: - raise ValueError( - f"Unexpected role of message in dataset: {d}") + raise ValueError(f"Unexpected role of message in dataset: {d}") skip_llm = True break @@ -602,8 +601,11 @@ def reformulate(self, nid: int, note: pd.Series) -> Dict: log["note_field_formattednewcontent"] = formattednewcontent log["status"] = STAT_OK_REFORM - if iscloze(content + newcontent + formattednewcontent): + if iscloze(content) and iscloze( newcontent + formattednewcontent): # check that no cloze were lost + # TODO: Bug here: `iscloze` can return true if the new content is a + # close, but if the original content is not a cloze, then this + # fails for cl in getclozes(content): cl = cl.split("::")[0] + "::" assert cl.startswith("{{c") and cl in content diff --git a/utils/cloze_utils.py b/utils/cloze_utils.py index e9a1a6c..05e3382 100644 --- a/utils/cloze_utils.py +++ b/utils/cloze_utils.py @@ -18,7 +18,7 @@ def iscloze(text: str) -> bool: def getclozes(text: str) -> List[str]: "return the cloze found in the text. Should only be called on cloze notes" - assert iscloze(text) + assert iscloze(text), f"Text '{text}' does not contain a cloze" return re.findall(CLOZE_REGEX, text) From c364342f849dab6b533e5d40e9594a910aeb96a1 Mon Sep 17 00:00:00 2001 From: Grazfather Date: Wed, 8 Jan 2025 16:03:20 -0500 Subject: [PATCH 04/10] Remove API key loading --- README.md | 2 +- reformulator.py | 5 +---- utils/llm.py | 24 ------------------------ 3 files changed, 2 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 7df1d63..1ca408e 100644 --- a/README.md +++ b/README.md @@ -387,7 +387,7 @@ Dataset files (like `explainer_dataset.txt`, `reformulator_dataset.txt`, etc.) a Click to read more -First, create an _API_KEYS/_ directory and place your API key in a separate file, or ensure that you API keys are set in you env variables. +First, ensure that you API keys are set in you env variables. Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it. diff --git a/reformulator.py b/reformulator.py index abb6107..70d72ba 100644 --- a/reformulator.py +++ b/reformulator.py @@ -31,7 +31,7 @@ import litellm from utils.misc import load_formatting_funcs, replace_media -from utils.llm import load_api_keys, llm_price, tkn_len, chat, model_name_matcher +from utils.llm import llm_price, tkn_len, chat, model_name_matcher from utils.anki import anki, sync_anki, addtags, removetags, updatenote from utils.logger import create_loggers from utils.datasets import load_dataset, semantic_prompt_filtering @@ -51,9 +51,6 @@ d = datetime.datetime.today() today = f"{d.day:02d}_{d.month:02d}_{d.year:04d}" -whi("Loading api keys") -load_api_keys() - # status string STAT_CHANGED_CONT = "Content has been changed" diff --git a/utils/llm.py b/utils/llm.py index 2293d55..24becb9 100644 --- a/utils/llm.py +++ b/utils/llm.py @@ -13,30 +13,6 @@ litellm.drop_params = True -def load_api_keys() -> Dict: - """Load API keys from files in the API_KEYS directory. - - Creates API_KEYS directory if it doesn't exist. - Each file in API_KEYS/ should contain a single API key. - The filename (without extension) becomes part of the environment variable name. - - Returns - ------- - Dict - Dictionary mapping environment variable names to API key values - """ - Path("API_KEYS").mkdir(exist_ok=True) - if not list(Path("API_KEYS").iterdir()): - shared.red("## No API_KEYS found in API_KEYS") - raise Exception("Need to write API KEYS to API_KEYS/") - api_keys = {} - for apifile in Path("API_KEYS").iterdir(): - keyname = f"{apifile.stem.upper()}_API_KEY" - key = apifile.read_text().strip() - os.environ[keyname] = key - api_keys[keyname] = key - return api_keys - llm_price = {} for k, v in litellm.model_cost.items(): From 0c0aec0a4f7ebd6efb31286099ae8d4f59d3a6ee Mon Sep 17 00:00:00 2001 From: Grazfather Date: Wed, 8 Jan 2025 16:03:33 -0500 Subject: [PATCH 05/10] more --- utils/llm.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/utils/llm.py b/utils/llm.py index 24becb9..adc27c2 100644 --- a/utils/llm.py +++ b/utils/llm.py @@ -18,20 +18,21 @@ for k, v in litellm.model_cost.items(): llm_price[k] = v -embedding_models = [ - "openai/text-embedding-3-large", - "openai/text-embedding-3-small", - "mistral/mistral-embed", - ] +embedding_models = ["openai/text-embedding-3-large", + "openai/text-embedding-3-small", + "mistral/mistral-embed"] # steps : price -sd_price = { - "15": 0.001, - "30": 0.002, - "50": 0.004, - "100": 0.007, - "150": "0.01", -} +sd_price = {"15": 0.001, + "30": 0.002, + "50": 0.004, + "100": 0.007, + # NOTE: Why is this one a string? + "150": "0.01"} + +tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") +llm_cache = Memory(".cache", verbose=0) + def llm_cost_compute( input_cost: int, @@ -56,9 +57,6 @@ def llm_cost_compute( return input_cost * price["input_cost_per_token"] + output_cost * price["output_cost_per_token"] -tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") - - def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]): if isinstance(message, str): return len(tokenizer.encode(dedent(message))) @@ -67,7 +65,6 @@ def tkn_len(message: Union[str, List[Union[str, Dict]], Dict]): elif isinstance(message, list): return sum([tkn_len(subel) for subel in message]) -llm_cache = Memory(".cache", verbose=0) @llm_cache.cache def chat( @@ -89,6 +86,7 @@ def chat( assert all(a["finish_reason"] == "stop" for a in answer["choices"]), f"Found bad finish_reason: '{answer}'" return answer + def wrapped_model_name_matcher(model: str) -> str: "find the best match for a modelname (wrapped to make some check)" # find the currently set api keys to avoid matching models from @@ -124,10 +122,11 @@ def wrapped_model_name_matcher(model: str) -> str: return match[0] else: print(f"Couldn't match the modelname {model} to any known model. " - "Continuing but this will probably crash DocToolsLLM further " - "down the code.") + "Continuing but this will probably crash DocToolsLLM further " + "down the code.") return model + def model_name_matcher(model: str) -> str: """find the best match for a modelname (wrapper that checks if the matched model has a known cost and print the matched name)""" From cf2fca542fafcdcd4502fcdad3a8386defb83097 Mon Sep 17 00:00:00 2001 From: Grazfather Date: Thu, 9 Jan 2025 09:29:07 -0500 Subject: [PATCH 06/10] More fixes and Qs --- README.md | 10 +++------- reformulator.py | 23 ++++++++++------------- utils/cloze_utils.py | 8 +++++--- utils/llm.py | 6 ++---- 4 files changed, 20 insertions(+), 27 deletions(-) diff --git a/README.md b/README.md index 1ca408e..22f5171 100644 --- a/README.md +++ b/README.md @@ -391,19 +391,15 @@ First, ensure that you API keys are set in you env variables. Next, install the [AnkiConnect](https://ankiweb.net/shared/info/2055492159) Anki addon if you don't already have it. -#### Reformulator - -Next... create a database? it expects a sqlite db in databases/reformulator/reformulator? -* Can handle it in code -Next... something about adding a field called `AnkiReformulator` to notes you want to change? -* Do you have to create a special note type for this to work? +#### Reformulator +The reformulator expects the notes you modify to have a specific field present so that it can save the old versions and add logging. Modify the note type you want to reformulate by adding a `AnkiReformulator` field to it. The Reformulator can be run from the command line: ```bash python reformulator.py \ - --query "note:Basic (rated:2:1 OR rated:2:2) -is:suspended" \ + --query "note:Cloze (rated:2:1 OR rated:2:2) -is:suspended" \ --dataset_path "examples/reformulator_dataset.txt" \ --string_formatting "examples/string_formatting.py" \ --ntfy_url "ntfy.sh/YOUR_TOPIC" \ diff --git a/reformulator.py b/reformulator.py index 70d72ba..c691812 100644 --- a/reformulator.py +++ b/reformulator.py @@ -201,7 +201,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): print(json.dumps(db_content, ensure_ascii=False, indent=4)) return else: - # sync_anki() + sync_anki() assert query is not None, "Must specify --query" assert dataset_path is not None, "Must specify --dataset_path" litellm.set_verbose = verbose @@ -271,9 +271,8 @@ def handle_exception(exc_type, exc_value, exc_traceback): self.db_content = self.load_db() assert self.db_content, "Could not create database" - # TODO: What should be in the database normally? This fails with an empty database whi("Computing estimated costs") - # self.compute_cost(self.db_content) + self.compute_cost(self.db_content) # load dataset whi("Loading dataset") @@ -295,8 +294,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): red(f"Found {len(nids)} notes with tag AnkiReformulator::DOING : {nids}") # find notes ids for the specific note type - nids = anki(action="findNotes", query="note:AnkiAITest") - # nids = anki(action="findNotes", query=query) + nids = anki(action="findNotes", query=query) assert nids, f"No notes found for the query '{query}'" # find the field names for this note type @@ -304,8 +302,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): notes=[int(nids[0])])[0]["fields"] assert "AnkiReformulator" in fields.keys(), \ "The notetype to edit must have a field called 'AnkiReformulator'" - # NOTE: This gets the first field. Is that what we want? Or do we specifically want the AnkiReformulator field? - self.field_name = list(fields.keys())[0] + self.field_name = list(fields.keys())[self.field_index] if self.exclude_media: # now find notes ids after excluding the img in the important field @@ -316,7 +313,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): query += f' -{self.field_name}:"*http://*"' query += f' -{self.field_name}:"*https://*"' - whi(f"Query to find note: {query}") + whi(f"Query to find note: '{query}'") nids = anki(action="findNotes", query=query) assert nids, f"No notes found for the query '{query}'" whi(f"Found {len(nids)} notes") @@ -357,7 +354,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): else: assert not tag.lower().startswith("ankireformulator") - # check if too many tokens + # check if required tokens are higher than our limits tkn_sum = sum(tkn_len(d["content"]) for d in self.dataset) tkn_sum += sum(tkn_len(replace_media(content=note["fields"][self.field_name]["value"], media=None, @@ -442,12 +439,11 @@ def compute_cost(self, db_content: List[Dict]) -> None: This is used to know if something went wrong. """ n_db = len(db_content) - red(f"Number of entries in databases/reforumulator/reformulator.db: {n_db}") + red(f"Number of entries in databases/reformulator/reformulator.db: {n_db}") dol_costs = [] dol_missing = 0 for dic in db_content: - # TODO: Mode isn't a field in the reformulator database dictionaries table - if dic["mode"] != "reformulate": + if self.mode != "reformulate": continue try: dol = float(dic["dollar_price"]) @@ -640,7 +636,7 @@ def apply_reformulate(self, log: Dict) -> None: new_minilog = rtoml.dumps(minilog, pretty=True) new_minilog = new_minilog.strip().replace("\n", "
") - previous_minilog = note["fields"]["AnkiReformulator"]["value"].strip() + previous_minilog = note["fields"].get("AnkiReformulator", {}).get("value", "").strip() if previous_minilog: new_minilog += "" new_minilog += "

Older minilog" @@ -670,6 +666,7 @@ def apply_reformulate(self, log: Dict) -> None: nid, fields={ self.field_name: log["note_field_formattednewcontent"], + # TODO: Might be nice to not require this "AnkiReformulator": new_minilog, }, ) diff --git a/utils/cloze_utils.py b/utils/cloze_utils.py index 05e3382..fae06e9 100644 --- a/utils/cloze_utils.py +++ b/utils/cloze_utils.py @@ -27,6 +27,7 @@ def cloze_input_parser(cloze: str) -> str: if you use weird formatting that mess with LLMs""" assert iscloze(cloze), f"Invalid cloze: {cloze}" + # TODO: What is this? cloze = cloze.replace("\xa0", " ") # make newlines consistent @@ -37,7 +38,6 @@ def cloze_input_parser(cloze: str) -> str: # make spaces consitent cloze = cloze.replace(" ", " ") - # misc cloze = cloze.replace(">", ">") cloze = cloze.replace("≥", ">=") @@ -57,9 +57,12 @@ def cloze_output_parser(cloze: str) -> str: cloze = cloze.strip() # make sure all newlines are consistent for now + # TODO: You mean
? cloze = cloze.replace("
", "
") + cloze = cloze.replace("
", "
") cloze = cloze.replace("\r", "
") - cloze = cloze.replace("
", "\n") + # TODO: Not needed + # cloze = cloze.replace("
", "\n") # make sure all spaces are consistent cloze = cloze.replace(" ", " ") @@ -68,4 +71,3 @@ def cloze_output_parser(cloze: str) -> str: cloze = cloze.replace("\n", "
") return cloze - diff --git a/utils/llm.py b/utils/llm.py index adc27c2..b399a4c 100644 --- a/utils/llm.py +++ b/utils/llm.py @@ -27,8 +27,7 @@ "30": 0.002, "50": 0.004, "100": 0.007, - # NOTE: Why is this one a string? - "150": "0.01"} + "150": 0.01} tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") llm_cache = Memory(".cache", verbose=0) @@ -122,8 +121,7 @@ def wrapped_model_name_matcher(model: str) -> str: return match[0] else: print(f"Couldn't match the modelname {model} to any known model. " - "Continuing but this will probably crash DocToolsLLM further " - "down the code.") + "Continuing but this will probably crash further down the code.") return model From 811d4e0f64c8d41d5cdeaee8446f92dcc61dba2e Mon Sep 17 00:00:00 2001 From: Grazfather Date: Thu, 9 Jan 2025 09:32:49 -0500 Subject: [PATCH 07/10] Better error for main field index --- reformulator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/reformulator.py b/reformulator.py index c691812..595396b 100644 --- a/reformulator.py +++ b/reformulator.py @@ -302,7 +302,11 @@ def handle_exception(exc_type, exc_value, exc_traceback): notes=[int(nids[0])])[0]["fields"] assert "AnkiReformulator" in fields.keys(), \ "The notetype to edit must have a field called 'AnkiReformulator'" - self.field_name = list(fields.keys())[self.field_index] + try: + self.field_name = list(fields.keys())[self.field_index] + except IndexError: + raise AssertionError(f"main_field_index {self.field_index} is invalid. " + f"Note only has {len(fields.keys())} fields!") if self.exclude_media: # now find notes ids after excluding the img in the important field From f34de0491ea56fda8a82fea243d60b54ef1fe320 Mon Sep 17 00:00:00 2001 From: Grazfather Date: Thu, 9 Jan 2025 09:41:58 -0500 Subject: [PATCH 08/10] Add reformulate method so all work is not done in init --- reformulator.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/reformulator.py b/reformulator.py index 595396b..f2de377 100644 --- a/reformulator.py +++ b/reformulator.py @@ -222,7 +222,12 @@ def handle_exception(exc_type, exc_value, exc_traceback): parallel = int(parallel) main_field_index = int(main_field_index) assert main_field_index >= 0, "invalid field_index" + self.base_query = query + self.dataset_path = dataset_path self.mode = mode + self.exclude_done = exclude_done + self.exclude_version = exclude_version + if string_formatting: red(f"Loading specific string formatting from {string_formatting}") cloze_input_parser, cloze_output_parser = load_formatting_funcs( @@ -254,11 +259,14 @@ def handle_exception(exc_type, exc_value, exc_traceback): else: raise Exception(f"{llm} not found in llm_price") self.verbose = verbose - if mode == "reformulate": - if exclude_done: + + def reformulate(self): + query = self.base_query + if self.mode == "reformulate": + if self.exclude_done: query += " -AnkiReformulator::Done::*" - if exclude_version: + if self.exclude_version: query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\"" # load db just in case @@ -276,7 +284,7 @@ def handle_exception(exc_type, exc_value, exc_traceback): # load dataset whi("Loading dataset") - dataset = load_dataset(dataset_path) + dataset = load_dataset(self.dataset_path) # check that each note is valid but exclude the system prompt, which is # the first entry for id, d in enumerate(dataset[1:]): @@ -371,9 +379,9 @@ def handle_exception(exc_type, exc_value, exc_traceback): f"which is higher than the limit of {n_note_limit}") if self.mode == "reformulate": - func = self.reformulate + func = self.reformulate_note elif self.mode == "reset": - func = self.reset + func = self.reset_note else: raise ValueError(f"Unknown mode {self.mode}") @@ -469,7 +477,7 @@ def compute_cost(self, db_content: List[Dict]) -> None: elif dol_costs: self._cost_so_far = dol_total - def reformulate(self, nid: int, note: pd.Series) -> Dict: + def reformulate_note(self, nid: int, note: pd.Series) -> Dict: """Generate a reformulated version of a note's content using an LLM. Parameters @@ -686,7 +694,7 @@ def apply_reformulate(self, log: Dict) -> None: # remove DOING tag removetags(nid, "AnkiReformulator::DOING") - def reset(self, nid: int, note: pd.Series) -> Dict: + def reset_note(self, nid: int, note: pd.Series) -> Dict: """Reset a note back to its state before reformulation. Parameters @@ -993,7 +1001,8 @@ def load_db(self) -> Dict: print(help(AnkiReformulator), file=sys.stderr) else: whi(f"Launching reformulator.py with args '{args}' and kwargs '{kwargs}'") - AnkiReformulator(*args, **kwargs) + r = AnkiReformulator(*args, **kwargs) + r.reformulate() sync_anki() except AssertionError as e: red(e) From 25813f69bdbc1709225f61ad7aeb8d3dd8406f43 Mon Sep 17 00:00:00 2001 From: Grazfather Date: Thu, 9 Jan 2025 15:29:21 -0500 Subject: [PATCH 09/10] more cleanup --- reformulator.py | 54 +++++++++++++++++--------------------------- utils/cloze_utils.py | 9 ++++---- 2 files changed, 26 insertions(+), 37 deletions(-) diff --git a/reformulator.py b/reformulator.py index f2de377..dbd0817 100644 --- a/reformulator.py +++ b/reformulator.py @@ -269,12 +269,12 @@ def reformulate(self): if self.exclude_version: query += f" -AnkiReformulator:\"*version*=*'{self.VERSION}'*\"" - # load db just in case + # load db just in case, and create one if it doesn't already exist self.db_content = self.load_db() if not self.db_content: red("Empty database. If you have already ran anki_reformulator " "before then something went wrong!") - whi("Trying to create a new database") + whi("Creating a empty database") self.save_to_db({}) self.db_content = self.load_db() assert self.db_content, "Could not create database" @@ -507,7 +507,7 @@ def reformulate_note(self, nid: int, note: pd.Series) -> Dict: # reformulate the content content = note["fields"][self.field_name]["value"] log["note_field_content"] = content - formattedcontent = self.cloze_input_parser(content) if iscloze(content) else content + formattedcontent = self.cloze_input_parser(content) log["note_field_formattedcontent"] = formattedcontent # if the card is in the dataset, just take the dataset value directly @@ -537,11 +537,13 @@ def reformulate_note(self, nid: int, note: pd.Series) -> Dict: fc, media = replace_media( content=formattedcontent, media=None, - mode="remove_media", - ) + mode="remove_media") log["media"] = media - if not skip_llm: + if skip_llm: + log["llm_answer"] = {"Skipped": True} + log["dollar_price"] = 0 + else: dataset = copy.deepcopy(self.dataset) curr_mess = [{"role": "user", "content": fc}] dataset = semantic_prompt_filtering( @@ -553,8 +555,7 @@ def reformulate_note(self, nid: int, note: pd.Series) -> Dict: embedding_model=self.embedding_model, whi=whi, yel=yel, - red=red, - ) + red=red) dataset += curr_mess assert dataset[0]["role"] == "system", "First message is not from system!" @@ -597,20 +598,14 @@ def reformulate_note(self, nid: int, note: pd.Series) -> Dict: ) else: log["dollar_price"] = "?" - else: - log["llm_answer"] = {"Skipped": True} - log["dollar_price"] = 0 log["note_field_newcontent"] = newcontent - formattednewcontent = self.cloze_output_parser(newcontent) if iscloze(newcontent) else newcontent + formattednewcontent = self.cloze_output_parser(newcontent) log["note_field_formattednewcontent"] = formattednewcontent log["status"] = STAT_OK_REFORM if iscloze(content) and iscloze( newcontent + formattednewcontent): # check that no cloze were lost - # TODO: Bug here: `iscloze` can return true if the new content is a - # close, but if the original content is not a cloze, then this - # fails for cl in getclozes(content): cl = cl.split("::")[0] + "::" assert cl.startswith("{{c") and cl in content @@ -734,18 +729,14 @@ def reset_note(self, nid: int, note: pd.Series) -> Dict: ] if not entries: - red( - f"Entry not found for note {nid}. Looking for the content of " - "the field AnkiReformulator" - ) + red(f"Entry not found for note {nid}. Looking for the content of " + "the field AnkiReformulator") logfield = note["fields"]["AnkiReformulator"]["value"] logfield = logfield.split( "")[0] # keep most recent if not logfield.strip(): - raise Exception( - f"Note {nid} was not found in the db and its " - "AnkiReformulator field was empty." - ) + raise Exception(f"Note {nid} was not found in the db and its " + "AnkiReformulator field was empty.") # replace the [[c1::cloze]] by {{c1::cloze}} logfield = logfield.replace("]]", "}}") @@ -755,7 +746,7 @@ def reset_note(self, nid: int, note: pd.Series) -> Dict: # parse old content buffer = [] - for i, line in enumerate(logfield.split("
")): + for line in logfield.split("
"): if buffer: try: _ = rtoml.loads("".join(buffer + [line])) @@ -774,10 +765,12 @@ def reset_note(self, nid: int, note: pd.Series) -> Dict: # parse new content at the time buffer = [] - for i, line in enumerate(logfield.split("
")): + for line in logfield.split("
"): if buffer: try: - _ = rtoml.loads("".join(buffer + [line])) + # TODO: What are you trying to do here? Just check that adding the line keeps valid toml? + # If so, you should catch the specific exception that the load function raises on error + rtoml.loads("".join(buffer + [line])) buffer.append(line) continue except Exception: @@ -931,10 +924,8 @@ def apply_reset(self, log: Dict) -> None: # remove TO_RESET tag if present removetags(nid, "AnkiReformulator::TO_RESET") - # remove Done tag removetags(nid, "AnkiReformulator::Done") - # remove DOING tag removetags(nid, "AnkiReformulator::RESETTING") @@ -987,11 +978,8 @@ def load_db(self) -> Dict: cursor = conn.cursor() cursor.execute("SELECT data FROM dictionaries") rows = cursor.fetchall() - dictionaries = [] - for row in rows: - dictionary = json.loads(zlib.decompress(row[0])) - dictionaries.append(dictionary) - return dictionaries + # TODO: Why do you compress? This just makes it more difficult to debug + return [json.loads(zlib.decompress(row[0]) for row in rows] if __name__ == "__main__": diff --git a/utils/cloze_utils.py b/utils/cloze_utils.py index fae06e9..52782c3 100644 --- a/utils/cloze_utils.py +++ b/utils/cloze_utils.py @@ -23,11 +23,12 @@ def getclozes(text: str) -> List[str]: def cloze_input_parser(cloze: str) -> str: - """edits the cloze from anki before sending it to the LLM. This is useful - if you use weird formatting that mess with LLMs""" - assert iscloze(cloze), f"Invalid cloze: {cloze}" + """edit the cloze from anki before sending it to the LLM. This is useful + if you use weird formatting that mess with LLMs. + If the note content is not a cloze, then return it unmodified.""" + if not iscloze(cloze): + return cloze - # TODO: What is this? cloze = cloze.replace("\xa0", " ") # make newlines consistent From f907ea1d5d5ab6a8d67e3af8fe53e17fe31d32cc Mon Sep 17 00:00:00 2001 From: Grazfather Date: Fri, 10 Jan 2025 11:46:47 -0500 Subject: [PATCH 10/10] fix --- reformulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reformulator.py b/reformulator.py index dbd0817..25ae566 100644 --- a/reformulator.py +++ b/reformulator.py @@ -979,7 +979,7 @@ def load_db(self) -> Dict: cursor.execute("SELECT data FROM dictionaries") rows = cursor.fetchall() # TODO: Why do you compress? This just makes it more difficult to debug - return [json.loads(zlib.decompress(row[0]) for row in rows] + return [json.loads(zlib.decompress(row[0])) for row in rows] if __name__ == "__main__":