Skip to content

MurrellGroup/ChainStorm.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ChainStorm

Stable Dev Build Status Coverage

This repo implements a structure/sequence co-design model, using diffusion/flow matching (from Flowfusion.jl) with an architecture based primarily on AlphaFold 2's Invariant Point Attention (here via InvariantPointAttention.jl). The protein backbone is represented as a sequence of "frames", each with a location and rotation, as well as a discrete amino acid character. The model is trained to take noised input (where the locations, rotations, and discrete states have all been perturbed, to a random degree, by a noising process) and predict the original (i.e. un-noised) protein structure. With a model thus trained, samples from the distribution of training structures can be generated by taking many small steps from a random starting distribution.

ProteinDesignAnimation.mp4

New to Julia?

Go here for instructions on how to install Julia (use juliaup), and you can run the code snippets below directly in the Julia REPL.

ChainStorm installation

using Pkg
pkg"registry add https://github.com/MurrellGroup/MurrellGroupRegistry"
#Pkg.add(["CUDA", "cuDNN"]) #<- If GPU
Pkg.add(url = "https://github.com/MurrellGroup/ChainStorm.jl")

Quick start

This will load up a model and generate a single small protein with two chains, each of length 20:

using ChainStorm
model = load_model()
b = dummy_batch([20,20]) #<- The model's only input
g = flow_quickgen(b, model) #<- Model inference call
export_pdb("gen.pdb", g, b.chainids, b.resinds) #<- Save PDB

Or try this in a minimal Colab notebook:

Open In Colab

Visualization, and using the GPU

using Pkg
Pkg.add(["GLMakie", "ProtPlot"])

using ChainStorm, GLMakie, ProtPlot

#If GPU:
using CUDA
dev = ChainStorm.gpu
#dev = identity #<- If no GPU

model = load_model() |> dev

chainlengths = [54,54]
b = dummy_batch(chainlengths)
paths = ChainStorm.Tracker() #The trajectories will end up in here
g = flow_quickgen(b, model, d = dev, tracker = paths) #<- Model inference call
id = join(string.(chainlengths),"_")*"-"*join(rand('A':'Z', 4))
export_pdb("$(id).pdb", g, b.chainids, b.resinds) #<- Save PDB
samp = gen2prot(g, b.chainids, b.resinds)
animate_trajectory("$(id).mp4", samp, first_trajectory(paths), viewmode = :fit) #<- Animate design process

Note: If you need the animations via GLMakie to run headless, in linux you can install xvfb, then run these in the terminal before starting your Julia session/script:

Xvfb :99 -screen 0 1024x768x24 &
export DISPLAY=:99

Training

#In addition to ChainStorm, also install these:
using Pkg
Pkg.add(["JLD2", "Flux", "CannotWaitForTheseOptimisers", "LearningSchedules", "DLProteinFormats"])
Pkg.add(["CUDA", "cuDNN"])

using ChainStorm, DLProteinFormats, Flux, CannotWaitForTheseOptimisers, LearningSchedules, JLD2
using DLProteinFormats: load, PDBSimpleFlat, batch_flatrecs, sample_batched_inds, length2batch
using CUDA
device = gpu

dat = load(PDBSimpleFlat);

model = ChainStormV1(384, 3, 3) |> device
sched = burnin_learning_schedule(0.000005f0, 0.001f0, 1.05f0, 0.99995f0)
opt_state = Flux.setup(Muon(eta = sched.lr), model)

for epoch in 1:100
    batchinds = sample_batched_inds(dat,l2b = length2batch(1500, 1.9))
    for (i, b) in enumerate(batchinds)
        bat = batch_flatrecs(dat[b])
        ts = training_sample(bat) |> device
        sc_frames = nothing
        if epoch > 1 && rand() < 0.5
            sc_frames, _ = model(ts.t, ts.Xt, ts.chainids, ts.resinds)
        end
        l, grad = Flux.withgradient(model) do m
            fr, aalogs = m(ts.t, ts.Xt, ts.chainids, ts.resinds, sc_frames = sc_frames)
            l_loc, l_rot, l_aas = losses(fr, aalogs, ts)
            l_loc + l_rot + l_aas
        end
        Flux.update!(opt_state, model, grad[1])
        (mod(i, 10) == 0) && Flux.adjust!(opt_state, next_rate(sched))
        println(l)
    end
    jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state))
end

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Contributors 2

  •  
  •