Skip to content

Commit d5e389f

Browse files
wukevinjackdent
andauthored
Easy command line interface with MSA server flag (#190)
Co-authored-by: Jack Dent <jack@chaidiscovery.com>
1 parent c086906 commit d5e389f

File tree

11 files changed

+143
-38
lines changed

11 files changed

+143
-38
lines changed

README.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,18 @@ This Python package requires Linux, and a GPU with CUDA and bfloat16 support. We
2222

2323
## Running the model
2424

25-
The model accepts inputs in the FASTA file format, and allows you to specify the number of trunk recycles and diffusion timesteps via the `chai_lab.chai1.run_inference` function. By default, the model generates five sample predictions, and uses embeddings without MSAs or templates.
25+
### Command line inference
2626

27-
The following script demonstrates how to provide inputs to the model, and obtain a list of PDB files for downstream analysis:
27+
You can fold a FASTA file containing all the sequences (including modified residues, nucleotides, and ligands as SMILES strings) in a complex of interest by calling:
28+
```shell
29+
chai fold input.fasta output_folder
30+
```
31+
32+
By default, the model generates five sample predictions, and uses embeddings without MSAs or templates. For additional information about how to supply MSAs and restraints to the model, see the documentation below, or run `chai fold --help`.
33+
34+
### Programmatic inference
35+
36+
The main entrypoint into the Chai-1 folding code is through the `chai_lab.chai1.run_inference` function. The following script demonstrates how to programmatically provide inputs to the model, and obtain a list of PDB files for downstream analysis:
2837

2938
```shell
3039
python examples/predict_structure.py
@@ -56,6 +65,8 @@ CHAI_DOWNLOADS_DIR=/tmp/downloads python ./examples/predict_structure.py
5665

5766
Chai-1 supports MSAs provided as an `aligned.pqt` file. This file format is similar to an `a3m` file, but has additional columns that provide metadata like the source database and sequence pairing keys. We provide code to convert `a3m` files to `aligned.pqt` files. For more information on how to provide MSAs to Chai-1, see [this documentation](examples/msas/README.md).
5867

68+
For user convenience, we also support automatic MSA generation via the ColabFold [MMseqs2](https://github.com/soedinglab/MMseqs2) server via the `--msa-server` flag. As detailed in the ColabFold [repository](https://github.com/sokrypton/ColabFold), please keep in mind that this is a shared resource. Note that the results reported in our preprint and the webserver use a different MSA search strategy than MMseqs2, though we expect results to be broadly similar.
69+
5970
</p>
6071
</details>
6172

@@ -121,6 +132,20 @@ If you find Chai-1 useful in your research or use any structures produced by the
121132
}
122133
```
123134

135+
You can also access this information by running `chai citation`.
136+
137+
Additionally, if you use the automatic MMseqs2 MSA generation described above, please also cite:
138+
139+
```
140+
@article{mirdita2022colabfold,
141+
title={ColabFold: making protein folding accessible to all},
142+
author={Mirdita, Milot and Sch{\"u}tze, Konstantin and Moriwaki, Yoshitaka and Heo, Lim and Ovchinnikov, Sergey and Steinegger, Martin},
143+
journal={Nature methods},
144+
year={2022},
145+
}
146+
```
147+
148+
124149
## Licence
125150

126151
Chai-1 is released under an Apache 2.0 License (both code and model weights), which means it can be used for both academic and commerical purposes, including for drug discovery.

chai_lab/chai1.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from chai_lab.data.dataset.embeddings.embedding_context import EmbeddingContext
3030
from chai_lab.data.dataset.embeddings.esm import get_esm_embedding_context
3131
from chai_lab.data.dataset.inference_dataset import load_chains_from_raw, read_inputs
32+
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
3233
from chai_lab.data.dataset.msas.load import get_msa_contexts
3334
from chai_lab.data.dataset.msas.msa_context import MSAContext
3435
from chai_lab.data.dataset.structure.all_atom_structure_context import (
@@ -83,6 +84,7 @@
8384
)
8485
from chai_lab.data.io.cif_utils import outputs_to_cif
8586
from chai_lab.data.parsing.restraints import parse_pairwise_table
87+
from chai_lab.data.parsing.structure.entity_type import EntityType
8688
from chai_lab.model.diffusion_schedules import InferenceNoiseSchedule
8789
from chai_lab.model.utils import center_random_augmentation
8890
from chai_lab.ranking.frames import get_frames_and_mask
@@ -268,14 +270,24 @@ def run_inference(
268270
*,
269271
output_dir: Path,
270272
use_esm_embeddings: bool = True,
273+
msa_server: bool = False,
271274
msa_directory: Path | None = None,
272-
constraint_path: Path | str | None = None,
275+
constraint_path: Path | None = None,
273276
# expose some params for easy tweaking
274277
num_trunk_recycles: int = 3,
275278
num_diffn_timesteps: int = 200,
276279
seed: int | None = None,
277-
device: torch.device | None = None,
280+
device: str | None = None,
278281
) -> StructureCandidates:
282+
if output_dir.exists():
283+
assert not any(
284+
output_dir.iterdir()
285+
), f"Output directory {output_dir} is not empty."
286+
torch_device = torch.device(device if device is not None else "cuda:0")
287+
assert not (
288+
msa_server and msa_directory
289+
), "Cannot specify both MSA server and directory"
290+
279291
# Prepare inputs
280292
assert fasta_file.exists(), fasta_file
281293
fasta_inputs = read_inputs(fasta_file, length_limit=None)
@@ -290,14 +302,28 @@ def run_inference(
290302

291303
# Load structure context
292304
chains = load_chains_from_raw(fasta_inputs)
305+
del fasta_inputs # Do not reference inputs after creating chains from them
306+
293307
merged_context = AllAtomStructureContext.merge(
294308
[c.structure_context for c in chains]
295309
)
296310
n_actual_tokens = merged_context.num_tokens
297311
raise_if_too_many_tokens(n_actual_tokens)
298312

299-
# Load MSAs
300-
if msa_directory is not None:
313+
# Generated and/or load MSAs
314+
if msa_server:
315+
protein_sequences = [
316+
chain.entity_data.sequence
317+
for chain in chains
318+
if chain.entity_data.entity_type == EntityType.PROTEIN
319+
]
320+
msa_dir = output_dir / "msas"
321+
msa_dir.mkdir(parents=True, exist_ok=False)
322+
generate_colabfold_msas(protein_seqs=protein_sequences, msa_dir=msa_dir)
323+
msa_context, msa_profile_context = get_msa_contexts(
324+
chains, msa_directory=msa_dir
325+
)
326+
elif msa_directory is not None:
301327
msa_context, msa_profile_context = get_msa_contexts(
302328
chains, msa_directory=msa_directory
303329
)
@@ -308,6 +334,7 @@ def run_inference(
308334
msa_profile_context = MSAContext.create_empty(
309335
n_tokens=n_actual_tokens, depth=MAX_MSA_DEPTH
310336
)
337+
311338
assert (
312339
msa_context.num_tokens == merged_context.num_tokens
313340
), f"Discrepant tokens in input and MSA: {merged_context.num_tokens} != {msa_context.num_tokens}"
@@ -320,7 +347,7 @@ def run_inference(
320347

321348
# Load ESM embeddings
322349
if use_esm_embeddings:
323-
embedding_context = get_esm_embedding_context(chains, device=device)
350+
embedding_context = get_esm_embedding_context(chains, device=torch_device)
324351
else:
325352
embedding_context = EmbeddingContext.empty(n_tokens=n_actual_tokens)
326353

@@ -351,7 +378,7 @@ def run_inference(
351378
num_trunk_recycles=num_trunk_recycles,
352379
num_diffn_timesteps=num_diffn_timesteps,
353380
seed=seed,
354-
device=device,
381+
device=torch_device,
355382
)
356383

357384

chai_lab/data/dataset/inference_dataset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def _synth_subchain_id(idx: int) -> str:
9393
def raw_inputs_to_entitites_data(
9494
inputs: list[Input], identifier: str = "test"
9595
) -> list[AllAtomEntityData]:
96+
"""Load an entity for each raw input."""
9697
entities = []
9798

9899
# track unique entities
@@ -157,7 +158,7 @@ def load_chains_from_raw(
157158
tokenizer: AllAtomResidueTokenizer | None = None,
158159
) -> list[Chain]:
159160
"""
160-
loads and tokenizes each input chain
161+
Loads and tokenizes each input chain; skips over inputs that fail to tokenize.
161162
"""
162163

163164
if tokenizer is None:
@@ -186,12 +187,14 @@ def load_chains_from_raw(
186187
logger.exception(f"Failed to tokenize input {entity_data=} {sym_id=}")
187188
tok = None
188189
structure_contexts.append(tok)
189-
assert len(structure_contexts) == len(entities)
190+
190191
# Join the untokenized entity data with the tokenized chain data, removing
191192
# chains we failed to tokenize
192193
chains = [
193194
Chain(entity_data=entity_data, structure_context=structure_context)
194-
for entity_data, structure_context in zip(entities, structure_contexts)
195+
for entity_data, structure_context in zip(
196+
entities, structure_contexts, strict=True
197+
)
195198
if structure_context is not None
196199
]
197200

chai_lab/data/dataset/msas/colabfold.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path):
352352
This implementation also relies on ColabFold's chain pairing algorithm
353353
rather than using Chai-1's own algorithm, which could also lead to
354354
differences in results.
355+
356+
Places .aligned.pqt files in msa_dir; does not save intermediate a3m files.
355357
"""
356358
assert msa_dir.is_dir(), "MSA directory must be a dir"
357359
assert not any(msa_dir.iterdir()), "MSA directory must be empty"
@@ -366,7 +368,7 @@ def generate_colabfold_msas(protein_seqs: list[str], msa_dir: Path):
366368
a3ms_dir.mkdir()
367369

368370
# Generate MSAs for each protein chain
369-
print(f"Running MSA generation for {len(protein_seqs)} protein sequences")
371+
logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences")
370372
msas = _run_mmseqs2(
371373
protein_seqs,
372374
mmseqs_dir,

chai_lab/data/dataset/msas/load.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ def get_msa_contexts(
4747
def get_msa_contexts_for_seq(seq) -> MSAContext:
4848
path = msa_directory / expected_basename(seq)
4949
if not path.is_file():
50-
logger.warning(f"No MSA found for sequence: {seq}")
50+
if seq != "X":
51+
# Don't warn for the special "X" sequence
52+
logger.warning(f"No MSA found for sequence: {seq}")
5153
[tokenized_seq] = tokenize_sequences_to_arrays([seq])[0]
5254
return MSAContext.create_single_seq(
5355
MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq)

chai_lab/main.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright (c) 2024 Chai Discovery, Inc.
2+
# Licensed under the Apache License, Version 2.0.
3+
# See the LICENSE file for details.
4+
5+
"""Command line interface."""
6+
7+
import logging
8+
9+
import typer
10+
11+
from chai_lab.chai1 import run_inference
12+
13+
CITATION = """
14+
@article{Chai-1-Technical-Report,
15+
title = {Chai-1: Decoding the molecular interactions of life},
16+
author = {{Chai Discovery}},
17+
year = 2024,
18+
journal = {bioRxiv},
19+
publisher = {Cold Spring Harbor Laboratory},
20+
doi = {10.1101/2024.10.10.615955},
21+
url = {https://www.biorxiv.org/content/early/2024/10/11/2024.10.10.615955},
22+
elocation-id = {2024.10.10.615955},
23+
eprint = {https://www.biorxiv.org/content/early/2024/10/11/2024.10.10.615955.full.pdf}
24+
}
25+
""".strip()
26+
27+
28+
def citation():
29+
"""Print citation information"""
30+
typer.echo(CITATION)
31+
32+
33+
def cli():
34+
app = typer.Typer()
35+
app.command("fold", help="Run Chai-1 to fold a complex.")(run_inference)
36+
app.command("citation", help="Print citation information")(citation)
37+
app()
38+
39+
40+
if __name__ == "__main__":
41+
logging.basicConfig(level=logging.INFO)
42+
cli()

examples/msas/README.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
While Chai-1 performs very well in "single-sequence mode," it can also be given additional evolutionary information to further improve performance. As in other folding methods, this evolutionary information is provided in the form of a multiple sequence alignment (MSA). This information is given in the form of a `MSAContext` object (see `chai_lab/data/dataset/msas/msa_context.py`); we provide code for building these `MSAContext` objects through `aligned.pqt` files, though you can play with building out an `MSAContext` yourself as well.
44

5-
Multiple strategies can be used for generating MSAs. In our [technical report](https://chaiassets.com/chai-1/paper/technical_report_v1.pdf), we generated MSAs using [jackhmmer](https://github.com/EddyRivasLab/hmmer). Other algorithms such as [MMseqs2](https://github.com/soedinglab/MMseqs2) can also be used. We provide an example of how to generate MSAs using [ColabFold](https://github.com/sokrypton/ColabFold) in `examples/msas/predict_with_msas.py`. Performance will vary depending on the input MSA databases and search algorithms used.
6-
75
## The `.aligned.pqt` file format
86

97
The easiest way to provide MSA information to Chai-1 is through the `.aligned.pqt` file format that we have defined. This file can be thought of as an augmented `a3m` file, and is essentially a dataframe saved in parquet format with the following four (required) columns:
@@ -58,4 +56,21 @@ import pandas as pd
5856

5957
aligned_pqt = pd.read_parquet("examples/msas/703adc2c74b8d7e613549b6efcf37126da7963522dc33852ad3c691eef1da06f.aligned.pqt")
6058
aligned_pqt.head()
61-
```
59+
```
60+
61+
62+
## Additional MSA generation strategies
63+
64+
Multiple strategies can be used for generating MSAs. In our [technical report](https://chaiassets.com/chai-1/paper/technical_report_v1.pdf), we generated MSAs using [jackhmmer](https://github.com/EddyRivasLab/hmmer). Other algorithms such as [MMseqs2](https://github.com/soedinglab/MMseqs2) can also be used. In this vein, we provide support for automatic MSA generation via the [ColabFold](https://github.com/sokrypton/ColabFold) server using `chai fold input.fasta output_directory --msa-server` or by invoking `run_inference` as follows:
65+
66+
```python
67+
candidates = run_inference(
68+
...
69+
msa_sever=True,
70+
...
71+
)
72+
```
73+
74+
Please note that performance will vary depending on the input MSA databases and search algorithms used.
75+
76+
In addition, people have found that tweaking MSA inputs can be a fruitful path to improving folding results -- we such exploration of this for Chai-1 as well!

examples/msas/predict_with_msas.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,8 @@
22
from pathlib import Path
33

44
import numpy as np
5-
import torch
65

76
from chai_lab.chai1 import run_inference
8-
from chai_lab.data.dataset.inference_dataset import read_inputs
9-
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
10-
from chai_lab.data.parsing.structure.entity_type import EntityType
117

128
tmp_dir = Path(tempfile.mkdtemp())
139

@@ -25,16 +21,6 @@
2521
fasta_path = tmp_dir / "example.fasta"
2622
fasta_path.write_text(example_fasta)
2723

28-
# Generate MSAs
29-
msa_dir = tmp_dir / "msas"
30-
msa_dir.mkdir()
31-
protein_seqs = [
32-
input.sequence
33-
for input in read_inputs(fasta_path)
34-
if input.entity_type == EntityType.PROTEIN.value
35-
]
36-
generate_colabfold_msas(protein_seqs=protein_seqs, msa_dir=msa_dir)
37-
3824

3925
# Generate structure
4026
output_dir = tmp_dir / "outputs"
@@ -45,9 +31,12 @@
4531
num_trunk_recycles=3,
4632
num_diffn_timesteps=200,
4733
seed=42,
48-
device=torch.device("cuda:0"),
34+
device="cuda:0",
4935
use_esm_embeddings=True,
50-
msa_directory=msa_dir,
36+
# See example .aligned.pqt files in this directory
37+
msa_directory=Path(__file__).parent,
38+
# Exclusive with msa_directory; can be used for MMseqs2 server MSA generation
39+
msa_server=False,
5140
)
5241
cif_paths = candidates.cif_paths
5342
scores = [rd.aggregate_score for rd in candidates.ranking_data]

examples/predict_structure.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from pathlib import Path
22

33
import numpy as np
4-
import torch
54

65
from chai_lab.chai1 import run_inference
76

@@ -36,7 +35,7 @@
3635
num_trunk_recycles=3,
3736
num_diffn_timesteps=200,
3837
seed=42,
39-
device=torch.device("cuda:0"),
38+
device="cuda:0",
4039
use_esm_embeddings=True,
4140
)
4241

examples/restraints/predict_with_restraints.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import logging
22
from pathlib import Path
33

4-
import torch
5-
64
from chai_lab.chai1 import run_inference
75

86
logging.basicConfig(level=logging.INFO)
@@ -32,6 +30,6 @@
3230
num_trunk_recycles=3,
3331
num_diffn_timesteps=200,
3432
seed=42,
35-
device=torch.device("cuda:0"),
33+
device="cuda:0",
3634
use_esm_embeddings=True,
3735
)

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,7 @@ exclude = [
6868
]
6969

7070
[tool.hatch.build.targets.wheel]
71-
# should use packages from sdist section
71+
# should use packages from sdist section
72+
73+
[project.scripts]
74+
chai = "chai_lab.main:cli"

0 commit comments

Comments
 (0)