Skip to content

Commit a8dfcb1

Browse files
authored
Merge pull request #766 from k-ivey/prompt_augmentation
Add support for prompt augmentation
2 parents 29f38b2 + 6469138 commit a8dfcb1

16 files changed

+320
-1
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,27 @@ You can also create your own augmenter from scratch by importing transformations
392392
['What I cannot creae, I do not understand.', 'What I cannot creat, I do not understand.', 'What I cannot create, I do not nderstand.', 'What I cannot create, I do nt understand.', 'Wht I cannot create, I do not understand.']
393393
```
394394

395+
#### Prompt Augmentation
396+
In additional to augmentation of regular text, you can augment prompts and then generate responses to
397+
the augmented prompts using a large language model (LLMs). The augmentation is performed using the same
398+
`Augmenter` as above. To generate responses, you can use your own LLM, a HuggingFace LLM, or an OpenAI LLM.
399+
Here's an example using a pretrained HuggingFace LLM:
400+
401+
```python
402+
>>> from textattack.augmentation import EmbeddingAugmenter
403+
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
404+
>>> from textattack.llms import HuggingFaceLLMWrapper
405+
>>> from textattack.prompt_augmentation import PromptAugmentationPipeline
406+
>>> augmenter = EmbeddingAugmenter(transformations_per_example=3)
407+
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
408+
>>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
409+
>>> model_wrapper = HuggingFaceLLMWrapper(model, tokenizer)
410+
>>> pipeline = PromptAugmentationPipeline(augmenter, model_wrapper)
411+
>>> pipeline("Classify the following piece of text as `positive` or `negative`: This movie is great!")
412+
[('Classify the following piece of text as `positive` or `negative`: This film is great!', ['positive']), ('Classify the following piece of text as `positive` or `negative`: This movie is fabulous!', ['positive']), ('Classify the following piece of text as `positive` or `negative`: This movie is wonderful!', ['positive'])]
413+
```
414+
415+
395416
### Training Models: `textattack train`
396417

397418
Our model training code is available via `textattack train` to help you train LSTMs,

docs/apidoc/textattack.constraints.pre_transformation.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,15 @@ textattack.constraints.pre\_transformation package
4343
:members:
4444
:undoc-members:
4545
:show-inheritance:
46+
47+
48+
.. automodule:: textattack.constraints.pre_transformation.unmodifiable_indices
49+
:members:
50+
:undoc-members:
51+
:show-inheritance:
52+
53+
54+
.. automodule:: textattack.constraints.pre_transformation.unmodifiable_phrases
55+
:members:
56+
:undoc-members:
57+
:show-inheritance:

docs/apidoc/textattack.llms.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
textattack.llms package
2+
=========================
3+
4+
.. automodule:: textattack.llms
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
9+
10+
.. automodule:: textattack.llms.huggingface_llm_wrapper
11+
:members:
12+
:undoc-members:
13+
:show-inheritance:
14+
15+
16+
.. automodule:: textattack.llms.chat_gpt_wrapper
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
textattack.prompt_augmentation package
2+
=======================================
3+
4+
.. automodule:: textattack.prompt_augmentation
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
9+
10+
.. automodule:: textattack.prompt_augmentation.prompt_augmentation_pipeline
11+
:members:
12+
:undoc-members:
13+
:show-inheritance:

docs/apidoc/textattack.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ textattack package
1919
textattack.datasets
2020
textattack.goal_function_results
2121
textattack.goal_functions
22+
textattack.llms
2223
textattack.loggers
2324
textattack.metrics
2425
textattack.models
26+
textattack.prompt_augmentation
2527
textattack.search_methods
2628
textattack.shared
2729
textattack.transformations

tests/test_constraints/test_pretransformation_constraints.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,34 @@ def test_stopword_modification(
103103
set(range(len(entailment_attacked_text.words)))
104104
- {1, 2, 3, 8, 9, 11, 16, 17, 20, 22, 25, 31, 34, 39, 40, 41, 43, 44}
105105
)
106+
107+
def test_unmodifiable_indices(
108+
self, sentence_attacked_text, entailment_attacked_text
109+
):
110+
constraint = textattack.constraints.pre_transformation.UnmodifiableIndices(
111+
[4, 5]
112+
)
113+
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
114+
set(range(len(sentence_attacked_text.words))) - {4, 5}
115+
)
116+
sentence_attacked_text = sentence_attacked_text.delete_word_at_index(2)
117+
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
118+
set(range(len(sentence_attacked_text.words))) - {3, 4}
119+
)
120+
assert constraint._get_modifiable_indices(entailment_attacked_text) == (
121+
set(range(len(entailment_attacked_text.words))) - {4, 5}
122+
)
123+
entailment_attacked_text = (
124+
entailment_attacked_text.insert_text_after_word_index(0, "two words")
125+
)
126+
assert constraint._get_modifiable_indices(entailment_attacked_text) == (
127+
set(range(len(entailment_attacked_text.words))) - {6, 7}
128+
)
129+
130+
def test_unmodifiable_phrases(self, sentence_attacked_text):
131+
constraint = textattack.constraints.pre_transformation.UnmodifablePhrases(
132+
["South Korea's", "oil", "monday"]
133+
)
134+
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
135+
set(range(len(sentence_attacked_text.words))) - {0, 1, 9, 22}
136+
)

tests/test_prompt_augmentation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
def test_prompt_augmentation_pipeline():
2+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3+
4+
from textattack.augmentation.recipes import CheckListAugmenter
5+
from textattack.constraints.pre_transformation import UnmodifiableIndices
6+
from textattack.llms import HuggingFaceLLMWrapper
7+
from textattack.prompt_augmentation import PromptAugmentationPipeline
8+
9+
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
10+
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
11+
model_wrapper = HuggingFaceLLMWrapper(model, tokenizer)
12+
13+
augmenter = CheckListAugmenter()
14+
15+
pipeline = PromptAugmentationPipeline(augmenter, model_wrapper)
16+
17+
prompt = "As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: Poor Ben Bratt couldn't find stardom if MapQuest emailed him point-to-point driving directions."
18+
prompt_constraints = [UnmodifiableIndices([2, 3, 10, 12, 14])]
19+
20+
output = pipeline(prompt, prompt_constraints)
21+
22+
assert len(output) == 1
23+
assert len(output[0]) == 2
24+
assert "could not" in output[0][0]
25+
assert "negative" in output[0][1]

textattack/constraints/pre_transformation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
from .max_num_words_modified import MaxNumWordsModified
1414
from .min_word_length import MinWordLength
1515
from .max_modification_rate import MaxModificationRate
16+
from .unmodifiable_indices import UnmodifiableIndices
17+
from .unmodifiable_phrases import UnmodifablePhrases
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from textattack.constraints import PreTransformationConstraint
2+
3+
4+
class UnmodifiableIndices(PreTransformationConstraint):
5+
"""A constraint that prevents the modification of certain words at specific
6+
indices.
7+
8+
Args:
9+
indices (list(int)): A list of indices which are unmodifiable
10+
"""
11+
12+
def __init__(self, indices):
13+
self.unmodifiable_indices = indices
14+
15+
def _get_modifiable_indices(self, current_text):
16+
unmodifiable_set = current_text.convert_from_original_idxs(
17+
self.unmodifiable_indices
18+
)
19+
return set(
20+
i for i in range(0, len(current_text.words)) if i not in unmodifiable_set
21+
)
22+
23+
def extra_repr_keys(self):
24+
return ["indices"]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from collections import defaultdict
2+
3+
from textattack.constraints import PreTransformationConstraint
4+
5+
6+
class UnmodifablePhrases(PreTransformationConstraint):
7+
"""A constraint that prevents the modification of specified phrases or
8+
words.
9+
10+
Args:
11+
phrases (list(str)): A list of strings that cannot be modified
12+
"""
13+
14+
def __init__(self, phrases):
15+
self.length_to_phrases = defaultdict(set)
16+
for phrase in phrases:
17+
self.length_to_phrases[len(phrase.split())].add(phrase.lower())
18+
19+
def _get_modifiable_indices(self, current_text):
20+
phrase_indices = set()
21+
22+
for phrase_length in self.length_to_phrases.keys():
23+
for i in range(len(current_text.words) - phrase_length + 1):
24+
if (
25+
" ".join(current_text.words[i : i + phrase_length])
26+
in self.length_to_phrases[phrase_length]
27+
):
28+
phrase_indices |= set(range(i, i + phrase_length))
29+
30+
return set(i for i in range(len(current_text.words)) if i not in phrase_indices)
31+
32+
def extra_repr_keys(self):
33+
return ["phrases"]

textattack/llms/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
Large Language Models
3+
======================
4+
5+
TextAttack can generate responses to prompts using LLMs, which take in a list of strings and outputs a list of responses.
6+
7+
We've provided an implementation around two common LLM patterns:
8+
9+
1. `HuggingFaceLLMWrapper` for LLMs in HuggingFace
10+
2. `ChatGptWrapper` for OpenAI's ChatGPT model
11+
12+
13+
"""
14+
15+
from .chat_gpt_wrapper import ChatGptWrapper
16+
from .huggingface_llm_wrapper import HuggingFaceLLMWrapper

textattack/llms/chat_gpt_wrapper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
3+
from textattack.models.wrappers import ModelWrapper
4+
5+
6+
class ChatGptWrapper(ModelWrapper):
7+
"""A wrapper around OpenAI's ChatGPT model. Note that you must provide your
8+
own API key to use this wrapper.
9+
10+
Args:
11+
model_name (:obj:`str`): The name of the GPT model to use. See the OpenAI documentation
12+
for a list of latest model names
13+
key_environment_variable (:obj:`str`, 'optional`, defaults to :obj:`OPENAI_API_KEY`):
14+
The environment variable that the API key is set to
15+
"""
16+
17+
def __init__(
18+
self, model_name="gpt-3.5-turbo", key_environment_variable="OPENAI_API_KEY"
19+
):
20+
from openai import OpenAI
21+
22+
self.model_name = model_name
23+
self.client = OpenAI(api_key=os.getenv(key_environment_variable))
24+
25+
def __call__(self, text_input_list):
26+
"""Returns a list of responses to the given input list."""
27+
if isinstance(text_input_list, str):
28+
text_input_list = [text_input_list]
29+
30+
outputs = []
31+
for text in text_input_list:
32+
completion = self.client.chat.completions.create(
33+
model=self.model_name, messages=[{"role": "user", "content": text}]
34+
)
35+
outputs.append(completion.choices[0].message)
36+
37+
return outputs
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from textattack.models.wrappers import ModelWrapper
2+
3+
4+
class HuggingFaceLLMWrapper(ModelWrapper):
5+
"""A wrapper around HuggingFace for LLMs.
6+
7+
Args:
8+
model: A HuggingFace pretrained LLM
9+
tokenizer: A HuggingFace pretrained tokenizer
10+
"""
11+
12+
def __init__(self, model, tokenizer):
13+
self.model = model
14+
self.tokenizer = tokenizer
15+
16+
def __call__(self, text_input_list):
17+
"""Returns a list of responses to the given input list."""
18+
model_device = next(self.model.parameters()).device
19+
input_ids = self.tokenizer(text_input_list, return_tensors="pt").input_ids
20+
input_ids.to(model_device)
21+
22+
outputs = self.model.generate(
23+
input_ids, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id
24+
)
25+
26+
responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
27+
if len(text_input_list) == 1:
28+
return responses[0]
29+
return responses
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""
2+
Prompt Augmentation
3+
=====================
4+
5+
This package includes functions used to augment a prompt for a LLM
6+
7+
"""
8+
9+
from .prompt_augmentation_pipeline import PromptAugmentationPipeline
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from textattack.constraints import PreTransformationConstraint
2+
3+
4+
class PromptAugmentationPipeline:
5+
"""A prompt augmentation pipeline to augment a prompt and obtain the
6+
responses from a LLM on the augmented prompts.
7+
8+
Args:
9+
augmenter (textattack.Augmenter): the augmenter to use to
10+
augment the prompt
11+
llm (textattack.ModelWrapper): the LLM to generate responses
12+
to the augmented data
13+
"""
14+
15+
def __init__(self, augmenter, llm):
16+
self.augmenter = augmenter
17+
self.llm = llm
18+
19+
def __call__(self, prompt, prompt_constraints=[]):
20+
"""Augments the given prompt using the augmenter and generates
21+
responses using the LLM.
22+
23+
Args:
24+
prompt (:obj:`str`): the prompt to augment and generate responses
25+
prompt_constraints (List(textattack.constraints.PreTransformationConstraint)): a list of pretransformation
26+
constraints to apply to the given prompt
27+
28+
Returns a list of tuples of strings, where the first string in the pair is the augmented prompt and the second
29+
is the response to the augmented prompt from the LLM
30+
"""
31+
for constraint in prompt_constraints:
32+
if isinstance(constraint, PreTransformationConstraint):
33+
self.augmenter.pre_transformation_constraints.append(constraint)
34+
else:
35+
raise ValueError(
36+
"Prompt constraints must be of type PreTransformationConstraint"
37+
)
38+
39+
augmented_prompts = self.augmenter.augment(prompt)
40+
for _ in range(len(prompt_constraints)):
41+
self.augmenter.pre_transformation_constraints.pop()
42+
43+
outputs = []
44+
for augmented_prompt in augmented_prompts:
45+
outputs.append((augmented_prompt, self.llm(augmented_prompt)))
46+
return outputs

textattack/shared/attacked_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def convert_from_original_idxs(self, idxs: Iterable[int]) -> List[int]:
317317
elif isinstance(idxs, set):
318318
idxs = list(idxs)
319319

320-
elif not isinstance(idxs, [list, np.ndarray]):
320+
elif not isinstance(idxs, (list, np.ndarray)):
321321
raise TypeError(
322322
f"convert_from_original_idxs got invalid idxs type {type(idxs)}"
323323
)

0 commit comments

Comments
 (0)