@@ -56,37 +56,42 @@ def initialize(self):
56
56
self .models [name ] = FileData (real_hash , local_data_path )
57
57
58
58
def get_model (self , model_name ):
59
- if model_name in self .models :
60
- return self .models [model_name ].file_path
61
- elif len ([optional for optional in self .manifest_data ['optional_files' ] if
62
- optional ['name' ] == model_name ]) > 0 :
63
- self .find_optional_model (model_name )
64
- return self .models [model_name ].file_path
59
+ if self .available ():
60
+ if model_name in self .models :
61
+ return self .models [model_name ].file_path
62
+ elif len ([optional for optional in self .manifest_data ['optional_files' ] if
63
+ optional ['name' ] == model_name ]) > 0 :
64
+ self .find_optional_model (model_name )
65
+ return self .models [model_name ].file_path
66
+ else :
67
+ raise Exception ("model name " + model_name + " not found in manifest" )
65
68
else :
66
- raise Exception ("model name " + model_name + " not found in manifest" )
69
+ raise Exception ("unable to get model {}, model_manifest.json not found." . format ( model_name ) )
67
70
68
71
def find_optional_model (self , file_name ):
69
-
70
- found_models = [optional for optional in self .manifest_data ['optional_files' ] if
71
- optional ['name' ] == file_name ]
72
- if len (found_models ) == 0 :
73
- raise Exception ("file with name '" + file_name + "' not found in model manifest." )
74
- model_info = found_models [0 ]
75
- self .models [file_name ] = {}
76
- source_uri = model_info ['source_uri' ]
77
- fail_on_tamper = model_info .get ("fail_on_tamper" , False )
78
- expected_hash = model_info .get ('md5_checksum' , None )
79
- with self .client .file (source_uri ).getFile () as f :
80
- local_data_path = f .name
81
- real_hash = md5_for_file (local_data_path )
82
- if self .using_frozen :
83
- if real_hash != expected_hash and fail_on_tamper :
84
- raise Exception ("Model File Mismatch for " + file_name +
85
- "\n expected hash: " + expected_hash + "\n real hash: " + real_hash )
72
+ if self .available ():
73
+ found_models = [optional for optional in self .manifest_data ['optional_files' ] if
74
+ optional ['name' ] == file_name ]
75
+ if len (found_models ) == 0 :
76
+ raise Exception ("file with name '" + file_name + "' not found in model manifest." )
77
+ model_info = found_models [0 ]
78
+ self .models [file_name ] = {}
79
+ source_uri = model_info ['source_uri' ]
80
+ fail_on_tamper = model_info .get ("fail_on_tamper" , False )
81
+ expected_hash = model_info .get ('md5_checksum' , None )
82
+ with self .client .file (source_uri ).getFile () as f :
83
+ local_data_path = f .name
84
+ real_hash = md5_for_file (local_data_path )
85
+ if self .using_frozen :
86
+ if real_hash != expected_hash and fail_on_tamper :
87
+ raise Exception ("Model File Mismatch for " + file_name +
88
+ "\n expected hash: " + expected_hash + "\n real hash: " + real_hash )
89
+ else :
90
+ self .models [file_name ] = FileData (real_hash , local_data_path )
86
91
else :
87
92
self .models [file_name ] = FileData (real_hash , local_data_path )
88
93
else :
89
- self . models [ file_name ] = FileData ( real_hash , local_data_path )
94
+ raise Exception ( "unable to get model {}, model_manifest.json not found." . format ( model_name ) )
90
95
91
96
def get_manifest (self ):
92
97
if os .path .exists (self .manifest_frozen_path ):
0 commit comments