Skip to content

Commit 0db0319

Browse files
Fix AlphaFold integration: Correctly define and use apply function within model setup
1 parent c6eb36f commit 0db0319

File tree

1 file changed

+124
-29
lines changed

1 file changed

+124
-29
lines changed
Lines changed: 124 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,41 @@
1+
import jax
12
import jax.numpy as jnp
2-
from typing import List, Dict, Any
3-
from alphafold.model import model
3+
import haiku as hk
4+
from typing import List, Dict, Any, Tuple
5+
from alphafold.model import modules_multimer, config
6+
from alphafold.model.config import CONFIG, CONFIG_MULTIMER, CONFIG_DIFFS
47
from alphafold.common import protein
5-
from alphafold.data import pipeline
6-
from unittest.mock import MagicMock
8+
from alphafold.data import pipeline, templates
9+
from alphafold.data.tools import hhblits, jackhmmer
10+
from Bio import SeqIO
11+
from Bio.Seq import Seq
12+
from Bio.SeqRecord import SeqRecord
13+
import logging
14+
import copy
15+
import ml_collections
716

817
# Mock SCOPData
9-
SCOPData = MagicMock()
10-
SCOPData.protein_letters_3to1 = {
11-
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
12-
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
13-
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
14-
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
15-
}
16-
17-
# Mock getDomainBySid function
18-
def getDomainBySid(sid):
19-
"""
20-
Mock implementation of getDomainBySid.
21-
This function is a placeholder and should be replaced with actual implementation if needed.
22-
"""
23-
return MagicMock()
18+
class SCOPData:
19+
protein_letters_3to1 = {
20+
'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
21+
'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
22+
'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
23+
'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
24+
}
25+
26+
print("Mock SCOPData is being used in alphafold_integration.py")
27+
28+
# Export SCOPData for use in other modules
29+
__all__ = ['SCOPData', 'AlphaFoldIntegration']
2430

2531
class AlphaFoldIntegration:
2632
def __init__(self):
27-
self.model_runner = None
33+
self.model_apply = None
34+
self.model_params = None
2835
self.feature_dict = None
36+
self.msa_runner = None
37+
self.template_searcher = None
38+
self.config = None # Will be initialized in setup_model
2939

3040
def setup_model(self, model_params: Dict[str, Any]):
3141
"""
@@ -34,7 +44,68 @@ def setup_model(self, model_params: Dict[str, Any]):
3444
Args:
3545
model_params (Dict[str, Any]): Parameters for the AlphaFold model.
3646
"""
37-
self.model_runner = model.RunModel(model_params)
47+
logging.info("Setting up AlphaFold model")
48+
49+
try:
50+
# Initialize the config
51+
model_name = model_params.get('model_name', 'model_1')
52+
if 'multimer' in model_name:
53+
base_config = copy.deepcopy(CONFIG_MULTIMER)
54+
else:
55+
base_config = copy.deepcopy(CONFIG)
56+
57+
# Update the base config with model-specific differences
58+
if model_name in CONFIG_DIFFS:
59+
base_config.update_from_flattened_dict(CONFIG_DIFFS[model_name])
60+
61+
# Ensure global_config is present and correctly initialized
62+
if 'global_config' not in base_config:
63+
base_config.global_config = ml_collections.ConfigDict({
64+
'deterministic': False,
65+
'subbatch_size': 4,
66+
'use_remat': False,
67+
'zero_init': True,
68+
'eval_dropout': False,
69+
})
70+
71+
# Update config with any additional parameters
72+
base_config.update(model_params)
73+
74+
self.config = base_config
75+
76+
def create_model(config):
77+
model = modules_multimer.AlphaFold(config)
78+
def apply(params, inputs):
79+
return model.apply({'params': params}, **inputs)
80+
return model, apply
81+
82+
model_creator = hk.transform(create_model)
83+
84+
rng = jax.random.PRNGKey(0)
85+
dummy_input = {
86+
'aatype': jnp.zeros((1, 256), dtype=jnp.int32),
87+
'residue_index': jnp.zeros((1, 256), dtype=jnp.int32),
88+
'seq_length': jnp.array([256], dtype=jnp.int32),
89+
'template_aatype': jnp.zeros((1, 1, 256), dtype=jnp.int32),
90+
'template_all_atom_masks': jnp.zeros((1, 1, 256, 37), dtype=jnp.float32),
91+
'template_all_atom_positions': jnp.zeros((1, 1, 256, 37, 3), dtype=jnp.float32),
92+
'template_sum_probs': jnp.zeros((1, 1), dtype=jnp.float32),
93+
'is_distillation': jnp.array(0, dtype=jnp.int32),
94+
}
95+
self.model_params = model_creator.init(rng, self.config)
96+
_, self.model_apply = model_creator.apply(self.model_params, rng, self.config)
97+
98+
# Test the model with dummy input
99+
_ = self.model_apply(self.model_params, dummy_input)
100+
logging.info("AlphaFold model initialized successfully")
101+
102+
self.msa_runner = jackhmmer.Jackhmmer(binary_path=model_params.get('jackhmmer_binary_path', '/usr/bin/jackhmmer'))
103+
self.template_searcher = hhblits.HHBlits(binary_path=model_params.get('hhblits_binary_path', '/usr/bin/hhblits'))
104+
logging.info("MSA runner and template searcher initialized")
105+
106+
except Exception as e:
107+
logging.error(f"Error in AlphaFold setup: {str(e)}")
108+
raise ValueError(f"Failed to set up AlphaFold model: {str(e)}")
38109

