Skip to content

Commit 0ce1f45

Browse files
committed
add treeprompt
1 parent d30c305 commit 0ce1f45

18 files changed

+4403
-468
lines changed

demo_notebooks/tree_prompt.ipynb

Lines changed: 244 additions & 0 deletions
Large diffs are not rendered by default.

docs/auggam/auggam.html

Lines changed: 574 additions & 95 deletions
Large diffs are not rendered by default.

docs/embeddings.html

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
<!doctype html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="utf-8">
5+
<meta name="viewport" content="width=device-width, initial-scale=1, minimum-scale=1" />
6+
<meta name="generator" content="pdoc 0.10.0" />
7+
<title>imodelsx.embeddings API documentation</title>
8+
<meta name="description" content="" />
9+
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css" integrity="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs=" crossorigin>
10+
<link rel="preload stylesheet" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css" integrity="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg=" crossorigin>
11+
<link rel="stylesheet preload" as="style" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css" crossorigin>
12+
<style>:root{--highlight-color:#fe9}.flex{display:flex !important}body{line-height:1.5em}#content{padding:20px}#sidebar{padding:30px;overflow:hidden}#sidebar > *:last-child{margin-bottom:2cm}.http-server-breadcrumbs{font-size:130%;margin:0 0 15px 0}#footer{font-size:.75em;padding:5px 30px;border-top:1px solid #ddd;text-align:right}#footer p{margin:0 0 0 1em;display:inline-block}#footer p:last-child{margin-right:30px}h1,h2,h3,h4,h5{font-weight:300}h1{font-size:2.5em;line-height:1.1em}h2{font-size:1.75em;margin:1em 0 .50em 0}h3{font-size:1.4em;margin:25px 0 10px 0}h4{margin:0;font-size:105%}h1:target,h2:target,h3:target,h4:target,h5:target,h6:target{background:var(--highlight-color);padding:.2em 0}a{color:#058;text-decoration:none;transition:color .3s ease-in-out}a:hover{color:#e82}.title code{font-weight:bold}h2[id^="header-"]{margin-top:2em}.ident{color:#900}pre code{background:#f8f8f8;font-size:.8em;line-height:1.4em}code{background:#f2f2f1;padding:1px 4px;overflow-wrap:break-word}h1 code{background:transparent}pre{background:#f8f8f8;border:0;border-top:1px solid #ccc;border-bottom:1px solid #ccc;margin:1em 0;padding:1ex}#http-server-module-list{display:flex;flex-flow:column}#http-server-module-list div{display:flex}#http-server-module-list dt{min-width:10%}#http-server-module-list p{margin-top:0}.toc ul,#index{list-style-type:none;margin:0;padding:0}#index code{background:transparent}#index h3{border-bottom:1px solid #ddd}#index ul{padding:0}#index h4{margin-top:.6em;font-weight:bold}@media (min-width:200ex){#index .two-column{column-count:2}}@media (min-width:300ex){#index .two-column{column-count:3}}dl{margin-bottom:2em}dl dl:last-child{margin-bottom:4em}dd{margin:0 0 1em 3em}#header-classes + dl > dd{margin-bottom:3em}dd dd{margin-left:2em}dd p{margin:10px 0}.name{background:#eee;font-weight:bold;font-size:.85em;padding:5px 10px;display:inline-block;min-width:40%}.name:hover{background:#e0e0e0}dt:target .name{background:var(--highlight-color)}.name > span:first-child{white-space:nowrap}.name.class > span:nth-child(2){margin-left:.4em}.inherited{color:#999;border-left:5px solid #eee;padding-left:1em}.inheritance em{font-style:normal;font-weight:bold}.desc h2{font-weight:400;font-size:1.25em}.desc h3{font-size:1em}.desc dt code{background:inherit}.source summary,.git-link-div{color:#666;text-align:right;font-weight:400;font-size:.8em;text-transform:uppercase}.source summary > *{white-space:nowrap;cursor:pointer}.git-link{color:inherit;margin-left:1em}.source pre{max-height:500px;overflow:auto;margin:0}.source pre code{font-size:12px;overflow:visible}.hlist{list-style:none}.hlist li{display:inline}.hlist li:after{content:',\2002'}.hlist li:last-child:after{content:none}.hlist .hlist{display:inline;padding-left:1em}img{max-width:100%}td{padding:0 .5em}.admonition{padding:.1em .5em;margin-bottom:1em}.admonition-title{font-weight:bold}.admonition.note,.admonition.info,.admonition.important{background:#aef}.admonition.todo,.admonition.versionadded,.admonition.tip,.admonition.hint{background:#dfd}.admonition.warning,.admonition.versionchanged,.admonition.deprecated{background:#fd4}.admonition.error,.admonition.danger,.admonition.caution{background:lightpink}</style>
13+
<style media="screen and (min-width: 700px)">@media screen and (min-width:700px){#sidebar{width:30%;height:100vh;overflow:auto;position:sticky;top:0}#content{width:70%;max-width:100ch;padding:3em 4em;border-left:1px solid #ddd}pre code{font-size:1em}.item .name{font-size:1em}main{display:flex;flex-direction:row-reverse;justify-content:flex-end}.toc ul ul,#index ul{padding-left:1.5em}.toc > ul > li{margin-top:.5em}}</style>
14+
<style media="print">@media print{#sidebar h1{page-break-before:always}.source{display:none}}@media print{*{background:transparent !important;color:#000 !important;box-shadow:none !important;text-shadow:none !important}a[href]:after{content:" (" attr(href) ")";font-size:90%}a[href][title]:after{content:none}abbr[title]:after{content:" (" attr(title) ")"}.ir a:after,a[href^="javascript:"]:after,a[href^="#"]:after{content:""}pre,blockquote{border:1px solid #999;page-break-inside:avoid}thead{display:table-header-group}tr,img{page-break-inside:avoid}img{max-width:100% !important}@page{margin:0.5cm}p,h2,h3{orphans:3;widows:3}h1,h2,h3,h4,h5,h6{page-break-after:avoid}}</style>
15+
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js" integrity="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8=" crossorigin></script>
16+
<script>window.addEventListener('DOMContentLoaded', () => hljs.initHighlighting())</script>
17+
</head>
18+
<body>
19+
<main>
20+
<article id="content">
21+
<header>
22+
<h1 class="title">Module <code>imodelsx.embeddings</code></h1>
23+
</header>
24+
<section id="section-intro">
25+
<details class="source">
26+
<summary>
27+
<span>Expand source code</span>
28+
</summary>
29+
<pre><code class="python">import pandas as pd
30+
import numpy as np
31+
import seaborn as sns
32+
from tqdm import tqdm
33+
import matplotlib.pyplot as plt
34+
import torch
35+
from transformers import AutoTokenizer, AutoModel
36+
from sklearn.metrics.pairwise import cosine_similarity
37+
from typing import List
38+
from sklearn.feature_extraction.text import TfidfVectorizer
39+
import imodelsx.embeddings
40+
from copy import deepcopy
41+
42+
43+
def get_embs(
44+
texts: List[str], checkpoint: str = &#34;bert-base-uncased&#34;, batch_size: int = 32,
45+
aggregate: str = &#34;mean&#34;
46+
) -&gt; np.ndarray:
47+
&#39;&#39;&#39;
48+
Get embeddings for a list of texts.
49+
50+
Params
51+
------
52+
texts: List[str]: List of texts to get embeddings for.
53+
checkpoint: str: Name of the checkpoint to use. Use tf-idf for linear embeddings.
54+
batch_size: int: Batch size to use for inference.
55+
aggregate: str: Aggregation method to use for the embeddings. Can be &#34;mean&#34; or &#34;first&#34; (to use CLS token for BERT).
56+
&#39;&#39;&#39;
57+
if checkpoint == &#34;tf-idf&#34;:
58+
return get_embs_linear(texts)
59+
60+
# load model
61+
# get embeddings for each text from the corpus
62+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
63+
model = AutoModel.from_pretrained(checkpoint).to(&#34;cuda&#34;)
64+
65+
# calculate embeddings
66+
embs = []
67+
for i in tqdm(range(0, len(texts), batch_size)):
68+
t = texts[i: i + batch_size]
69+
with torch.no_grad():
70+
# tokenize
71+
inputs = tokenizer(
72+
t, return_tensors=&#34;pt&#34;, padding=True, truncation=True
73+
).to(&#34;cuda&#34;)
74+
# Shape: [batch_size, seq_len, hidden_size]
75+
outputs = model(**inputs).last_hidden_state.detach().cpu().numpy()
76+
# average over sequence length
77+
if aggregate == &#34;mean&#34;:
78+
emb = np.mean(outputs, axis=1).squeeze()
79+
elif aggregate == &#34;first&#34;:
80+
emb = outputs[:, 0, :].squeeze() # use CLS token
81+
embs.append(deepcopy(emb))
82+
embs = np.concatenate(embs)
83+
return embs
84+
85+
86+
def get_embs_linear(texts: List[str]) -&gt; np.ndarray:
87+
&#34;&#34;&#34;Get TF-IDF vectors for a list of texts.
88+
89+
Parameters
90+
----------
91+
texts (List[str]): List of texts to get TF-IDF vectors for.
92+
93+
Returns
94+
-------
95+
embs: np.ndarray: TF-IDF vectors for the input texts.
96+
&#34;&#34;&#34;
97+
vectorizer = TfidfVectorizer(
98+
# tokenizer=AutoTokenizer.from_pretrained(checkpoint).tokenize,
99+
# preprocessor=lambda x: x,
100+
# token_pattern=None,
101+
lowercase=False,
102+
max_features=10000,
103+
)
104+
return vectorizer.fit_transform(texts).toarray()</code></pre>
105+
</details>
106+
</section>
107+
<section>
108+
</section>
109+
<section>
110+
</section>
111+
<section>
112+
<h2 class="section-title" id="header-functions">Functions</h2>
113+
<dl>
114+
<dt id="imodelsx.embeddings.get_embs"><code class="name flex">
115+
<span>def <span class="ident">get_embs</span></span>(<span>texts: List[str], checkpoint: str = 'bert-base-uncased', batch_size: int = 32, aggregate: str = 'mean') ‑> numpy.ndarray</span>
116+
</code></dt>
117+
<dd>
118+
<div class="desc"><p>Get embeddings for a list of texts.</p>
119+
<h2 id="params">Params</h2>
120+
<p>texts: List[str]: List of texts to get embeddings for.
121+
checkpoint: str: Name of the checkpoint to use. Use tf-idf for linear embeddings.
122+
batch_size: int: Batch size to use for inference.
123+
aggregate: str: Aggregation method to use for the embeddings. Can be "mean" or "first" (to use CLS token for BERT).</p></div>
124+
<details class="source">
125+
<summary>
126+
<span>Expand source code</span>
127+
</summary>
128+
<pre><code class="python">def get_embs(
129+
texts: List[str], checkpoint: str = &#34;bert-base-uncased&#34;, batch_size: int = 32,
130+
aggregate: str = &#34;mean&#34;
131+
) -&gt; np.ndarray:
132+
&#39;&#39;&#39;
133+
Get embeddings for a list of texts.
134+
135+
Params
136+
------
137+
texts: List[str]: List of texts to get embeddings for.
138+
checkpoint: str: Name of the checkpoint to use. Use tf-idf for linear embeddings.
139+
batch_size: int: Batch size to use for inference.
140+
aggregate: str: Aggregation method to use for the embeddings. Can be &#34;mean&#34; or &#34;first&#34; (to use CLS token for BERT).
141+
&#39;&#39;&#39;
142+
if checkpoint == &#34;tf-idf&#34;:
143+
return get_embs_linear(texts)
144+
145+
# load model
146+
# get embeddings for each text from the corpus
147+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
148+
model = AutoModel.from_pretrained(checkpoint).to(&#34;cuda&#34;)
149+
150+
# calculate embeddings
151+
embs = []
152+
for i in tqdm(range(0, len(texts), batch_size)):
153+
t = texts[i: i + batch_size]
154+
with torch.no_grad():
155+
# tokenize
156+
inputs = tokenizer(
157+
t, return_tensors=&#34;pt&#34;, padding=True, truncation=True
158+
).to(&#34;cuda&#34;)
159+
# Shape: [batch_size, seq_len, hidden_size]
160+
outputs = model(**inputs).last_hidden_state.detach().cpu().numpy()
161+
# average over sequence length
162+
if aggregate == &#34;mean&#34;:
163+
emb = np.mean(outputs, axis=1).squeeze()
164+
elif aggregate == &#34;first&#34;:
165+
emb = outputs[:, 0, :].squeeze() # use CLS token
166+
embs.append(deepcopy(emb))
167+
embs = np.concatenate(embs)
168+
return embs</code></pre>
169+
</details>
170+
</dd>
171+
<dt id="imodelsx.embeddings.get_embs_linear"><code class="name flex">
172+
<span>def <span class="ident">get_embs_linear</span></span>(<span>texts: List[str]) ‑> numpy.ndarray</span>
173+
</code></dt>
174+
<dd>
175+
<div class="desc"><p>Get TF-IDF vectors for a list of texts.</p>
176+
<h2 id="parameters">Parameters</h2>
177+
<p>texts (List[str]): List of texts to get TF-IDF vectors for.</p>
178+
<h2 id="returns">Returns</h2>
179+
<p>embs: np.ndarray: TF-IDF vectors for the input texts.</p></div>
180+
<details class="source">
181+
<summary>
182+
<span>Expand source code</span>
183+
</summary>
184+
<pre><code class="python">def get_embs_linear(texts: List[str]) -&gt; np.ndarray:
185+
&#34;&#34;&#34;Get TF-IDF vectors for a list of texts.
186+
187+
Parameters
188+
----------
189+
texts (List[str]): List of texts to get TF-IDF vectors for.
190+
191+
Returns
192+
-------
193+
embs: np.ndarray: TF-IDF vectors for the input texts.
194+
&#34;&#34;&#34;
195+
vectorizer = TfidfVectorizer(
196+
# tokenizer=AutoTokenizer.from_pretrained(checkpoint).tokenize,
197+
# preprocessor=lambda x: x,
198+
# token_pattern=None,
199+
lowercase=False,
200+
max_features=10000,
201+
)
202+
return vectorizer.fit_transform(texts).toarray()</code></pre>
203+
</details>
204+
</dd>
205+
</dl>
206+
</section>
207+
<section>
208+
</section>
209+
</article>
210+
<nav id="sidebar">
211+
<h1>Index</h1>
212+
<div class="toc">
213+
<ul></ul>
214+
</div>
215+
<ul id="index">
216+
<li><h3>Super-module</h3>
217+
<ul>
218+
<li><code><a title="imodelsx" href="index.html">imodelsx</a></code></li>
219+
</ul>
220+
</li>
221+
<li><h3><a href="#header-functions">Functions</a></h3>
222+
<ul class="">
223+
<li><code><a title="imodelsx.embeddings.get_embs" href="#imodelsx.embeddings.get_embs">get_embs</a></code></li>
224+
<li><code><a title="imodelsx.embeddings.get_embs_linear" href="#imodelsx.embeddings.get_embs_linear">get_embs_linear</a></code></li>
225+
</ul>
226+
</li>
227+
</ul>
228+
</nav>
229+
</main>
230+
<footer id="footer">
231+
<p>Generated by <a href="https://pdoc3.github.io/pdoc" title="pdoc: Python API documentation generator"><cite>pdoc</cite> 0.10.0</a>.</p>
232+
</footer>
233+
</body>
234+
</html>

0 commit comments

Comments
 (0)