Skip to content

RHOAIENG-12606: Add notes on using local models #43

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions examples/Detoxify.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@
"text"
]
},
{
"cell_type": "markdown",
"id": "1eb7719e30054304",
"metadata": {},
"source": [
"## Initializing TMaRCo"
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand All @@ -198,6 +206,64 @@
"tmarco = TMaRCo()"
]
},
{
"cell_type": "markdown",
"id": "3e16ee305f4983d9",
"metadata": {},
"source": [
"This will initialize `TMaRCo` using the default models, taken from HuggingFace.\n",
"<div class=\"alert alert-info\">\n",
"To use local models with TMaRCo, we need to have the pre-initialized models in a local storage that is accessible to TMaRCo.\n",
"</div>\n",
"For instance, to use the default `facebook/bart-large` model, but locally. First, we would need to retrieve the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "614c9ff6f46a0ea9",
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import snapshot_download\n",
"\n",
"snapshot_download(repo_id=\"facebook/bart-large\", local_dir=\"models/bart\")"
]
},
{
"cell_type": "markdown",
"id": "95bd792e757205d6",
"metadata": {},
"source": [
"We now initialize the base model and tokenizer from local files and pass them to `TMaRCo`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0f24485822a7c3f",
"metadata": {},
"outputs": [],
"source": [
"from transformers import BartForConditionalGeneration, BartTokenizer\n",
"\n",
"tokenizer = BartTokenizer.from_pretrained(\n",
" \"models/bart\", # Or directory where the local model is stored\n",
" is_split_into_words=True, add_prefix_space=True\n",
")\n",
"\n",
"tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"\n",
"base = BartForConditionalGeneration.from_pretrained(\n",
" \"models/bart\", # Or directory where the local model is stored\n",
" max_length=150,\n",
" forced_bos_token_id=tokenizer.bos_token_id,\n",
")\n",
"\n",
"# Initialize TMaRCo with local models\n",
"tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)"
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand All @@ -223,6 +289,32 @@
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
]
},
{
"cell_type": "markdown",
"id": "c113208c527c342e",
"metadata": {},
"source": [
"<div class=\"alert alert-info\">\n",
"To use local expert/anti-expert models with TMaRCo, we need to have them in a local storage that is accessible to TMaRCo, as previously.\n",
"\n",
"However, we don't need to initialize them separately, and can pass the directory directly.\n",
"</div>\n",
"If we want to use local models with `TMaRCo` (in this case the same default `gminus`/`gplus`):\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfa288dcb60102c",
"metadata": {},
"outputs": [],
"source": [
"snapshot_download(repo_id=\"trustyai/gminus\", local_dir=\"models/gminus\")\n",
"snapshot_download(repo_id=\"trustyai/gplus\", local_dir=\"models/gplus\")\n",
"\n",
"tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
]
},
{
"cell_type": "code",
"execution_count": 13,
Expand Down Expand Up @@ -362,6 +454,25 @@
"tmarco.load_models([\"trustyai/gminus\", \"trustyai/gplus\"])"
]
},
{
"cell_type": "markdown",
"id": "b0738c324227f57",
"metadata": {},
"source": [
"As noted previously, to use local models, simply pass the initialized tokenizer and base model to the constructor, and the local path as the expert/anti-expert:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b929e21a97ea914e",
"metadata": {},
"outputs": [],
"source": [
"tmarco = TMaRCo(tokenizer=tokenizer, base_model=base)\n",
"tmarco.load_models([\"models/gminus\", \"models/gplus\"])"
]
},
{
"cell_type": "markdown",
"id": "5303f56b-85ff-40da-99bf-6962cf2f3395",
Expand Down
Loading