Skip to content

Commit 1922d1a

Browse files
committed
add linear finetune
1 parent b013415 commit 1922d1a

File tree

14 files changed

+1357
-95
lines changed

14 files changed

+1357
-95
lines changed

docs/dummy_script.html

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ <h1 class="title">Module <code>imodelsx.dummy_script</code></h1>
3434
if name == &#39;chandan&#39; and random.random() &gt;= 0.1:
3535
raise ValueError(&#39;chandan is not a valid name&#39;)
3636
time.sleep(1)
37-
return f&#34;Hello {name}!&#34;
37+
print(f&#34;Hello {name}!&#34;)
38+
return
3839

3940

4041
if __name__ == &#39;__main__&#39;:
@@ -61,7 +62,8 @@ <h2 class="section-title" id="header-functions">Functions</h2>
6162
if name == &#39;chandan&#39; and random.random() &gt;= 0.1:
6263
raise ValueError(&#39;chandan is not a valid name&#39;)
6364
time.sleep(1)
64-
return f&#34;Hello {name}!&#34;</code></pre>
65+
print(f&#34;Hello {name}!&#34;)
66+
return</code></pre>
6567
</details>
6668
</dd>
6769
</dl>

docs/embgam/embed.html

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ <h1 class="title">Module <code>imodelsx.embgam.embed</code></h1>
8686
fit_with_ngram_decomposition: bool = True,
8787
):
8888
&#34;&#34;&#34;Get summed embeddings for a single example
89-
Note: this function gets called many times, so don&#39;t want to do things like load a model here
90-
9189

9290
Params
9391
------
@@ -104,7 +102,7 @@ <h1 class="title">Module <code>imodelsx.embgam.embed</code></h1>
104102
nlp_chunks
105103
if parsing is not empty string, a parser that extracts specific ngrams
106104
fit_with_ngram_decomposition
107-
whether to fit the model with ngram decomposition (if not just use the stsandard sentence)
105+
whether to fit the model with ngram decomposition (if not just use the standard sentence)
108106
&#34;&#34;&#34;
109107
if dataset_key_text is not None:
110108
sentence = example[dataset_key_text]
@@ -170,8 +168,7 @@ <h2 class="section-title" id="header-functions">Functions</h2>
170168
<span>def <span class="ident">embed_and_sum_function</span></span>(<span>example, model, ngrams: int, tokenizer_embeddings, tokenizer_ngrams, checkpoint: str, dataset_key_text: str = None, layer: str = 'last_hidden_state', padding: bool = True, parsing: str = '', nlp_chunks=None, all_ngrams: bool = False, fit_with_ngram_decomposition: bool = True)</span>
171169
</code></dt>
172170
<dd>
173-
<div class="desc"><p>Get summed embeddings for a single example
174-
Note: this function gets called many times, so don't want to do things like load a model here</p>
171+
<div class="desc"><p>Get summed embeddings for a single example</p>
175172
<h2 id="params">Params</h2>
176173
<p>ngrams: int
177174
What order of ngrams to use (1 for unigrams, 2 for bigrams, &hellip;)
@@ -186,7 +183,7 @@ <h2 id="params">Params</h2>
186183
nlp_chunks
187184
if parsing is not empty string, a parser that extracts specific ngrams
188185
fit_with_ngram_decomposition
189-
whether to fit the model with ngram decomposition (if not just use the stsandard sentence)</p></div>
186+
whether to fit the model with ngram decomposition (if not just use the standard sentence)</p></div>
190187
<details class="source">
191188
<summary>
192189
<span>Expand source code</span>
@@ -207,8 +204,6 @@ <h2 id="params">Params</h2>
207204
fit_with_ngram_decomposition: bool = True,
208205
):
209206
&#34;&#34;&#34;Get summed embeddings for a single example
210-
Note: this function gets called many times, so don&#39;t want to do things like load a model here
211-
212207

213208
Params
214209
------
@@ -225,7 +220,7 @@ <h2 id="params">Params</h2>
225220
nlp_chunks
226221
if parsing is not empty string, a parser that extracts specific ngrams
227222
fit_with_ngram_decomposition
228-
whether to fit the model with ngram decomposition (if not just use the stsandard sentence)
223+
whether to fit the model with ngram decomposition (if not just use the standard sentence)
229224
&#34;&#34;&#34;
230225
if dataset_key_text is not None:
231226
sentence = example[dataset_key_text]

docs/index.html

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ <h1 id="related-work">Related work</h1>
171171
<section>
172172
<h2 class="section-title" id="header-submodules">Sub-modules</h2>
173173
<dl>
174+
<dt><code class="name"><a title="imodelsx.cache_save_utils" href="cache_save_utils.html">imodelsx.cache_save_utils</a></code></dt>
175+
<dd>
176+
<div class="desc"></div>
177+
</dd>
174178
<dt><code class="name"><a title="imodelsx.d3" href="d3/index.html">imodelsx.d3</a></code></dt>
175179
<dd>
176180
<div class="desc"></div>
@@ -191,10 +195,18 @@ <h2 class="section-title" id="header-submodules">Sub-modules</h2>
191195
<dd>
192196
<div class="desc"></div>
193197
</dd>
198+
<dt><code class="name"><a title="imodelsx.linear_finetune" href="linear_finetune.html">imodelsx.linear_finetune</a></code></dt>
199+
<dd>
200+
<div class="desc"><p>Simple scikit-learn interface for Emb-GAM …</p></div>
201+
</dd>
194202
<dt><code class="name"><a title="imodelsx.metrics" href="metrics.html">imodelsx.metrics</a></code></dt>
195203
<dd>
196204
<div class="desc"></div>
197205
</dd>
206+
<dt><code class="name"><a title="imodelsx.process_results" href="process_results.html">imodelsx.process_results</a></code></dt>
207+
<dd>
208+
<div class="desc"></div>
209+
</dd>
198210
<dt><code class="name"><a title="imodelsx.submit_utils" href="submit_utils.html">imodelsx.submit_utils</a></code></dt>
199211
<dd>
200212
<div class="desc"></div>
@@ -232,12 +244,15 @@ <h1>Index</h1>
232244
<ul id="index">
233245
<li><h3><a href="#header-submodules">Sub-modules</a></h3>
234246
<ul>
247+
<li><code><a title="imodelsx.cache_save_utils" href="cache_save_utils.html">imodelsx.cache_save_utils</a></code></li>
235248
<li><code><a title="imodelsx.d3" href="d3/index.html">imodelsx.d3</a></code></li>
236249
<li><code><a title="imodelsx.data" href="data.html">imodelsx.data</a></code></li>
237250
<li><code><a title="imodelsx.dummy_script" href="dummy_script.html">imodelsx.dummy_script</a></code></li>
238251
<li><code><a title="imodelsx.embgam" href="embgam/index.html">imodelsx.embgam</a></code></li>
239252
<li><code><a title="imodelsx.iprompt" href="iprompt/index.html">imodelsx.iprompt</a></code></li>
253+
<li><code><a title="imodelsx.linear_finetune" href="linear_finetune.html">imodelsx.linear_finetune</a></code></li>
240254
<li><code><a title="imodelsx.metrics" href="metrics.html">imodelsx.metrics</a></code></li>
255+
<li><code><a title="imodelsx.process_results" href="process_results.html">imodelsx.process_results</a></code></li>
241256
<li><code><a title="imodelsx.submit_utils" href="submit_utils.html">imodelsx.submit_utils</a></code></li>
242257
<li><code><a title="imodelsx.util" href="util.html">imodelsx.util</a></code></li>
243258
<li><code><a title="imodelsx.viz" href="viz.html">imodelsx.viz</a></code></li>

docs/iprompt/api.html

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,7 @@ <h1 class="title">Module <code>imodelsx.iprompt.api</code></h1>
116116
&#39;output&#39;: output_strs,
117117
&#39;text&#39;: text_strs,
118118
})
119-
if n_shots == 1:
120-
dset = datasets.Dataset.from_pandas(df)
121-
else:
119+
if n_shots &gt; 1:
122120
d2 = defaultdict(list)
123121
for i in range(max_n_datapoints):
124122
all_shots = df.sample(n=n_shots, replace=False)
@@ -133,10 +131,11 @@ <h1 class="title">Module <code>imodelsx.iprompt.api</code></h1>
133131
d2[&#39;output&#39;].append(last_output)
134132
#
135133
df = pd.DataFrame.from_dict(d2)
136-
# shuffle rows
134+
# shuffle rows
135+
if max_n_datapoints &lt; len(df):
137136
df = df.sample(n=max_n_datapoints, replace=False)
138-
dset = datasets.Dataset.from_pandas(df)
139-
print(&#39;loading model...&#39;)
137+
dset = datasets.Dataset.from_pandas(df)
138+
print(f&#39;iPrompt got {len(dset)} datapoints, now loading model...&#39;)
140139

141140
model = model.to(device)
142141
dataloader = DataLoader(
@@ -551,6 +550,7 @@ <h1 class="title">Module <code>imodelsx.iprompt.api</code></h1>
551550
epoch_save_interval=epoch_save_interval,
552551
verbose=verbose,
553552
)
553+
model = model.cpu()
554554
return r[&#39;prefixes&#39;], r
555555

556556
# r = eval_model(args=args, r=r, dset=Dataset.from_dict(dset_test[:128]), model=model, tokenizer=tokenizer)
@@ -1085,6 +1085,7 @@ <h2 id="returns">Returns</h2>
10851085
epoch_save_interval=epoch_save_interval,
10861086
verbose=verbose,
10871087
)
1088+
model = model.cpu()
10881089
return r[&#39;prefixes&#39;], r
10891090

10901091
# r = eval_model(args=args, r=r, dset=Dataset.from_dict(dset_test[:128]), model=model, tokenizer=tokenizer)</code></pre>
@@ -1146,9 +1147,7 @@ <h2 id="params">Params</h2>
11461147
&#39;output&#39;: output_strs,
11471148
&#39;text&#39;: text_strs,
11481149
})
1149-
if n_shots == 1:
1150-
dset = datasets.Dataset.from_pandas(df)
1151-
else:
1150+
if n_shots &gt; 1:
11521151
d2 = defaultdict(list)
11531152
for i in range(max_n_datapoints):
11541153
all_shots = df.sample(n=n_shots, replace=False)
@@ -1163,10 +1162,11 @@ <h2 id="params">Params</h2>
11631162
d2[&#39;output&#39;].append(last_output)
11641163
#
11651164
df = pd.DataFrame.from_dict(d2)
1166-
# shuffle rows
1165+
# shuffle rows
1166+
if max_n_datapoints &lt; len(df):
11671167
df = df.sample(n=max_n_datapoints, replace=False)
1168-
dset = datasets.Dataset.from_pandas(df)
1169-
print(&#39;loading model...&#39;)
1168+
dset = datasets.Dataset.from_pandas(df)
1169+
print(f&#39;iPrompt got {len(dset)} datapoints, now loading model...&#39;)
11701170

11711171
model = model.to(device)
11721172
dataloader = DataLoader(

docs/iprompt/autoprompt.html

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,15 @@ <h1 class="title">Module <code>imodelsx.iprompt.autoprompt</code></h1>
126126
# os.makedirs(save_dir, exist_ok=True)
127127
# pickle.dump(self._prefix_pool, open(os.path.join(save_dir, &#39;prefix_pool.p&#39;), &#39;wb&#39;))
128128

129-
if self._do_final_reranking:
130-
all_prefixes = self._prefix_pool.topk_all(
131-
k=self._num_prefixes_to_test, min_occurrences=1)
129+
all_prefixes = self._prefix_pool.topk_all(
130+
k=self._num_prefixes_to_test, min_occurrences=2)
132131

133-
if not len(all_prefixes):
134-
# In the case where we get no prefixes here (i.e. prompt generation
135-
# only ran for a single step) just take anything from prefix pool.
136-
all_prefixes = list(self._prefix_pool.prefixes)
132+
if not len(all_prefixes):
133+
# In the case where we get no prefixes here (i.e. prompt generation
134+
# only ran for a single step) just take anything from prefix pool.
135+
all_prefixes = random.choices(self._prefix_pool.prefixes, k=self._num_prefixes_to_test)
137136

137+
if self._do_final_reranking:
138138
all_losses, all_accuracies = self._test_prefixes(
139139
prefixes=all_prefixes,
140140
eval_dataloader=eval_dataloader,
@@ -148,14 +148,14 @@ <h1 class="title">Module <code>imodelsx.iprompt.autoprompt</code></h1>
148148
False, True]).reset_index()
149149
else:
150150
all_prefixes = list(self._prefix_pool.prefixes)
151-
all_losses = [-1] * len(all_prefixes)
152-
all_accuracies = [-1] * len(all_prefixes)
151+
all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes]
152+
all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes]
153153

154154
df = pd.DataFrame(
155155
zip(*[all_prefixes, all_losses, all_accuracies]),
156156
columns=[&#39;prefix&#39;, &#39;loss&#39;, &#39;accuracy&#39;]
157157
)
158-
# df = df.sort_values(by=&#39;loss&#39;, ascending=True).reset_index()
158+
df = df.sort_values(by=&#39;accuracy&#39;, ascending=False).reset_index()
159159

160160
df[&#39;prefix_str&#39;] = df[&#39;prefix&#39;].map(self.tokenizer.decode)
161161
df[&#39;n_queries&#39;] = df[&#39;prefix&#39;].map(
@@ -430,15 +430,15 @@ <h2 class="section-title" id="header-classes">Classes</h2>
430430
# os.makedirs(save_dir, exist_ok=True)
431431
# pickle.dump(self._prefix_pool, open(os.path.join(save_dir, &#39;prefix_pool.p&#39;), &#39;wb&#39;))
432432

433-
if self._do_final_reranking:
434-
all_prefixes = self._prefix_pool.topk_all(
435-
k=self._num_prefixes_to_test, min_occurrences=1)
433+
all_prefixes = self._prefix_pool.topk_all(
434+
k=self._num_prefixes_to_test, min_occurrences=2)
436435

437-
if not len(all_prefixes):
438-
# In the case where we get no prefixes here (i.e. prompt generation
439-
# only ran for a single step) just take anything from prefix pool.
440-
all_prefixes = list(self._prefix_pool.prefixes)
436+
if not len(all_prefixes):
437+
# In the case where we get no prefixes here (i.e. prompt generation
438+
# only ran for a single step) just take anything from prefix pool.
439+
all_prefixes = random.choices(self._prefix_pool.prefixes, k=self._num_prefixes_to_test)
441440

441+
if self._do_final_reranking:
442442
all_losses, all_accuracies = self._test_prefixes(
443443
prefixes=all_prefixes,
444444
eval_dataloader=eval_dataloader,
@@ -452,14 +452,14 @@ <h2 class="section-title" id="header-classes">Classes</h2>
452452
False, True]).reset_index()
453453
else:
454454
all_prefixes = list(self._prefix_pool.prefixes)
455-
all_losses = [-1] * len(all_prefixes)
456-
all_accuracies = [-1] * len(all_prefixes)
455+
all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes]
456+
all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes]
457457

458458
df = pd.DataFrame(
459459
zip(*[all_prefixes, all_losses, all_accuracies]),
460460
columns=[&#39;prefix&#39;, &#39;loss&#39;, &#39;accuracy&#39;]
461461
)
462-
# df = df.sort_values(by=&#39;loss&#39;, ascending=True).reset_index()
462+
df = df.sort_values(by=&#39;accuracy&#39;, ascending=False).reset_index()
463463

464464
df[&#39;prefix_str&#39;] = df[&#39;prefix&#39;].map(self.tokenizer.decode)
465465
df[&#39;n_queries&#39;] = df[&#39;prefix&#39;].map(
@@ -682,15 +682,15 @@ <h3>Methods</h3>
682682
# os.makedirs(save_dir, exist_ok=True)
683683
# pickle.dump(self._prefix_pool, open(os.path.join(save_dir, &#39;prefix_pool.p&#39;), &#39;wb&#39;))
684684

685-
if self._do_final_reranking:
686-
all_prefixes = self._prefix_pool.topk_all(
687-
k=self._num_prefixes_to_test, min_occurrences=1)
685+
all_prefixes = self._prefix_pool.topk_all(
686+
k=self._num_prefixes_to_test, min_occurrences=2)
688687

689-
if not len(all_prefixes):
690-
# In the case where we get no prefixes here (i.e. prompt generation
691-
# only ran for a single step) just take anything from prefix pool.
692-
all_prefixes = list(self._prefix_pool.prefixes)
688+
if not len(all_prefixes):
689+
# In the case where we get no prefixes here (i.e. prompt generation
690+
# only ran for a single step) just take anything from prefix pool.
691+
all_prefixes = random.choices(self._prefix_pool.prefixes, k=self._num_prefixes_to_test)
693692

693+
if self._do_final_reranking:
694694
all_losses, all_accuracies = self._test_prefixes(
695695
prefixes=all_prefixes,
696696
eval_dataloader=eval_dataloader,
@@ -704,14 +704,14 @@ <h3>Methods</h3>
704704
False, True]).reset_index()
705705
else:
706706
all_prefixes = list(self._prefix_pool.prefixes)
707-
all_losses = [-1] * len(all_prefixes)
708-
all_accuracies = [-1] * len(all_prefixes)
707+
all_losses = [self._prefix_pool._avg_loss.get(p, -1) for p in all_prefixes]
708+
all_accuracies = [self._prefix_pool._avg_accuracy.get(p, -1) for p in all_prefixes]
709709

710710
df = pd.DataFrame(
711711
zip(*[all_prefixes, all_losses, all_accuracies]),
712712
columns=[&#39;prefix&#39;, &#39;loss&#39;, &#39;accuracy&#39;]
713713
)
714-
# df = df.sort_values(by=&#39;loss&#39;, ascending=True).reset_index()
714+
df = df.sort_values(by=&#39;accuracy&#39;, ascending=False).reset_index()
715715

716716
df[&#39;prefix_str&#39;] = df[&#39;prefix&#39;].map(self.tokenizer.decode)
717717
df[&#39;n_queries&#39;] = df[&#39;prefix&#39;].map(

docs/iprompt/ipromptx.html

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ <h1 class="title">Module <code>imodelsx.iprompt.ipromptx</code></h1>
116116
tokenizer=self.tokenizer,
117117
criterion=pop_criterion, # &#39;loss&#39; # in [&#39;loss&#39;, &#39;acc&#39;, &#39;combined&#39;]
118118
topk_strategy=pop_topk_strategy,
119+
verbose=verbose,
119120
)
120121
# Suff to track for early stopping
121122
self._early_stopping_steps = early_stopping_steps
@@ -254,8 +255,8 @@ <h1 class="title">Module <code>imodelsx.iprompt.ipromptx</code></h1>
254255

255256
def _get_population_and_random_generations(self, full_text_ids: torch.Tensor) -&gt; torch.Tensor:
256257
population_pool = self._select_pop_topk(k=self._topk_pop_sample)
257-
if self._iprompt_verbose:
258-
print(&#34;population_pool:&#34;, [self.tokenizer.decode(p) for p in population_pool])
258+
# if self._iprompt_verbose:
259+
# print(&#34;population_pool:&#34;, [self.tokenizer.decode(p) for p in population_pool])
259260
population = random.sample(population_pool, self._pop_size)
260261
population = torch.tensor(population).to(device)
261262

@@ -569,6 +570,7 @@ <h2 class="section-title" id="header-classes">Classes</h2>
569570
tokenizer=self.tokenizer,
570571
criterion=pop_criterion, # &#39;loss&#39; # in [&#39;loss&#39;, &#39;acc&#39;, &#39;combined&#39;]
571572
topk_strategy=pop_topk_strategy,
573+
verbose=verbose,
572574
)
573575
# Suff to track for early stopping
574576
self._early_stopping_steps = early_stopping_steps
@@ -707,8 +709,8 @@ <h2 class="section-title" id="header-classes">Classes</h2>
707709

708710
def _get_population_and_random_generations(self, full_text_ids: torch.Tensor) -&gt; torch.Tensor:
709711
population_pool = self._select_pop_topk(k=self._topk_pop_sample)
710-
if self._iprompt_verbose:
711-
print(&#34;population_pool:&#34;, [self.tokenizer.decode(p) for p in population_pool])
712+
# if self._iprompt_verbose:
713+
# print(&#34;population_pool:&#34;, [self.tokenizer.decode(p) for p in population_pool])
712714
population = random.sample(population_pool, self._pop_size)
713715
population = torch.tensor(population).to(device)
714716

0 commit comments

Comments
 (0)