39110
def prepare_features(self, sequence: str):
40111
"""
@@ -46,8 +117,32 @@ def prepare_features(self, sequence: str):
46117
Returns:
47118
Dict: Feature dictionary for AlphaFold.
48119
"""
49-
self.feature_dict = pipeline.make_sequence_features(sequence)
50-
self.feature_dict.update(pipeline.make_msa_features([sequence]))
120+
sequence_features = pipeline.make_sequence_features(sequence)
121+
msa = self._run_msa(sequence)
122+
msa_features = pipeline.make_msa_features(msas=[msa])
123+
template_features = self._search_templates(sequence)
124+
125+
self.feature_dict = {**sequence_features, **msa_features, **template_features}
126+
127+
def _run_msa(self, sequence: str) -> List[Tuple[str, str]]:
128+
"""Run MSA and return results."""
129+
with open("temp.fasta", "w") as f:
130+
SeqIO.write(SeqRecord(Seq(sequence), id="query"), f, "fasta")
131+
result = self.msa_runner.query("temp.fasta")
132+
return [("query", sequence)] + [(hit.name, hit.sequence) for hit in result.hits]
133+
134+
def _search_templates(self, sequence: str) -> Dict[str, Any]:
135+
"""Search for templates and prepare features."""
136+
with open("temp.fasta", "w") as f:
137+
SeqIO.write(SeqRecord(Seq(sequence), id="query"), f, "fasta")
138+
hits = self.template_searcher.query("temp.fasta")
139+
templates_result = templates.TemplateHitFeaturizer(
140+
mmcif_dir="/path/to/mmcif/files",
141+
max_template_date="2021-11-01",
142+
max_hits=20,
143+
kalign_binary_path="/path/to/kalign"
144+
).get_templates(query_sequence=sequence, hits=hits)
145+
return templates_result.features
51146

52147
def predict_structure(self) -> protein.Protein:
53148
"""
@@ -56,10 +151,10 @@ def predict_structure(self) -> protein.Protein:
56151
Returns:
57152
protein.Protein: Predicted protein structure.
58153
"""
59-
if self.model_runner is None or self.feature_dict is None:
154+
if self.model_apply is None or self.feature_dict is None:
60155
raise ValueError("Model or features not set up. Call setup_model() and prepare_features() first.")
61156

62-
prediction_result = self.model_runner.predict(self.feature_dict)
157+
prediction_result = self.model_apply(self.model_params, self.feature_dict)
63158
return protein.from_prediction(prediction_result)
64159

65160
def get_plddt_scores(self) -> jnp.ndarray:
@@ -69,10 +164,10 @@ def get_plddt_scores(self) -> jnp.ndarray:
69164
Returns:
70165
jnp.ndarray: Array of pLDDT scores.
71166
"""
72-
if self.model_runner is None or self.feature_dict is None:
167+
if self.model_apply is None or self.feature_dict is None:
73168
raise ValueError("Model or features not set up. Call setup_model() and prepare_features() first.")
74169

75-
prediction_result = self.model_runner.predict(self.feature_dict)
170+
prediction_result = self.model_apply(self.model_params, self.feature_dict)
76171
return prediction_result['plddt']
77172

78173
def get_predicted_aligned_error(self) -> jnp.ndarray:
@@ -82,8 +177,8 @@ def get_predicted_aligned_error(self) -> jnp.ndarray:
82177
Returns:
83178
jnp.ndarray: 2D array of predicted aligned errors.
84179
"""
85-
if self.model_runner is None or self.feature_dict is None:
180+
if self.model_apply is None or self.feature_dict is None:
86181
raise ValueError("Model or features not set up. Call setup_model() and prepare_features() first.")
87182

88-
prediction_result = self.model_runner.predict(self.feature_dict)
183+
prediction_result = self.model_apply(self.model_params, self.feature_dict)
89184
return prediction_result['predicted_aligned_error']

0 commit comments

Comments
 (0)