Skip to content

Commit 0d09689

Browse files
Improve model config error handling (#333)
1 parent d31123f commit 0d09689

File tree

3 files changed

+85
-23
lines changed

3 files changed

+85
-23
lines changed

Libraries/MLXLLM/LLMModelFactory.swift

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -317,12 +317,26 @@ public class LLMModelFactory: ModelFactory {
317317
let modelDirectory = try await downloadModel(
318318
hub: hub, configuration: configuration, progressHandler: progressHandler)
319319

320-
// load the generic config to understand which model and how to load the weights
320+
// Load the generic config to understand which model and how to load the weights
321321
let configurationURL = modelDirectory.appending(component: "config.json")
322-
let baseConfig = try JSONDecoder().decode(
323-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
324-
let model = try typeRegistry.createModel(
325-
configuration: configurationURL, modelType: baseConfig.modelType)
322+
323+
let baseConfig: BaseConfiguration
324+
do {
325+
baseConfig = try JSONDecoder().decode(
326+
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
327+
} catch let error as DecodingError {
328+
throw ModelFactoryError.configurationDecodingError(
329+
configurationURL.lastPathComponent, configuration.name, error)
330+
}
331+
332+
let model: LanguageModel
333+
do {
334+
model = try typeRegistry.createModel(
335+
configuration: configurationURL, modelType: baseConfig.modelType)
336+
} catch let error as DecodingError {
337+
throw ModelFactoryError.configurationDecodingError(
338+
configurationURL.lastPathComponent, configuration.name, error)
339+
}
326340

327341
// apply the weights to the bare model
328342
try loadWeights(
@@ -338,12 +352,12 @@ public class LLMModelFactory: ModelFactory {
338352
DefaultMessageGenerator()
339353
}
340354

355+
let processor = LLMUserInputProcessor(
356+
tokenizer: tokenizer, configuration: configuration,
357+
messageGenerator: messageGenerator)
358+
341359
return .init(
342-
configuration: configuration, model: model,
343-
processor: LLMUserInputProcessor(
344-
tokenizer: tokenizer, configuration: configuration,
345-
messageGenerator: messageGenerator),
346-
tokenizer: tokenizer)
360+
configuration: configuration, model: model, processor: processor, tokenizer: tokenizer)
347361
}
348362

349363
}

Libraries/MLXLMCommon/ModelFactory.swift

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,40 @@ import Tokenizers
77
public enum ModelFactoryError: LocalizedError {
88
case unsupportedModelType(String)
99
case unsupportedProcessorType(String)
10+
case configurationDecodingError(String, String, DecodingError)
1011

1112
public var errorDescription: String? {
1213
switch self {
13-
case .unsupportedModelType(let type): "Unsupported model type: \(type)"
14-
case .unsupportedProcessorType(let type): "Unsupported processor type: \(type)"
14+
case .unsupportedModelType(let type):
15+
return "Unsupported model type: \(type)"
16+
case .unsupportedProcessorType(let type):
17+
return "Unsupported processor type: \(type)"
18+
case .configurationDecodingError(let file, let modelName, let decodingError):
19+
let errorDetail = extractDecodingErrorDetail(decodingError)
20+
return "Failed to parse \(file) for model '\(modelName)': \(errorDetail)"
21+
}
22+
}
23+
24+
private func extractDecodingErrorDetail(_ error: DecodingError) -> String {
25+
switch error {
26+
case .keyNotFound(let key, let context):
27+
let path = (context.codingPath + [key]).map { $0.stringValue }.joined(separator: ".")
28+
return "Missing field '\(path)'"
29+
case .typeMismatch(_, let context):
30+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
31+
return "Type mismatch at '\(path)'"
32+
case .valueNotFound(_, let context):
33+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
34+
return "Missing value at '\(path)'"
35+
case .dataCorrupted(let context):
36+
if context.codingPath.isEmpty {
37+
return "Invalid JSON"
38+
} else {
39+
let path = context.codingPath.map { $0.stringValue }.joined(separator: ".")
40+
return "Invalid data at '\(path)'"
41+
}
42+
@unknown default:
43+
return error.localizedDescription
1544
}
1645
}
1746
}

Libraries/MLXVLM/VLMModelFactory.swift

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,24 @@ public class VLMModelFactory: ModelFactory {
212212
let configurationURL = modelDirectory.appending(
213213
component: "config.json"
214214
)
215-
let baseConfig = try JSONDecoder().decode(
216-
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
217215

218-
let model = try typeRegistry.createModel(
219-
configuration: configurationURL, modelType: baseConfig.modelType)
216+
let baseConfig: BaseConfiguration
217+
do {
218+
baseConfig = try JSONDecoder().decode(
219+
BaseConfiguration.self, from: Data(contentsOf: configurationURL))
220+
} catch let error as DecodingError {
221+
throw ModelFactoryError.configurationDecodingError(
222+
configurationURL.lastPathComponent, configuration.name, error)
223+
}
224+
225+
let model: LanguageModel
226+
do {
227+
model = try typeRegistry.createModel(
228+
configuration: configurationURL, modelType: baseConfig.modelType)
229+
} catch let error as DecodingError {
230+
throw ModelFactoryError.configurationDecodingError(
231+
configurationURL.lastPathComponent, configuration.name, error)
232+
}
220233

221234
// apply the weights to the bare model
222235
try loadWeights(
@@ -228,17 +241,23 @@ public class VLMModelFactory: ModelFactory {
228241
hub: hub
229242
)
230243

231-
let processorConfiguration = modelDirectory.appending(
244+
let processorConfigurationURL = modelDirectory.appending(
232245
component: "preprocessor_config.json"
233246
)
234-
let baseProcessorConfig = try JSONDecoder().decode(
235-
BaseProcessorConfiguration.self,
236-
from: Data(
237-
contentsOf: processorConfiguration
247+
248+
let baseProcessorConfig: BaseProcessorConfiguration
249+
do {
250+
baseProcessorConfig = try JSONDecoder().decode(
251+
BaseProcessorConfiguration.self,
252+
from: Data(contentsOf: processorConfigurationURL)
238253
)
239-
)
254+
} catch let error as DecodingError {
255+
throw ModelFactoryError.configurationDecodingError(
256+
processorConfigurationURL.lastPathComponent, configuration.name, error)
257+
}
258+
240259
let processor = try processorRegistry.createModel(
241-
configuration: processorConfiguration,
260+
configuration: processorConfigurationURL,
242261
processorType: baseProcessorConfig.processorClass, tokenizer: tokenizer)
243262

244263
return .init(

0 commit comments

Comments
 (0)