1
+ <!doctype html>
2
+ < html lang ="en ">
3
+ < head >
4
+ < meta charset ="utf-8 ">
5
+ < meta name ="viewport " content ="width=device-width, initial-scale=1, minimum-scale=1 " />
6
+ < meta name ="generator " content ="pdoc 0.10.0 " />
7
+ < title > imodelsx.embeddings API documentation</ title >
8
+ < meta name ="description " content ="" />
9
+ < link rel ="preload stylesheet " as ="style " href ="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/sanitize.min.css " integrity ="sha256-PK9q560IAAa6WVRRh76LtCaI8pjTJ2z11v0miyNNjrs= " crossorigin >
10
+ < link rel ="preload stylesheet " as ="style " href ="https://cdnjs.cloudflare.com/ajax/libs/10up-sanitize.css/11.0.1/typography.min.css " integrity ="sha256-7l/o7C8jubJiy74VsKTidCy1yBkRtiUGbVkYBylBqUg= " crossorigin >
11
+ < link rel ="stylesheet preload " as ="style " href ="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/styles/github.min.css " crossorigin >
12
+ < style > : root {--highlight-color : # fe9 }.flex {display : flex !important }body {line-height : 1.5em }# content {padding : 20px }# sidebar {padding : 30px ;overflow : hidden}# sidebar > * : last-child {margin-bottom : 2cm }.http-server-breadcrumbs {font-size : 130% ;margin : 0 0 15px 0 }# footer {font-size : .75em ;padding : 5px 30px ;border-top : 1px solid # ddd ;text-align : right}# footer p {margin : 0 0 0 1em ;display : inline-block}# footer p : last-child {margin-right : 30px }h1 , h2 , h3 , h4 , h5 {font-weight : 300 }h1 {font-size : 2.5em ;line-height : 1.1em }h2 {font-size : 1.75em ;margin : 1em 0 .50em 0 }h3 {font-size : 1.4em ;margin : 25px 0 10px 0 }h4 {margin : 0 ;font-size : 105% }h1 : target , h2 : target , h3 : target , h4 : target , h5 : target , h6 : target {background : var (--highlight-color );padding : .2em 0 }a {color : # 058 ;text-decoration : none;transition : color .3s ease-in-out}a : hover {color : # e82 }.title code {font-weight : bold}h2 [id ^= "header-" ]{margin-top : 2em }.ident {color : # 900 }pre code {background : # f8f8f8 ;font-size : .8em ;line-height : 1.4em }code {background : # f2f2f1 ;padding : 1px 4px ;overflow-wrap : break-word}h1 code {background : transparent}pre {background : # f8f8f8 ;border : 0 ;border-top : 1px solid # ccc ;border-bottom : 1px solid # ccc ;margin : 1em 0 ;padding : 1ex }# http-server-module-list {display : flex;flex-flow : column}# http-server-module-list div {display : flex}# http-server-module-list dt {min-width : 10% }# http-server-module-list p {margin-top : 0 }.toc ul , # index {list-style-type : none;margin : 0 ;padding : 0 }# index code {background : transparent}# index h3 {border-bottom : 1px solid # ddd }# index ul {padding : 0 }# index h4 {margin-top : .6em ;font-weight : bold}@media (min-width : 200ex ){# index .two-column {column-count : 2 }}@media (min-width : 300ex ){# index .two-column {column-count : 3 }}dl {margin-bottom : 2em }dl dl : last-child {margin-bottom : 4em }dd {margin : 0 0 1em 3em }# header-classes + dl > dd {margin-bottom : 3em }dd dd {margin-left : 2em }dd p {margin : 10px 0 }.name {background : # eee ;font-weight : bold;font-size : .85em ;padding : 5px 10px ;display : inline-block;min-width : 40% }.name : hover {background : # e0e0e0 }dt : target .name {background : var (--highlight-color )}.name > span : first-child {white-space : nowrap}.name .class > span : nth-child (2 ){margin-left : .4em }.inherited {color : # 999 ;border-left : 5px solid # eee ;padding-left : 1em }.inheritance em {font-style : normal;font-weight : bold}.desc h2 {font-weight : 400 ;font-size : 1.25em }.desc h3 {font-size : 1em }.desc dt code {background : inherit}.source summary , .git-link-div {color : # 666 ;text-align : right;font-weight : 400 ;font-size : .8em ;text-transform : uppercase}.source summary > * {white-space : nowrap;cursor : pointer}.git-link {color : inherit;margin-left : 1em }.source pre {max-height : 500px ;overflow : auto;margin : 0 }.source pre code {font-size : 12px ;overflow : visible}.hlist {list-style : none}.hlist li {display : inline}.hlist li : after {content : ',\2002' }.hlist li : last-child : after {content : none}.hlist .hlist {display : inline;padding-left : 1em }img {max-width : 100% }td {padding : 0 .5em }.admonition {padding : .1em .5em ;margin-bottom : 1em }.admonition-title {font-weight : bold}.admonition .note , .admonition .info , .admonition .important {background : # aef }.admonition .todo , .admonition .versionadded , .admonition .tip , .admonition .hint {background : # dfd }.admonition .warning , .admonition .versionchanged , .admonition .deprecated {background : # fd4 }.admonition .error , .admonition .danger , .admonition .caution {background : lightpink}</ style >
13
+ < style media ="screen and (min-width: 700px) "> @media screen and (min-width : 700px ){# sidebar {width : 30% ;height : 100vh ;overflow : auto;position : sticky;top : 0 }# content {width : 70% ;max-width : 100ch ;padding : 3em 4em ;border-left : 1px solid # ddd }pre code {font-size : 1em }.item .name {font-size : 1em }main {display : flex;flex-direction : row-reverse;justify-content : flex-end}.toc ul ul , # index ul {padding-left : 1.5em }.toc > ul > li {margin-top : .5em }}</ style >
14
+ < style media ="print "> @media print{# sidebar h1 {page-break-before : always}.source {display : none}}@media print{* {background : transparent !important ;color : # 000 !important ;box-shadow : none !important ;text-shadow : none !important }a [href ]: after {content : " (" attr (href) ")" ;font-size : 90% }a [href ][title ]: after {content : none}abbr [title ]: after {content : " (" attr (title) ")" }.ir a : after , a [href ^= "javascript:" ]: after , a [href ^= "#" ]: after {content : "" }pre , blockquote {border : 1px solid # 999 ;page-break-inside : avoid}thead {display : table-header-group}tr , img {page-break-inside : avoid}img {max-width : 100% !important }@page {margin : 0.5cm }p , h2 , h3 {orphans : 3 ;widows : 3 }h1 , h2 , h3 , h4 , h5 , h6 {page-break-after : avoid}}</ style >
15
+ < script defer src ="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.1.1/highlight.min.js " integrity ="sha256-Uv3H6lx7dJmRfRvH8TH6kJD1TSK1aFcwgx+mdg3epi8= " crossorigin > </ script >
16
+ < script > window . addEventListener ( 'DOMContentLoaded' , ( ) => hljs . initHighlighting ( ) ) </ script >
17
+ </ head >
18
+ < body >
19
+ < main >
20
+ < article id ="content ">
21
+ < header >
22
+ < h1 class ="title "> Module < code > imodelsx.embeddings</ code > </ h1 >
23
+ </ header >
24
+ < section id ="section-intro ">
25
+ < details class ="source ">
26
+ < summary >
27
+ < span > Expand source code</ span >
28
+ </ summary >
29
+ < pre > < code class ="python "> import pandas as pd
30
+ import numpy as np
31
+ import seaborn as sns
32
+ from tqdm import tqdm
33
+ import matplotlib.pyplot as plt
34
+ import torch
35
+ from transformers import AutoTokenizer, AutoModel
36
+ from sklearn.metrics.pairwise import cosine_similarity
37
+ from typing import List
38
+ from sklearn.feature_extraction.text import TfidfVectorizer
39
+ import imodelsx.embeddings
40
+ from copy import deepcopy
41
+
42
+
43
+ def get_embs(
44
+ texts: List[str], checkpoint: str = "bert-base-uncased", batch_size: int = 32,
45
+ aggregate: str = "mean"
46
+ ) -> np.ndarray:
47
+ '''
48
+ Get embeddings for a list of texts.
49
+
50
+ Params
51
+ ------
52
+ texts: List[str]: List of texts to get embeddings for.
53
+ checkpoint: str: Name of the checkpoint to use. Use tf-idf for linear embeddings.
54
+ batch_size: int: Batch size to use for inference.
55
+ aggregate: str: Aggregation method to use for the embeddings. Can be "mean" or "first" (to use CLS token for BERT).
56
+ '''
57
+ if checkpoint == "tf-idf":
58
+ return get_embs_linear(texts)
59
+
60
+ # load model
61
+ # get embeddings for each text from the corpus
62
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
63
+ model = AutoModel.from_pretrained(checkpoint).to("cuda")
64
+
65
+ # calculate embeddings
66
+ embs = []
67
+ for i in tqdm(range(0, len(texts), batch_size)):
68
+ t = texts[i: i + batch_size]
69
+ with torch.no_grad():
70
+ # tokenize
71
+ inputs = tokenizer(
72
+ t, return_tensors="pt", padding=True, truncation=True
73
+ ).to("cuda")
74
+ # Shape: [batch_size, seq_len, hidden_size]
75
+ outputs = model(**inputs).last_hidden_state.detach().cpu().numpy()
76
+ # average over sequence length
77
+ if aggregate == "mean":
78
+ emb = np.mean(outputs, axis=1).squeeze()
79
+ elif aggregate == "first":
80
+ emb = outputs[:, 0, :].squeeze() # use CLS token
81
+ embs.append(deepcopy(emb))
82
+ embs = np.concatenate(embs)
83
+ return embs
84
+
85
+
86
+ def get_embs_linear(texts: List[str]) -> np.ndarray:
87
+ """Get TF-IDF vectors for a list of texts.
88
+
89
+ Parameters
90
+ ----------
91
+ texts (List[str]): List of texts to get TF-IDF vectors for.
92
+
93
+ Returns
94
+ -------
95
+ embs: np.ndarray: TF-IDF vectors for the input texts.
96
+ """
97
+ vectorizer = TfidfVectorizer(
98
+ # tokenizer=AutoTokenizer.from_pretrained(checkpoint).tokenize,
99
+ # preprocessor=lambda x: x,
100
+ # token_pattern=None,
101
+ lowercase=False,
102
+ max_features=10000,
103
+ )
104
+ return vectorizer.fit_transform(texts).toarray()</ code > </ pre >
105
+ </ details >
106
+ </ section >
107
+ < section >
108
+ </ section >
109
+ < section >
110
+ </ section >
111
+ < section >
112
+ < h2 class ="section-title " id ="header-functions "> Functions</ h2 >
113
+ < dl >
114
+ < dt id ="imodelsx.embeddings.get_embs "> < code class ="name flex ">
115
+ < span > def < span class ="ident "> get_embs</ span > </ span > (< span > texts: List[str], checkpoint: str = 'bert-base-uncased', batch_size: int = 32, aggregate: str = 'mean') ‑> numpy.ndarray</ span >
116
+ </ code > </ dt >
117
+ < dd >
118
+ < div class ="desc "> < p > Get embeddings for a list of texts.</ p >
119
+ < h2 id ="params "> Params</ h2 >
120
+ < p > texts: List[str]: List of texts to get embeddings for.
121
+ checkpoint: str: Name of the checkpoint to use. Use tf-idf for linear embeddings.
122
+ batch_size: int: Batch size to use for inference.
123
+ aggregate: str: Aggregation method to use for the embeddings. Can be "mean" or "first" (to use CLS token for BERT).</ p > </ div >
124
+ < details class ="source ">
125
+ < summary >
126
+ < span > Expand source code</ span >
127
+ </ summary >
128
+ < pre > < code class ="python "> def get_embs(
129
+ texts: List[str], checkpoint: str = "bert-base-uncased", batch_size: int = 32,
130
+ aggregate: str = "mean"
131
+ ) -> np.ndarray:
132
+ '''
133
+ Get embeddings for a list of texts.
134
+
135
+ Params
136
+ ------
137
+ texts: List[str]: List of texts to get embeddings for.
138
+ checkpoint: str: Name of the checkpoint to use. Use tf-idf for linear embeddings.
139
+ batch_size: int: Batch size to use for inference.
140
+ aggregate: str: Aggregation method to use for the embeddings. Can be "mean" or "first" (to use CLS token for BERT).
141
+ '''
142
+ if checkpoint == "tf-idf":
143
+ return get_embs_linear(texts)
144
+
145
+ # load model
146
+ # get embeddings for each text from the corpus
147
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
148
+ model = AutoModel.from_pretrained(checkpoint).to("cuda")
149
+
150
+ # calculate embeddings
151
+ embs = []
152
+ for i in tqdm(range(0, len(texts), batch_size)):
153
+ t = texts[i: i + batch_size]
154
+ with torch.no_grad():
155
+ # tokenize
156
+ inputs = tokenizer(
157
+ t, return_tensors="pt", padding=True, truncation=True
158
+ ).to("cuda")
159
+ # Shape: [batch_size, seq_len, hidden_size]
160
+ outputs = model(**inputs).last_hidden_state.detach().cpu().numpy()
161
+ # average over sequence length
162
+ if aggregate == "mean":
163
+ emb = np.mean(outputs, axis=1).squeeze()
164
+ elif aggregate == "first":
165
+ emb = outputs[:, 0, :].squeeze() # use CLS token
166
+ embs.append(deepcopy(emb))
167
+ embs = np.concatenate(embs)
168
+ return embs</ code > </ pre >
169
+ </ details >
170
+ </ dd >
171
+ < dt id ="imodelsx.embeddings.get_embs_linear "> < code class ="name flex ">
172
+ < span > def < span class ="ident "> get_embs_linear</ span > </ span > (< span > texts: List[str]) ‑> numpy.ndarray</ span >
173
+ </ code > </ dt >
174
+ < dd >
175
+ < div class ="desc "> < p > Get TF-IDF vectors for a list of texts.</ p >
176
+ < h2 id ="parameters "> Parameters</ h2 >
177
+ < p > texts (List[str]): List of texts to get TF-IDF vectors for.</ p >
178
+ < h2 id ="returns "> Returns</ h2 >
179
+ < p > embs: np.ndarray: TF-IDF vectors for the input texts.</ p > </ div >
180
+ < details class ="source ">
181
+ < summary >
182
+ < span > Expand source code</ span >
183
+ </ summary >
184
+ < pre > < code class ="python "> def get_embs_linear(texts: List[str]) -> np.ndarray:
185
+ """Get TF-IDF vectors for a list of texts.
186
+
187
+ Parameters
188
+ ----------
189
+ texts (List[str]): List of texts to get TF-IDF vectors for.
190
+
191
+ Returns
192
+ -------
193
+ embs: np.ndarray: TF-IDF vectors for the input texts.
194
+ """
195
+ vectorizer = TfidfVectorizer(
196
+ # tokenizer=AutoTokenizer.from_pretrained(checkpoint).tokenize,
197
+ # preprocessor=lambda x: x,
198
+ # token_pattern=None,
199
+ lowercase=False,
200
+ max_features=10000,
201
+ )
202
+ return vectorizer.fit_transform(texts).toarray()</ code > </ pre >
203
+ </ details >
204
+ </ dd >
205
+ </ dl >
206
+ </ section >
207
+ < section >
208
+ </ section >
209
+ </ article >
210
+ < nav id ="sidebar ">
211
+ < h1 > Index</ h1 >
212
+ < div class ="toc ">
213
+ < ul > </ ul >
214
+ </ div >
215
+ < ul id ="index ">
216
+ < li > < h3 > Super-module</ h3 >
217
+ < ul >
218
+ < li > < code > < a title ="imodelsx " href ="index.html "> imodelsx</ a > </ code > </ li >
219
+ </ ul >
220
+ </ li >
221
+ < li > < h3 > < a href ="#header-functions "> Functions</ a > </ h3 >
222
+ < ul class ="">
223
+ < li > < code > < a title ="imodelsx.embeddings.get_embs " href ="#imodelsx.embeddings.get_embs "> get_embs</ a > </ code > </ li >
224
+ < li > < code > < a title ="imodelsx.embeddings.get_embs_linear " href ="#imodelsx.embeddings.get_embs_linear "> get_embs_linear</ a > </ code > </ li >
225
+ </ ul >
226
+ </ li >
227
+ </ ul >
228
+ </ nav >
229
+ </ main >
230
+ < footer id ="footer ">
231
+ < p > Generated by < a href ="https://pdoc3.github.io/pdoc " title ="pdoc: Python API documentation generator "> < cite > pdoc</ cite > 0.10.0</ a > .</ p >
232
+ </ footer >
233
+ </ body >
234
+ </ html >
0 commit comments