diff --git a/examples/Detoxify.ipynb b/examples/Detoxify.ipynb index 09bbe2f..f23d8a8 100644 --- a/examples/Detoxify.ipynb +++ b/examples/Detoxify.ipynb @@ -188,6 +188,14 @@ "text" ] }, + { + "cell_type": "markdown", + "id": "1eb7719e30054304", + "metadata": {}, + "source": [ + "## Initializing TMaRCo" + ] + }, { "cell_type": "code", "execution_count": 5, @@ -198,6 +206,64 @@ "tmarco = TMaRCo()" ] }, + { + "cell_type": "markdown", + "id": "3e16ee305f4983d9", + "metadata": {}, + "source": [ + "This will initialize `TMaRCo` using the default models, taken from HuggingFace.\n", + "
\n", + "To use local models with TMaRCo, we need to have the pre-initialized models in a local storage that is accessible to TMaRCo.\n", + "
\n", + "For instance, to use the default `facebook/bart-large` model, but locally. First, we would need to retrieve the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "614c9ff6f46a0ea9", + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import snapshot_download\n", + "\n", + "snapshot_download(repo_id=\"facebook/bart-large\", local_dir=\"models/bart\")" + ] + }, + { + "cell_type": "markdown", + "id": "95bd792e757205d6", + "metadata": {}, + "source": [ + "We now initialize the base model and tokenizer from local files and pass them to `TMaRCo`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0f24485822a7c3f", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import BartForConditionalGeneration, BartTokenizer\n", + "\n", + "tokenizer = BartTokenizer.from_pretrained(\n", + " \"models/bart\", # Or directory where the local model is stored\n", + " is_split_into_words=True, add_prefix_space=True\n", + ")\n", + "\n", + "tokenizer.pad_token_id = tokenizer.eos_token_id\n", + "\n", + "base = BartForConditionalGeneration.from_pretrained(\n", + " \"models/bart\", # Or directory where the local model is stored\n", + " max_length=150,\n", + " forced_bos_token_id=tokenizer.bos_token_id,\n", + ")\n", + "\n", + "# Initialize TMaRCo with local models\n", + "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)" + ] + }, { "cell_type": "code", "execution_count": 7, @@ -223,6 +289,32 @@ "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])" ] }, + { + "cell_type": "markdown", + "id": "c113208c527c342e", + "metadata": {}, + "source": [ + "
\n", + "To use local expert/anti-expert models with TMaRCo, we need to have them in a local storage that is accessible to TMaRCo, as previously.\n", + "\n", + "However, we don't need to initialize them separately, and can pass the directory directly.\n", + "
\n", + "If we want to use local models with `TMaRCo` (in this case the same default `gminus`/`gplus`):\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfa288dcb60102c", + "metadata": {}, + "outputs": [], + "source": [ + "snapshot_download(repo_id=\"trustyai/gminus\", local_dir=\"models/gminus\")\n", + "snapshot_download(repo_id=\"trustyai/gplus\", local_dir=\"models/gplus\")\n", + "\n", + "tmarco.load_models([\"models/gminus\", \"models/gplus\"])" + ] + }, { "cell_type": "code", "execution_count": 13, @@ -362,6 +454,25 @@ "tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])" ] }, + { + "cell_type": "markdown", + "id": "b0738c324227f57", + "metadata": {}, + "source": [ + "As noted previously, to use local models, simply pass the initialized tokenizer and base model to the constructor, and the local path as the expert/anti-expert:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b929e21a97ea914e", + "metadata": {}, + "outputs": [], + "source": [ + "tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)\n", + "tmarco.load_models([\"models/gminus\", \"models/gplus\"])" + ] + }, { "cell_type": "markdown", "id": "5303f56b-85ff-40da-99bf-6962cf2f3395",