diff --git a/Sources/InfomaniakDI/SimpleResolver.swift b/Sources/InfomaniakDI/SimpleResolver.swift index 5e30a2f..28d25d7 100644 --- a/Sources/InfomaniakDI/SimpleResolver.swift +++ b/Sources/InfomaniakDI/SimpleResolver.swift @@ -49,17 +49,18 @@ public protocol SimpleStorable: Sendable { /// A minimalist DI solution /// Once initiated, stores types as long as the app lives public final class SimpleResolver: SimpleResolvable, SimpleStorable, CustomDebugStringConvertible, @unchecked Sendable { + private let recursiveLock = NSRecursiveLock() + public var debugDescription: String { - var buffer: String! - queue.sync { - buffer = """ - <\(type(of: self)):\(Unmanaged.passUnretained(self).toOpaque()) - \(factories.count) factories and \(store.count) stored types - factories: \(factories) - store: \(store)> - """ - } - return buffer + recursiveLock.lock() + defer { recursiveLock.unlock() } + + return """ + <\(type(of: self)):\(Unmanaged.passUnretained(self).toOpaque()) + \(factories.count) factories and \(store.count) stored types + factories: \(factories) + store: \(store)> + """ } enum ErrorDomain: Error { @@ -76,19 +77,16 @@ public final class SimpleResolver: SimpleResolvable, SimpleStorable, CustomDebug /// Resolved object collection var store = [String: Any]() - /// A serial queue for thread safety - private let queue = DispatchQueue(label: "com.infomaniakDI.resolver") - // MARK: SimpleStorable public func store(factory: Factoryable, forCustomTypeIdentifier customIdentifier: String? = nil) { - let type = factory.type + recursiveLock.lock() + defer { recursiveLock.unlock() } + let type = factory.type let identifier = buildIdentifier(type: type, forIdentifier: customIdentifier) - queue.sync { - factories[identifier] = factory - } + factories[identifier] = factory } // MARK: SimpleResolvable @@ -97,38 +95,36 @@ public final class SimpleResolver: SimpleResolvable, SimpleStorable, CustomDebug forCustomTypeIdentifier customIdentifier: String?, factoryParameters: [String: Any]? = nil, resolver: SimpleResolvable) throws -> Service { - let serviceIdentifier = buildIdentifier(type: type, forIdentifier: customIdentifier) + recursiveLock.lock() + defer { recursiveLock.unlock() } - // load form store - var fetchedService: Any? - queue.sync { - fetchedService = store[serviceIdentifier] - } - if let service = fetchedService as? Service { - return service - } + let serviceIdentifier = buildIdentifier(type: type, forIdentifier: customIdentifier) + return try loadOrResolve( + serviceIdentifier: serviceIdentifier, + factoryParameters: factoryParameters, + resolver: resolver + ) + } - // load service from factory - var factory: Factoryable? - queue.sync { - factory = factories[serviceIdentifier] - } - guard let factory = factory else { - throw ErrorDomain.factoryMissing(identifier: serviceIdentifier) - } + private func loadOrResolve(serviceIdentifier: String, + factoryParameters: [String: Any]?, + resolver: SimpleResolvable) throws -> Service { + if let fetchedObject = store[serviceIdentifier], + let fetchedService = fetchedObject as? Service { + return fetchedService + } else { + guard let factory = factories[serviceIdentifier] else { + throw ErrorDomain.factoryMissing(identifier: serviceIdentifier) + } - // Apply factory closure - let builtType = try factory.build(factoryParameters: factoryParameters, resolver: resolver) - guard let service = builtType as? Service else { - throw ErrorDomain.typeMissmatch(expected: "\(Service.Type.self)", got: "\(builtType.self)") - } + let builtType = try factory.build(factoryParameters: factoryParameters, resolver: resolver) + guard let service = builtType as? Service else { + throw ErrorDomain.typeMissmatch(expected: "\(Service.Type.self)", got: "\(builtType.self)") + } - // keep in store built object for later - queue.sync { store[serviceIdentifier] = service + return service } - - return service } // MARK: internal @@ -145,9 +141,9 @@ public final class SimpleResolver: SimpleResolvable, SimpleStorable, CustomDebug // MARK: testing func removeAll() { - queue.sync { - factories.removeAll() - store.removeAll() - } + recursiveLock.lock() + factories.removeAll() + store.removeAll() + recursiveLock.unlock() } }