33import numpy as np
44from gensim import utils , matutils
55from gensim .models import ldamodel
6- from .ldaseq_sslm_inner import fit_sslm
6+ from .ldaseq_sslm_inner import fit_sslm , sslm_counts_init
77from .ldaseq_posterior_inner import fit_lda_post
88
99logger = logging .getLogger (__name__ )
@@ -670,157 +670,8 @@ def __init__(self, vocab_len=None, num_time_slices=None, num_topics=None, obs_va
670670 self .w_phi_sum = None
671671 self .w_phi_l_sq = None
672672 self .m_update_coeff_g = None
673+ self .config_c_address = 0
673674
674- def update_zeta (self ):
675- """Update the Zeta variational parameter.
676-
677- Zeta is described in the appendix and is equal to sum (exp(mean[word] + Variance[word] / 2)),
678- over every time-slice. It is the value of variational parameter zeta which maximizes the lower bound.
679-
680- Returns
681- -------
682- list of float
683- The updated zeta values for each time slice.
684-
685- """
686- for j , val in enumerate (self .zeta ):
687- self .zeta [j ] = np .sum (np .exp (self .mean [:, j + 1 ] + self .variance [:, j + 1 ] / 2 ))
688- return self .zeta
689-
690- def compute_post_variance (self , word , chain_variance ):
691- r"""Get the variance, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1)
692- <https://mimno.infosci.cornell.edu/info6150/readings/dynamic_topic_models.pdf>`_.
693-
694- This function accepts the word to compute variance for, along with the associated sslm class object,
695- and returns the `variance` and the posterior approximation `fwd_variance`.
696-
697- Notes
698- -----
699- This function essentially computes Var[\beta_{t,w}] for t = 1:T
700-
701- .. :math::
702-
703- fwd\_variance[t] \equiv E((beta_{t,w}-mean_{t,w})^2 |beta_{t}\ for\ 1:t) =
704- (obs\_variance / fwd\_variance[t - 1] + chain\_variance + obs\_variance ) *
705- (fwd\_variance[t - 1] + obs\_variance)
706-
707- .. :math::
708-
709- variance[t] \equiv E((beta_{t,w}-mean\_cap_{t,w})^2 |beta\_cap_{t}\ for\ 1:t) =
710- fwd\_variance[t - 1] + (fwd\_variance[t - 1] / fwd\_variance[t - 1] + obs\_variance)^2 *
711- (variance[t - 1] - (fwd\_variance[t-1] + obs\_variance))
712-
713- Parameters
714- ----------
715- word: int
716- The word's ID.
717- chain_variance : float
718- Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
719-
720- Returns
721- -------
722- (numpy.ndarray, numpy.ndarray)
723- The first returned value is the variance of each word in each time slice, the second value is the
724- inferred posterior variance for the same pairs.
725-
726- """
727- INIT_VARIANCE_CONST = 1000
728-
729- T = self .num_time_slices
730- variance = self .variance [word ]
731- fwd_variance = self .fwd_variance [word ]
732- # forward pass. Set initial variance very high
733- fwd_variance [0 ] = chain_variance * INIT_VARIANCE_CONST
734- for t in range (1 , T + 1 ):
735- if self .obs_variance :
736- c = self .obs_variance / (fwd_variance [t - 1 ] + chain_variance + self .obs_variance )
737- else :
738- c = 0
739- fwd_variance [t ] = c * (fwd_variance [t - 1 ] + chain_variance )
740-
741- # backward pass
742- variance [T ] = fwd_variance [T ]
743- for t in range (T - 1 , - 1 , - 1 ):
744- if fwd_variance [t ] > 0.0 :
745- c = np .power ((fwd_variance [t ] / (fwd_variance [t ] + chain_variance )), 2 )
746- else :
747- c = 0
748- variance [t ] = (c * (variance [t + 1 ] - chain_variance )) + ((1 - c ) * fwd_variance [t ])
749-
750- return variance , fwd_variance
751-
752- def compute_post_mean (self , word , chain_variance ):
753- """Get the mean, based on the `Variational Kalman Filtering approach for Approximate Inference (section 3.1)
754- <https://mimno.infosci.cornell.edu/info6150/readings/dynamic_topic_models.pdf>`_.
755-
756- Notes
757- -----
758- This function essentially computes E[\b eta_{t,w}] for t = 1:T.
759-
760- .. :math::
761-
762- Fwd_Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:t )
763- = (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance ) * fwd_mean[t - 1] +
764- (1 - (obs_variance / fwd_variance[t - 1] + chain_variance + obs_variance)) * beta
765-
766- .. :math::
767-
768- Mean(t) ≡ E(beta_{t,w} | beta_ˆ 1:T )
769- = fwd_mean[t - 1] + (obs_variance / fwd_variance[t - 1] + obs_variance) +
770- (1 - obs_variance / fwd_variance[t - 1] + obs_variance)) * mean[t]
771-
772- Parameters
773- ----------
774- word: int
775- The word's ID.
776- chain_variance : float
777- Gaussian parameter defined in the beta distribution to dictate how the beta values evolve over time.
778-
779- Returns
780- -------
781- (numpy.ndarray, numpy.ndarray)
782- The first returned value is the mean of each word in each time slice, the second value is the
783- inferred posterior mean for the same pairs.
784-
785- """
786- T = self .num_time_slices
787- obs = self .obs [word ]
788- fwd_variance = self .fwd_variance [word ]
789- mean = self .mean [word ]
790- fwd_mean = self .fwd_mean [word ]
791-
792- # forward
793- fwd_mean [0 ] = 0
794- for t in range (1 , T + 1 ):
795- c = self .obs_variance / (fwd_variance [t - 1 ] + chain_variance + self .obs_variance )
796- fwd_mean [t ] = c * fwd_mean [t - 1 ] + (1 - c ) * obs [t - 1 ]
797-
798- # backward pass
799- mean [T ] = fwd_mean [T ]
800- for t in range (T - 1 , - 1 , - 1 ):
801- if chain_variance == 0.0 :
802- c = 0.0
803- else :
804- c = chain_variance / (fwd_variance [t ] + chain_variance )
805- mean [t ] = c * fwd_mean [t ] + (1 - c ) * mean [t + 1 ]
806- return mean , fwd_mean
807-
808- def compute_expected_log_prob (self ):
809- """Compute the expected log probability given values of m.
810-
811- The appendix describes the Expectation of log-probabilities in equation 5 of the DTM paper;
812- The below implementation is the result of solving the equation and is implemented as in the original
813- Blei DTM code.
814-
815- Returns
816- -------
817- numpy.ndarray of float
818- The expected value for the log probabilities for each word and time slice.
819-
820- """
821- for (w , t ), val in np .ndenumerate (self .e_log_prob ):
822- self .e_log_prob [w ][t ] = self .mean [w ][t + 1 ] - np .log (self .zeta [t ])
823- return self .e_log_prob
824675
825676 def sslm_counts_init (self , obs_variance , chain_variance , sstats ):
826677 """Initialize the State Space Language Model with LDA sufficient statistics.
@@ -839,28 +690,8 @@ def sslm_counts_init(self, obs_variance, chain_variance, sstats):
839690 expected shape (`self.vocab_len`, `num_topics`).
840691
841692 """
842- W = self .vocab_len
843- T = self .num_time_slices
844-
845- log_norm_counts = np .copy (sstats )
846- log_norm_counts /= sum (log_norm_counts )
847- log_norm_counts += 1.0 / W
848- log_norm_counts /= sum (log_norm_counts )
849- log_norm_counts = np .log (log_norm_counts )
850-
851- # setting variational observations to transformed counts
852- self .obs = (np .repeat (log_norm_counts , T , axis = 0 )).reshape (W , T )
853- # set variational parameters
854- self .obs_variance = obs_variance
855- self .chain_variance = chain_variance
856-
857- # # compute post variance, mean
858- for w in range (W ):
859- self .variance [w ], self .fwd_variance [w ] = self .compute_post_variance (w , self .chain_variance )
860- self .mean [w ], self .fwd_mean [w ] = self .compute_post_mean (w , self .chain_variance )
861693
862- self .zeta = self .update_zeta ()
863- self .e_log_prob = self .compute_expected_log_prob ()
694+ sslm_counts_init (self , obs_variance , chain_variance , sstats )
864695
865696 def fit_sslm (self , sstats ):
866697 """Fits variational distribution.
0 commit comments