This repo implements a structure/sequence co-design model, using diffusion/flow matching (from Flowfusion.jl) with an architecture based primarily on AlphaFold 2's Invariant Point Attention (here via InvariantPointAttention.jl). The protein backbone is represented as a sequence of "frames", each with a location and rotation, as well as a discrete amino acid character. The model is trained to take noised input (where the locations, rotations, and discrete states have all been perturbed, to a random degree, by a noising process) and predict the original (i.e. un-noised) protein structure. With a model thus trained, samples from the distribution of training structures can be generated by taking many small steps from a random starting distribution.
ProteinDesignAnimation.mp4
Go here for instructions on how to install Julia (use juliaup
), and you can run the code snippets below directly in the Julia REPL.
using Pkg
pkg"registry add https://github.com/MurrellGroup/MurrellGroupRegistry"
#Pkg.add(["CUDA", "cuDNN"]) #<- If GPU
Pkg.add(url = "https://github.com/MurrellGroup/ChainStorm.jl")
This will load up a model and generate a single small protein with two chains, each of length 20:
using ChainStorm
model = load_model()
b = dummy_batch([20,20]) #<- The model's only input
g = flow_quickgen(b, model) #<- Model inference call
export_pdb("gen.pdb", g, b.chainids, b.resinds) #<- Save PDB
Or try this in a minimal Colab notebook:
using Pkg
Pkg.add(["GLMakie", "ProtPlot"])
using ChainStorm, GLMakie, ProtPlot
#If GPU:
using CUDA
dev = ChainStorm.gpu
#dev = identity #<- If no GPU
model = load_model() |> dev
chainlengths = [54,54]
b = dummy_batch(chainlengths)
paths = ChainStorm.Tracker() #The trajectories will end up in here
g = flow_quickgen(b, model, d = dev, tracker = paths) #<- Model inference call
id = join(string.(chainlengths),"_")*"-"*join(rand('A':'Z', 4))
export_pdb("$(id).pdb", g, b.chainids, b.resinds) #<- Save PDB
samp = gen2prot(g, b.chainids, b.resinds)
animate_trajectory("$(id).mp4", samp, first_trajectory(paths), viewmode = :fit) #<- Animate design process
Note: If you need the animations via GLMakie to run headless, in linux you can install xvfb, then run these in the terminal before starting your Julia session/script:
Xvfb :99 -screen 0 1024x768x24 &
export DISPLAY=:99
#In addition to ChainStorm, also install these:
using Pkg
Pkg.add(["JLD2", "Flux", "CannotWaitForTheseOptimisers", "LearningSchedules", "DLProteinFormats"])
Pkg.add(["CUDA", "cuDNN"])
using ChainStorm, DLProteinFormats, Flux, CannotWaitForTheseOptimisers, LearningSchedules, JLD2
using DLProteinFormats: load, PDBSimpleFlat, batch_flatrecs, sample_batched_inds, length2batch
using CUDA
device = gpu
dat = load(PDBSimpleFlat);
model = ChainStormV1(384, 3, 3) |> device
sched = burnin_learning_schedule(0.000005f0, 0.001f0, 1.05f0, 0.99995f0)
opt_state = Flux.setup(Muon(eta = sched.lr), model)
for epoch in 1:100
batchinds = sample_batched_inds(dat,l2b = length2batch(1500, 1.9))
for (i, b) in enumerate(batchinds)
bat = batch_flatrecs(dat[b])
ts = training_sample(bat) |> device
sc_frames = nothing
if epoch > 1 && rand() < 0.5
sc_frames, _ = model(ts.t, ts.Xt, ts.chainids, ts.resinds)
end
l, grad = Flux.withgradient(model) do m
fr, aalogs = m(ts.t, ts.Xt, ts.chainids, ts.resinds, sc_frames = sc_frames)
l_loc, l_rot, l_aas = losses(fr, aalogs, ts)
l_loc + l_rot + l_aas
end
Flux.update!(opt_state, model, grad[1])
(mod(i, 10) == 0) && Flux.adjust!(opt_state, next_rate(sched))
println(l)
end
jldsave("model_epoch_$epoch.jld", model_state = Flux.state(cpu(model)), opt_state=cpu(opt_state))
end