diff --git a/examples/Detoxify.ipynb b/examples/Detoxify.ipynb
index 09bbe2f..f23d8a8 100644
--- a/examples/Detoxify.ipynb
+++ b/examples/Detoxify.ipynb
@@ -188,6 +188,14 @@
"text"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "1eb7719e30054304",
+ "metadata": {},
+ "source": [
+ "## Initializing TMaRCo"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 5,
@@ -198,6 +206,64 @@
"tmarco = TMaRCo()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "3e16ee305f4983d9",
+ "metadata": {},
+ "source": [
+ "This will initialize `TMaRCo` using the default models, taken from HuggingFace.\n",
+ "
\n",
+ "To use local models with TMaRCo, we need to have the pre-initialized models in a local storage that is accessible to TMaRCo.\n",
+ "
\n",
+ "For instance, to use the default `facebook/bart-large` model, but locally. First, we would need to retrieve the model:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "614c9ff6f46a0ea9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import snapshot_download\n",
+ "\n",
+ "snapshot_download(repo_id=\"facebook/bart-large\", local_dir=\"models/bart\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "95bd792e757205d6",
+ "metadata": {},
+ "source": [
+ "We now initialize the base model and tokenizer from local files and pass them to `TMaRCo`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f0f24485822a7c3f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import BartForConditionalGeneration, BartTokenizer\n",
+ "\n",
+ "tokenizer = BartTokenizer.from_pretrained(\n",
+ " \"models/bart\", # Or directory where the local model is stored\n",
+ " is_split_into_words=True, add_prefix_space=True\n",
+ ")\n",
+ "\n",
+ "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
+ "\n",
+ "base = BartForConditionalGeneration.from_pretrained(\n",
+ " \"models/bart\", # Or directory where the local model is stored\n",
+ " max_length=150,\n",
+ " forced_bos_token_id=tokenizer.bos_token_id,\n",
+ ")\n",
+ "\n",
+ "# Initialize TMaRCo with local models\n",
+ "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 7,
@@ -223,6 +289,32 @@
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "c113208c527c342e",
+ "metadata": {},
+ "source": [
+ "\n",
+ "To use local expert/anti-expert models with TMaRCo, we need to have them in a local storage that is accessible to TMaRCo, as previously.\n",
+ "\n",
+ "However, we don't need to initialize them separately, and can pass the directory directly.\n",
+ "
\n",
+ "If we want to use local models with `TMaRCo` (in this case the same default `gminus`/`gplus`):\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dfa288dcb60102c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "snapshot_download(repo_id=\"trustyai/gminus\", local_dir=\"models/gminus\")\n",
+ "snapshot_download(repo_id=\"trustyai/gplus\", local_dir=\"models/gplus\")\n",
+ "\n",
+ "tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 13,
@@ -362,6 +454,25 @@
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "b0738c324227f57",
+ "metadata": {},
+ "source": [
+ "As noted previously, to use local models, simply pass the initialized tokenizer and base model to the constructor, and the local path as the expert/anti-expert:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b929e21a97ea914e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)\n",
+ "tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
+ ]
+ },
{
"cell_type": "markdown",
"id": "5303f56b-85ff-40da-99bf-6962cf2f3395",