1
+ import jax
1
2
import jax .numpy as jnp
2
- from typing import List , Dict , Any
3
- from alphafold .model import model
3
+ import haiku as hk
4
+ from typing import List , Dict , Any , Tuple
5
+ from alphafold .model import modules_multimer , config
6
+ from alphafold .model .config import CONFIG , CONFIG_MULTIMER , CONFIG_DIFFS
4
7
from alphafold .common import protein
5
- from alphafold .data import pipeline
6
- from unittest .mock import MagicMock
8
+ from alphafold .data import pipeline , templates
9
+ from alphafold .data .tools import hhblits , jackhmmer
10
+ from Bio import SeqIO
11
+ from Bio .Seq import Seq
12
+ from Bio .SeqRecord import SeqRecord
13
+ import logging
14
+ import copy
15
+ import ml_collections
7
16
8
17
# Mock SCOPData
9
- SCOPData = MagicMock ()
10
- SCOPData .protein_letters_3to1 = {
11
- 'ALA' : 'A' , 'CYS' : 'C' , 'ASP' : 'D' , 'GLU' : 'E' , 'PHE' : 'F' ,
12
- 'GLY' : 'G' , 'HIS' : 'H' , 'ILE' : 'I' , 'LYS' : 'K' , 'LEU' : 'L' ,
13
- 'MET' : 'M' , 'ASN' : 'N' , 'PRO' : 'P' , 'GLN' : 'Q' , 'ARG' : 'R' ,
14
- 'SER' : 'S' , 'THR' : 'T' , 'VAL' : 'V' , 'TRP' : 'W' , 'TYR' : 'Y'
15
- }
16
-
17
- # Mock getDomainBySid function
18
- def getDomainBySid (sid ):
19
- """
20
- Mock implementation of getDomainBySid.
21
- This function is a placeholder and should be replaced with actual implementation if needed.
22
- """
23
- return MagicMock ()
18
+ class SCOPData :
19
+ protein_letters_3to1 = {
20
+ 'ALA' : 'A' , 'CYS' : 'C' , 'ASP' : 'D' , 'GLU' : 'E' , 'PHE' : 'F' ,
21
+ 'GLY' : 'G' , 'HIS' : 'H' , 'ILE' : 'I' , 'LYS' : 'K' , 'LEU' : 'L' ,
22
+ 'MET' : 'M' , 'ASN' : 'N' , 'PRO' : 'P' , 'GLN' : 'Q' , 'ARG' : 'R' ,
23
+ 'SER' : 'S' , 'THR' : 'T' , 'VAL' : 'V' , 'TRP' : 'W' , 'TYR' : 'Y'
24
+ }
25
+
26
+ print ("Mock SCOPData is being used in alphafold_integration.py" )
27
+
28
+ # Export SCOPData for use in other modules
29
+ __all__ = ['SCOPData' , 'AlphaFoldIntegration' ]
24
30
25
31
class AlphaFoldIntegration :
26
32
def __init__ (self ):
27
- self .model_runner = None
33
+ self .model_apply = None
34
+ self .model_params = None
28
35
self .feature_dict = None
36
+ self .msa_runner = None
37
+ self .template_searcher = None
38
+ self .config = None # Will be initialized in setup_model
29
39
30
40
def setup_model (self , model_params : Dict [str , Any ]):
31
41
"""
@@ -34,7 +44,68 @@ def setup_model(self, model_params: Dict[str, Any]):
34
44
Args:
35
45
model_params (Dict[str, Any]): Parameters for the AlphaFold model.
36
46
"""
37
- self .model_runner = model .RunModel (model_params )
47
+ logging .info ("Setting up AlphaFold model" )
48
+
49
+ try :
50
+ # Initialize the config
51
+ model_name = model_params .get ('model_name' , 'model_1' )
52
+ if 'multimer' in model_name :
53
+ base_config = copy .deepcopy (CONFIG_MULTIMER )
54
+ else :
55
+ base_config = copy .deepcopy (CONFIG )
56
+
57
+ # Update the base config with model-specific differences
58
+ if model_name in CONFIG_DIFFS :
59
+ base_config .update_from_flattened_dict (CONFIG_DIFFS [model_name ])
60
+
61
+ # Ensure global_config is present and correctly initialized
62
+ if 'global_config' not in base_config :
63
+ base_config .global_config = ml_collections .ConfigDict ({
64
+ 'deterministic' : False ,
65
+ 'subbatch_size' : 4 ,
66
+ 'use_remat' : False ,
67
+ 'zero_init' : True ,
68
+ 'eval_dropout' : False ,
69
+ })
70
+
71
+ # Update config with any additional parameters
72
+ base_config .update (model_params )
73
+
74
+ self .config = base_config
75
+
76
+ def create_model (config ):
77
+ model = modules_multimer .AlphaFold (config )
78
+ def apply (params , inputs ):
79
+ return model .apply ({'params' : params }, ** inputs )
80
+ return model , apply
81
+
82
+ model_creator = hk .transform (create_model )
83
+
84
+ rng = jax .random .PRNGKey (0 )
85
+ dummy_input = {
86
+ 'aatype' : jnp .zeros ((1 , 256 ), dtype = jnp .int32 ),
87
+ 'residue_index' : jnp .zeros ((1 , 256 ), dtype = jnp .int32 ),
88
+ 'seq_length' : jnp .array ([256 ], dtype = jnp .int32 ),
89
+ 'template_aatype' : jnp .zeros ((1 , 1 , 256 ), dtype = jnp .int32 ),
90
+ 'template_all_atom_masks' : jnp .zeros ((1 , 1 , 256 , 37 ), dtype = jnp .float32 ),
91
+ 'template_all_atom_positions' : jnp .zeros ((1 , 1 , 256 , 37 , 3 ), dtype = jnp .float32 ),
92
+ 'template_sum_probs' : jnp .zeros ((1 , 1 ), dtype = jnp .float32 ),
93
+ 'is_distillation' : jnp .array (0 , dtype = jnp .int32 ),
94
+ }
95
+ self .model_params = model_creator .init (rng , self .config )
96
+ _ , self .model_apply = model_creator .apply (self .model_params , rng , self .config )
97
+
98
+ # Test the model with dummy input
99
+ _ = self .model_apply (self .model_params , dummy_input )
100
+ logging .info ("AlphaFold model initialized successfully" )
101
+
102
+ self .msa_runner = jackhmmer .Jackhmmer (binary_path = model_params .get ('jackhmmer_binary_path' , '/usr/bin/jackhmmer' ))
103
+ self .template_searcher = hhblits .HHBlits (binary_path = model_params .get ('hhblits_binary_path' , '/usr/bin/hhblits' ))
104
+ logging .info ("MSA runner and template searcher initialized" )
105
+
106
+ except Exception as e :
107
+ logging .error (f"Error in AlphaFold setup: { str (e )} " )
108
+ raise ValueError (f"Failed to set up AlphaFold model: { str (e )} " )
38
109
39
110
def prepare_features (self , sequence : str ):
40
111
"""
@@ -46,8 +117,32 @@ def prepare_features(self, sequence: str):
46
117
Returns:
47
118
Dict: Feature dictionary for AlphaFold.
48
119
"""
49
- self .feature_dict = pipeline .make_sequence_features (sequence )
50
- self .feature_dict .update (pipeline .make_msa_features ([sequence ]))
120
+ sequence_features = pipeline .make_sequence_features (sequence )
121
+ msa = self ._run_msa (sequence )
122
+ msa_features = pipeline .make_msa_features (msas = [msa ])
123
+ template_features = self ._search_templates (sequence )
124
+
125
+ self .feature_dict = {** sequence_features , ** msa_features , ** template_features }
126
+
127
+ def _run_msa (self , sequence : str ) -> List [Tuple [str , str ]]:
128
+ """Run MSA and return results."""
129
+ with open ("temp.fasta" , "w" ) as f :
130
+ SeqIO .write (SeqRecord (Seq (sequence ), id = "query" ), f , "fasta" )
131
+ result = self .msa_runner .query ("temp.fasta" )
132
+ return [("query" , sequence )] + [(hit .name , hit .sequence ) for hit in result .hits ]
133
+
134
+ def _search_templates (self , sequence : str ) -> Dict [str , Any ]:
135
+ """Search for templates and prepare features."""
136
+ with open ("temp.fasta" , "w" ) as f :
137
+ SeqIO .write (SeqRecord (Seq (sequence ), id = "query" ), f , "fasta" )
138
+ hits = self .template_searcher .query ("temp.fasta" )
139
+ templates_result = templates .TemplateHitFeaturizer (
140
+ mmcif_dir = "/path/to/mmcif/files" ,
141
+ max_template_date = "2021-11-01" ,
142
+ max_hits = 20 ,
143
+ kalign_binary_path = "/path/to/kalign"
144
+ ).get_templates (query_sequence = sequence , hits = hits )
145
+ return templates_result .features
51
146
52
147
def predict_structure (self ) -> protein .Protein :
53
148
"""
@@ -56,10 +151,10 @@ def predict_structure(self) -> protein.Protein:
56
151
Returns:
57
152
protein.Protein: Predicted protein structure.
58
153
"""
59
- if self .model_runner is None or self .feature_dict is None :
154
+ if self .model_apply is None or self .feature_dict is None :
60
155
raise ValueError ("Model or features not set up. Call setup_model() and prepare_features() first." )
61
156
62
- prediction_result = self .model_runner . predict ( self .feature_dict )
157
+ prediction_result = self .model_apply ( self . model_params , self .feature_dict )
63
158
return protein .from_prediction (prediction_result )
64
159
65
160
def get_plddt_scores (self ) -> jnp .ndarray :
@@ -69,10 +164,10 @@ def get_plddt_scores(self) -> jnp.ndarray:
69
164
Returns:
70
165
jnp.ndarray: Array of pLDDT scores.
71
166
"""
72
- if self .model_runner is None or self .feature_dict is None :
167
+ if self .model_apply is None or self .feature_dict is None :
73
168
raise ValueError ("Model or features not set up. Call setup_model() and prepare_features() first." )
74
169
75
- prediction_result = self .model_runner . predict ( self .feature_dict )
170
+ prediction_result = self .model_apply ( self . model_params , self .feature_dict )
76
171
return prediction_result ['plddt' ]
77
172
78
173
def get_predicted_aligned_error (self ) -> jnp .ndarray :
@@ -82,8 +177,8 @@ def get_predicted_aligned_error(self) -> jnp.ndarray:
82
177
Returns:
83
178
jnp.ndarray: 2D array of predicted aligned errors.
84
179
"""
85
- if self .model_runner is None or self .feature_dict is None :
180
+ if self .model_apply is None or self .feature_dict is None :
86
181
raise ValueError ("Model or features not set up. Call setup_model() and prepare_features() first." )
87
182
88
- prediction_result = self .model_runner . predict ( self .feature_dict )
183
+ prediction_result = self .model_apply ( self . model_params , self .feature_dict )
89
184
return prediction_result ['predicted_aligned_error' ]
0 commit comments