diff --git a/examples/large_scale_demo.py b/examples/large_scale_demo.py new file mode 100644 index 0000000..543b4e7 --- /dev/null +++ b/examples/large_scale_demo.py @@ -0,0 +1,344 @@ +# basic imports +import numpy as np +import os +from pathlib import Path +import gc +import pandas as pd +import time +import pickle + + +# Import stopwords +import nltk +nltk.download('stopwords') +from nltk.corpus import stopwords + +# Cuda imports +import cupy as cp +import cudf +from cudf import Series +from cuml.feature_extraction.text import CountVectorizer +from cuml.preprocessing.text.stem import PorterStemmer + +# Import TensorLy +import tensorly as tl + +# Import utility functions from other files +from .tlda.tlda_wrapper import TLDA +from .tlda.file_operations import get_files_in_dir + +# Root Filepath -- can modify +ROOT_DIR = "/Users/skangaslahti/tlda/data" + +# Data Relative Paths -- can modify +INDIR = "MeTooMonthCleaned/" + +# Output Relative paths -- do not change +X_MAT_FILEPATH_PREFIX = "x_mat/" +X_FILEPATH = "X_full.obj" +X_DF_FILEPATH = "X_df.obj" +X_LST_FILEPATH = "X_lst.obj" +CORPUS_FILEPATH_PREFIX = "corpus/" +GENSIM_CORPUS_FILEPATH = "corpus.obj" +COUNTVECTOR_FILEPATH = "countvec.obj" +TLDA_FILEPATH = "tlda.obj" +VOCAB_FILEPATH = "vocab.csv" +EXISTING_VOCAB_FILEPATH = "vocab.obj" +TOPIC_FILEPATH_PREFIX = 'predicted_topics/' +DOCUMENT_TOPIC_FILEPATH = 'dtm.csv' +COHERENCE_FILEPATH = 'coherence.obj' +DOCUMENT_TOPIC_FILEPATH_TOT = 'dtm_df.csv' +OUT_ID_DATA_PREFIX = 'ids/' +TOP_WORDS_FILEPATH ='top_words.csv' + +# Device settings +backend="cupy" +tl.set_backend(backend) +device = 'cuda' +porter = PorterStemmer() + + +def basic_clean(df): + df['tweets'] = df['tweets'].astype('str') + df = df.drop_duplicates(keep="first") + return df + + +def partial_fit(self , data): + if(hasattr(self , 'vocabulary_')): + vocab = self.vocabulary_ # series + else: + vocab = Series() + self.fit(data) + vocab = vocab.append(self.vocabulary_) + self.vocabulary_ = vocab.unique() + + +# declare the stop words +# potentially add extra stop words depending on the application dataset +stop_words = (stopwords.words('english')) +added_words = [] + +# set stop words and countvectorizer method +stop_words= list(np.append(stop_words,added_words)) +CountVectorizer.partial_fit = partial_fit + +# define function with no preprocessing +def custom_preprocessor(doc): + return doc + + +def fit_topics(num_tops, curr_dir, alpha_0 = 0.01, learning_rate = 0.0004, theta_param = 5.005, ortho_loss_param = 1000, smoothing = 1e-5, initialize_first_docs = False, n_eigenvec = None): + + # make final directories for outputs + save_dir = os.path.join(ROOT_DIR, curr_dir) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + # initialize RAPIDS CountVectorizer + countvec = CountVectorizer( stop_words = stop_words, + lowercase = True, + ngram_range = (1, 2), + preprocessor = custom_preprocessor, + max_df = 0.5, + min_df = 0.00125) + + # set directory for saving CountVectorizer and TLDA + eigenvec_str = "_n_eigenvec_" + (str(n_eigenvec) if n_eigenvec is not None else "None") + exp_save_dir = os.path.join(save_dir, "num_tops_" + str(num_tops) + "_alpha0_" + str(alpha_0) + "_learning_rate_" + str(learning_rate) + "_theta_" + str(theta_param) + "_orthogonality_" + str(ortho_loss_param) + "_initialize_first_docs_" + str(initialize_first_docs) + eigenvec_str + "/") + if not os.path.exists(exp_save_dir): + os.makedirs(exp_save_dir) + + # DEFAULT PARAMS -- Grid search according to dataset + batch_size_pca = 100000 + batch_size_grad = 80000 + n_iter_train = 200 + n_iter_test = 10 + + #SET SEED + seed = 57 + + # Program controls -- decide which portions to run + if os.path.exists(save_dir + "/" + COUNTVECTOR_FILEPATH): + first_run = 1 + vocab_build = first_run + save_files = first_run + stgd = 1 + recover_top_words = 1 + + # Start + print("\n\nSTART...") + + # Set files to read + inDir = os.path.join(ROOT_DIR, INDIR) + dl = sorted(get_files_in_dir(inDir)) + + # Build the vocabulary + if vocab_build == 1: + if not os.path.exists(save_dir + "/" + EXISTING_VOCAB_FILEPATH): + for i, f in enumerate(dl): + print("Beginning vocabulary build: " + f) + path_in = os.path.join(inDir,f) + + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + pinned_mempool = cp.get_default_pinned_memory_pool() + pinned_mempool.free_all_blocks() + + # read in dataframe + df = pd.read_csv(path_in, names = ['tweets']) + + # basic preprocessing + mask = df['tweets'].str.len() > 10 + df = df.loc[mask] + df = cudf.from_pandas(df) + df = basic_clean(df) + + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + pinned_mempool = cp.get_default_pinned_memory_pool() + pinned_mempool.free_all_blocks() + gc.collect() + + # add vocabulary from current file to CountVectorizer vocabulary + countvec.partial_fit(df['tweets']) + print("End " + f) + + # count rows of data + num_data_rows += len(df.index) + print(num_data_rows) + print(len(df.index)) + else: + countvec.vocabulary_ = countvec.vocabulary + vocab = len(countvec.vocabulary_) + + # Save fitted CountVectorizer and vocabulary + pickle.dump(countvec, open(os.path.join(save_dir, COUNTVECTOR_FILEPATH), 'wb')) + vocab = len(countvec.vocabulary_) + df_voc = cudf.DataFrame({'words':countvec.vocabulary_}) + df_voc.to_csv(save_dir + "/" + VOCAB_FILEPATH) + print("right after countvec partial fit vocab\n\n\n: ", vocab) + + # make directories to save: + # - X matrices + # - corpus (only needed if computing coherence) + x_mat_dir = os.path.join(save_dir, X_MAT_FILEPATH_PREFIX) + if not os.path.exists(x_mat_dir): + os.makedirs(x_mat_dir) + corpus_dir = os.path.join(save_dir, CORPUS_FILEPATH_PREFIX) + if not os.path.exists(corpus_dir): + os.makedirs(corpus_dir) + + + # transform X matrices with fitted CountVectorizer and save to disk + transform_time = 0.0 + if save_files == 1: + for f in dl: + print("Beginning CountVectorizer transform: " + f) + path_in = os.path.join(inDir,f) + + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + pinned_mempool = cp.get_default_pinned_memory_pool() + pinned_mempool.free_all_blocks() + + # read in dataframe + df = pd.read_csv(path_in, names = ['tweets']) + + # basic preprocessing + mask = df['tweets'].str.len() > 10 + df = df.loc[mask] + df = cudf.from_pandas(df) + df = basic_clean(df) + + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + gc.collect() + + # transform data from current file + t1 = time.time() + corpus = countvec.transform(df['tweets']) + t2 = time.time() + transform_time += t2 - t1 + X_batch = tl.tensor(corpus.toarray()) + + # save current X matrix and corpus to disk + pickle.dump( + (X_batch), + open(x_mat_dir + Path(f).stem + '.obj','wb') + ) + pickle.dump( + (corpus), + open(corpus_dir + Path(f).stem + '.obj','wb') + ) + del X_batch + del corpus + print("End " + f) + del df + del mask + + gc.collect() + + print("Transform Time:" + str(transform_time)) + + + # initialize TLDA using parameters from above + tlda = TLDA( + num_tops, alpha_0, n_iter_train, n_iter_test,learning_rate, + pca_batch_size = batch_size_pca, third_order_cumulant_batch = batch_size_grad, + gamma_shape = 1.0, smoothing = smoothing, theta=theta_param, ortho_loss_criterion = ortho_loss_param, random_seed = seed, + n_eigenvec = n_eigenvec, + ) + + tot_tlda_time = 0.0 + if stgd == 1: + # keep track of iterations + i = 0 + + t1 = time.time() + for f in dl: + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + pinned_mempool = cp.get_default_pinned_memory_pool() + pinned_mempool.free_all_blocks() + + print("Beginning TLDA: " + f) + + # load saved X matrix batch from disk + X_batch = pickle.load( + open(save_dir + X_MAT_FILEPATH_PREFIX + Path(f).stem + '.obj','rb') + ) + + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + pinned_mempool = cp.get_default_pinned_memory_pool() + pinned_mempool.free_all_blocks() + gc.collect() + + + t3 = time.time() + # fit tensor LDA fully online + if initialize_first_docs and i == 0: + # fully fit tensor LDA on first batch + tlda.fit(X_batch) + else: + # partial fit tensor LDA on remaining batches + tlda.partial_fit_online(X_batch) + + t4 = time.time() + print("New fit time" + str(t4-t3)) + tot_tlda_time += t4-t3 + + del X_batch + gc.collect() + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + pinned_mempool = cp.get_default_pinned_memory_pool() + pinned_mempool.free_all_blocks() + + i += 1 + else: + tlda = pickle.load(open(exp_save_dir + TLDA_FILEPATH,'rb')) + + + # save top words in each topic + if recover_top_words == 1: + n_top_words = 100 + + top_words_df = cudf.DataFrame({}) + for k in range(0,num_tops): + t_n_indices = tlda.unwhitened_factors_[:,k].argsort()[:-n_top_words - 1:-1] + top_words_LDA = countvec.vocabulary_[t_n_indices] + top_words_df['words_'+str(k)] = top_words_LDA.reset_index(drop=True) + + + top_words_df.to_csv(exp_save_dir + TOP_WORDS_FILEPATH) + del top_words_df + + gc.collect() + mempool = cp.get_default_memory_pool() + mempool.free_all_blocks() + pinned_mempool = cp.get_default_pinned_memory_pool() + pinned_mempool.free_all_blocks() + + +def main(): + curr_dir = "metoo_evaluation_initialized_paper_exps/" + + # set parameters + num_tops = 10 + alpha_0 = 0.01 + lr = 0.0001 + pca_dim = 40 + + # run method to fit topics and save top words in each topic + fit_topics( + num_tops = num_tops, + curr_dir = curr_dir, + alpha_0 = alpha_0, + learning_rate = lr, + n_eigenvec = pca_dim + ) + +if __name__ == "__main__": + main() diff --git a/tlda/third_order_cumulant.py b/tlda/third_order_cumulant.py index cf33e0a..7fe2a5c 100755 --- a/tlda/third_order_cumulant.py +++ b/tlda/third_order_cumulant.py @@ -123,7 +123,7 @@ def fit(self, X, verbose = True): print("Total iterations: " + str(i)) - def _predict_topic(self, X_batch, weights, verbose = True): + def _predict_topic(self, X_batch, adjusted_factors, weights): '''Infer the document-topic distribution vector for a given document Parameters @@ -143,33 +143,39 @@ def _predict_topic(self, X_batch, weights, verbose = True): # factors = nvocab x ntopics n_topics = self.n_topic n_docs = X_batch.shape[0] + n_words = X_batch.shape[1] - gammad = tl.tensor(tl.gamma(self.gamma_shape, scale= 1.0/self.gamma_shape, size = (n_docs,n_topics))) - exp_elogthetad = tl.exp(dirichlet_expectation(gammad)) #ndocs, n_topics ## CONVERT TO TL - phinorm = (tl.matmul(exp_elogthetad,self.unwhitened_factors_.T) + 1e-20) #ndoc X nwords + #print("adjusted factors shape: " + str(adjusted_factors.shape)) + + #gammad = tl.tensor(cp.random.gamma(self.gamma_shape, scale= 1.0/self.gamma_shape, size = (n_docs,n_topics))) ## not working + gammad = tl.tensor(cp.random.gamma(self.gamma_shape, scale= 1.0/self.gamma_shape, size = (n_docs,n_topics))) + exp_elogthetad = tl.tensor(cp.exp(dirichlet_expectation(gammad))) #ndocs, n_topics ## CONVERT TO TL + #exp_elogbetad = tl.tensor(cp.exp(dirichlet_expectation(adjusted_factors.T))) + + epsilon = tl.finfo(gammad.dtype).eps + phinorm = (tl.matmul(exp_elogthetad,adjusted_factors.T) + epsilon) #ndoc X nwords max_gamma_change = 1.0 - iter = 0 - if verbose: - print("Begin Document Topic Prediction") - while (max_gamma_change > 1e-2 and iter < self.n_iter_test): + #epsilon = tl.finfo(gammad.dtype).eps + i = 0 + print("Begin Document Topic Prediction") + while (max_gamma_change > 5e-3 and i < self.n_iter_test): lastgamma = tl.copy(gammad) x_phi_norm = X_batch / phinorm - x_phi_norm_factors = tl.matmul(x_phi_norm,self.unwhitened_factors_) + x_phi_norm_factors = tl.matmul(x_phi_norm, adjusted_factors) gammad = ((exp_elogthetad * (x_phi_norm_factors)) + weights) # estimate for the variational mixing param - exp_elogthetad = tl.exp(dirichlet_expectation(gammad)) - phinorm = tl.matmul(exp_elogthetad,self.unwhitened_factors_.T) + 1e-20 + exp_elogthetad = tl.tensor(cp.exp(dirichlet_expectation(gammad))) ## CONVERT TO TL + phinorm = (tl.matmul(exp_elogthetad,adjusted_factors.T) + epsilon) mean_gamma_change_pdoc = tl.sum(tl.abs(gammad - lastgamma),axis=1) / n_topics max_gamma_change = tl.max(mean_gamma_change_pdoc) - iter += 1 - if verbose: - print("End Document Topic Prediction Iteration " + str(iter) +" out of "+str(self.n_iter_test)) - print("Current Maximal Change:" + str(max_gamma_change)) + i += 1 + print("End Document Topic Prediction Iteration " + str(i) +" out of "+str(self.n_iter_test)) + print("Current Maximal Change:" + str(max_gamma_change)) del X_batch return gammad - def predict(self, X_test, weights): + def predict(self, X_test, adjusted_factors, weights): '''Infer the document/topic distribution from the factors and weights and make the factor non-negative @@ -180,18 +186,36 @@ def predict(self, X_test, weights): Returns ------- - gammad_norm2 : tensor of shape (number_documents, number_topics) equal to + gammad : tensor of shape (number_documents, number_topics) equal to the normalized document/topic distribution for X_test - - factor : tensor of shape (vocabulary_size, number_topics) equal to the - adjusted factor ''' - gammad_l = self._predict_topic(X_test, weights) - gammad_norm = tl.exp(dirichlet_expectation(gammad_l)) ## CONVERT TO TL - reshape_obj = tl.sum(gammad_norm,axis=1) - denom = tl.reshape(reshape_obj,(-1,1)) - gammad_norm2 = gammad_norm/denom + # factors = nvocab x ntopics + n_topics = self.n_topic + n_docs = X_test.shape[0] + + gammad = tl.gamma(self.gamma_shape, scale= 1.0/self.gamma_shape, size = (n_docs,n_topics)) + exp_elogthetad = tl.exp(dirichlet_expectation(gammad)) #ndocs, n_topic + + epsilon = tl.finfo(gammad.dtype).eps + phinorm = (tl.matmul(exp_elogthetad,adjusted_factors.T) + epsilon) #ndoc X nwords + max_gamma_change = 1.0 + i = 0 + print("Begin Document Topic Prediction") + while (max_gamma_change > 5e-3 and i < self.n_iter_test): + lastgamma = tl.copy(gammad) + x_phi_norm = X_test / phinorm + x_phi_norm_factors = tl.matmul(x_phi_norm, adjusted_factors) + gammad = ((exp_elogthetad * (x_phi_norm_factors)) + weights) # estimate for the variational mixing param + exp_elogthetad = tl.exp(dirichlet_expectation(gammad)) + phinorm = (tl.matmul(exp_elogthetad,adjusted_factors.T) + epsilon) + + mean_gamma_change_pdoc = tl.sum(tl.abs(gammad - lastgamma),axis=1) / n_topics + max_gamma_change = tl.max(mean_gamma_change_pdoc) + i += 1 + print("End Document Topic Prediction Iteration " + str(i) +" out of "+str(self.n_iter_test)) + print("Current Maximal Change:" + str(max_gamma_change)) + del X_test - return gammad_norm2 + return gammad diff --git a/tlda/tlda_wrapper.py b/tlda/tlda_wrapper.py index f9f737a..05e6a09 100755 --- a/tlda/tlda_wrapper.py +++ b/tlda/tlda_wrapper.py @@ -160,6 +160,9 @@ def _unwhiten_factors(self): # Un-centers the data factors_unwhitened += tl.reshape(self.mean,(self.vocab,1)) factors_unwhitened [factors_unwhitened < 0.] = 0. # remove non-negative probabilities + + # Save unwhitened factors before postprocessing + self.unwhitened_factors_raw_ = tl.copy(factors_unwhitened) # Smoothing factors_unwhitened *= (1. - self.smoothing) @@ -173,12 +176,12 @@ def _unwhiten_factors(self): self.weights_ = tl.tensor(alpha_norm) # Normalize the factors - factors_unwhitened /= factors_unwhitened.sum(axis=0) + return factors_unwhitened @property - def unwhitened_factors(self): # This doesnt work + def unwhitened_factors(self): """Unwhitened learned factors of shape (n_topic, vocabulary_size) On the first call, this will compute and store the unwhitened factors. @@ -186,8 +189,7 @@ def unwhitened_factors(self): # This doesnt work """ if self.unwhitened_factors_ is None: self.unwhitened_factors_ = self._unwhiten_factors() - else: - return self.unwhitened_factors_ + return self.unwhitened_factors_ def transform(self, X=None, predict=True): """ @@ -202,7 +204,7 @@ def transform(self, X=None, predict=True): self.third_order.unwhitened_factors_ = self.unwhitened_factors_ if predict: - predicted_topics = self.third_order.predict(X, self.weights_) + predicted_topics = self.third_order.predict(X, self.unwhitened_factors_raw_, self.weights_) return predicted_topics - return predicted_topics + return