Skip to content

ibsi6/LightWeightDiffusionModel

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

README.


This program is an implementation of the diffusion probabilistic model to generate 2D gaussian data. https://arxiv.org/abs/1503.03585

The implementation is done entirely using JAX, Google's numerical computing library, and does not rely on external deep learning libraries like TensorFlow or PyTorch.

To run, simply run main.py / install any required libaries (See JAX documentation: https://github.com/google/jax?tab=readme-ov-file#installation)




IMPORTANT: Each step in the process will not continue unless current graph window has been closed. (This was done to reduce computational resources and prevent crashing). If it looks as if program is not running, make sure current graph window is closed.


Contributions: Credit to Muhammed.S for Documentation


--------------------------------------------------------

You can also:

    Adjust Hyperparameters: Modify the number of epochs, batch size, learning rate, and network dimensions in main.py or the respective classes.
    Change Data Distribution: Replace the synthetic data generation with your own dataset or a different data distribution.



--------------------------------------------------------
Classes:

	DiffusionModel

		Methods:

		    __init__(self, input_dim, hidden_dim, output_dim, key):
		        Initializes the model parameters, including the weights and biases for each layer of the neural network.
		        Takes in the input dimension, hidden dimension, output dimension, and a random key for parameter initialization.

		    init_params(self, input_dim, hidden_dim, output_dim, key):
		        Initializes the weights (W) and biases (b) for a simple feedforward neural network. Each weight matrix is initialized using normal distributions scaled by a factor that accounts for the size of the input.
		        This method returns a dictionary containing the model parameters.

		    forward(self, params, x, t):
		        Implements the forward pass of the neural network. The inputs are the data x and the time step t (which is embedded and passed through the network).
		        The forward pass combines the input data with the time embedding and passes it through three hidden layers, each followed by a ReLU activation. The final layer outputs the predicted noise at the current time step t.
		        The time embedding helps the network know at which diffusion step the input data is.

    DiffusionProcess

		Methods:

		    __init__(self, T, beta_start=1e-4, beta_end=0.02):
		        Initializes the diffusion process by setting the number of diffusion steps (T), as well as the range of beta values (beta_start, beta_end) which control the amount of noise added at each step.
		        It computes the betas (noise schedule), the alphas (1 - betas), and their cumulative products (alphas_cumprod), which are used to calculate the variance and scaling factors for the noise.

		    q_sample(self, x0, t, noise):
		        This method performs the forward diffusion step, where it adds noise to the original data x0 at a given time step t.
		        Uses the cumulative product of alphas to scale the data and the noise appropriately at each time step.
		        This method returns the noisy data for the time step t.




    Trainer

		Methods:

		    __init__(self, model, diffusion_process, data, key, learning_rate=1e-3, batch_size=128):
		        Initializes the trainer by setting the model, diffusion process, training data, and various training hyperparameters (learning rate, batch size).
		        Also, it initializes the optimizer state (which includes parameters, momentum (m), and velocity (v) for the Adam optimizer).

		    init_optimizer(self, params):
		        Initializes the optimizer state for the Adam optimizer, setting the momentum and velocity to zero, and storing the model parameters.
		        Returns a dictionary containing the optimizer state.

		    update_optimizer(self, grads, opt_state):
		        Updates the model parameters using the Adam optimizer. It computes the new values of momentum (m) and velocity (v) based on the gradients and updates the parameters accordingly.
		        The method returns the updated optimizer state.

		    loss_fn(self, params, x0, t, noise):
		        Defines the loss function, which is the mean squared error (MSE) between the predicted noise and the actual noise added to the data during the diffusion process.
		        It first generates noisy data xt for a given batch of original data x0 and time steps t, then computes the MSE loss between the predicted noise and the actual noise.

		    train_step(self, opt_state, x0_batch):
		        Performs a single training step by selecting a batch of data, adding noise to it, and computing the loss and gradients.
		        Updates the model parameters using the gradients and the Adam optimizer.
		        The method returns the updated optimizer state and the loss value for the current batch.

		    train(self, num_epochs):
		        Runs the full training loop for a specified number of epochs. For each epoch, the training data is shuffled, split into batches, and the model is updated using the train_step method.
		        It prints the average loss at the end of each epoch.
    Sampler

		Methods:

		    __init__(self, model, diffusion_process, key):
		        Initializes the sampler with the trained model, diffusion process, and a random key for generating noise during sampling.

		    sample(self, params, num_samples):
		        This method implements the reverse diffusion process to generate new data samples.
		        It starts with random noise and iteratively denoises it by predicting the noise at each time step using the trained model.
		        The method handles adding noise at each time step and scaling the predictions by the appropriate factors (i.e., beta_t, alpha_t, and alpha_bar_t).
		        The process is repeated for a predefined number of steps, and the method returns the generated samples at the end.

About

Light Weight diffusion probablistic model. Uses only jax, numpy, and matplot.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages