From c477f77954b162f10a258ca80a615f0b516c2778 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Thu, 27 Mar 2025 12:48:45 +0000 Subject: [PATCH] TLSConfiguration --- .../Redis/Connection/RedisConnection.swift | 48 ++--- Sources/Redis/Connection/TSTLSOptions.swift | 178 ------------------ Sources/Redis/RedisClient.swift | 35 ---- Sources/Redis/RedisClientConfiguration.swift | 35 +++- 4 files changed, 47 insertions(+), 249 deletions(-) delete mode 100644 Sources/Redis/Connection/TSTLSOptions.swift diff --git a/Sources/Redis/Connection/RedisConnection.swift b/Sources/Redis/Connection/RedisConnection.swift index b19ebd57..e7172086 100644 --- a/Sources/Redis/Connection/RedisConnection.swift +++ b/Sources/Redis/Connection/RedisConnection.swift @@ -15,6 +15,7 @@ import Logging import NIOCore import NIOPosix +import NIOSSL import RESP #if canImport(Network) @@ -81,24 +82,6 @@ public struct RedisConnection: Sendable { (self.requestStream, self.requestContinuation) = AsyncStream.makeStream(of: RequestStreamElement.self) } - #if canImport(Network) - /// Initialize Client with TLS options - public init( - address: ServerAddress, - configuration: RedisClientConfiguration, - transportServicesTLSOptions: TSTLSOptions, - eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, - logger: Logger - ) throws { - self.address = address - self.configuration = configuration - self.eventLoopGroup = eventLoopGroup - self.logger = logger - self.tlsOptions = transportServicesTLSOptions.options - (self.requestStream, self.requestContinuation) = AsyncStream.makeStream(of: RequestStreamElement.self) - } - #endif - public func run() async throws { let asyncChannel = try await self.makeClient( address: self.address @@ -265,26 +248,14 @@ public struct RedisConnection: Sendable { result = try await bootstrap .connect(host: host, port: port) { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(RESPTokenHandler()) - return try NIOAsyncChannel( - wrappingChannelSynchronously: channel, - configuration: .init() - ) - } + setupChannel(channel) } self.logger.debug("Client connnected to \(host):\(port)") case .unixDomainSocket(let path): result = try await bootstrap .connect(unixDomainSocketPath: path) { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(RESPTokenHandler()) - return try NIOAsyncChannel( - wrappingChannelSynchronously: channel, - configuration: .init() - ) - } + setupChannel(channel) } self.logger.debug("Client connnected to socket path \(path)") } @@ -294,6 +265,19 @@ public struct RedisConnection: Sendable { } } + private func setupChannel(_ channel: Channel) -> EventLoopFuture> { + channel.eventLoop.makeCompletedFuture { + if case .enable(let sslContext, let tlsServerName) = self.configuration.tls.base { + try channel.pipeline.syncOperations.addHandler(NIOSSLClientHandler(context: sslContext, serverHostname: tlsServerName)) + } + try channel.pipeline.syncOperations.addHandler(RESPTokenHandler()) + return try NIOAsyncChannel( + wrappingChannelSynchronously: channel, + configuration: .init() + ) + } + } + /// create a BSD sockets based bootstrap private func createSocketsBootstrap() -> ClientBootstrap { ClientBootstrap(group: self.eventLoopGroup) diff --git a/Sources/Redis/Connection/TSTLSOptions.swift b/Sources/Redis/Connection/TSTLSOptions.swift deleted file mode 100644 index b6979236..00000000 --- a/Sources/Redis/Connection/TSTLSOptions.swift +++ /dev/null @@ -1,178 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the Hummingbird server framework project -// -// Copyright (c) 2021-2024 the Hummingbird authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -#if canImport(Network) -import Foundation -import Network -import Security - -/// Wrapper for NIO transport services TLS options -public struct TSTLSOptions: Sendable { - public struct Error: Swift.Error, Equatable { - enum _Internal: Equatable { - case invalidFormat - case interactionNotAllowed - case verificationFailed - } - - private let value: _Internal - init(_ value: _Internal) { - self.value = value - } - - // invalid format - public static var invalidFormat: Self { .init(.invalidFormat) } - // unable to import p12 as no interaction is allowed - public static var interactionNotAllowed: Self { .init(.interactionNotAllowed) } - // MAC verification failed during PKCS12 import (wrong password?) - public static var verificationFailed: Self { .init(.verificationFailed) } - } - - public struct Identity { - let secIdentity: SecIdentity - - public static func secIdentity(_ secIdentity: SecIdentity) -> Self { - .init(secIdentity: secIdentity) - } - - public static func p12(filename: String, password: String) throws -> Self { - guard let secIdentity = try Self.loadP12(filename: filename, password: password) else { throw Error.invalidFormat } - return .init(secIdentity: secIdentity) - } - - private static func loadP12(filename: String, password: String) throws -> SecIdentity? { - let data = try Data(contentsOf: URL(fileURLWithPath: filename)) - let options: [String: String] = [kSecImportExportPassphrase as String: password] - var rawItems: CFArray? - let result = SecPKCS12Import(data as CFData, options as CFDictionary, &rawItems) - switch result { - case errSecSuccess: - break - case errSecInteractionNotAllowed: - throw Error.interactionNotAllowed - case errSecPkcs12VerifyFailure: - throw Error.verificationFailed - default: - throw Error.invalidFormat - } - let items = rawItems! as! [[String: Any]] - let firstItem = items[0] - return firstItem[kSecImportItemIdentity as String] as! SecIdentity? - } - } - - /// Struct defining an array of certificates - public struct Certificates { - let certificates: [SecCertificate] - - /// Create certificate array from already loaded SecCertificate array - public static var none: Self { .init(certificates: []) } - - /// Create certificate array from already loaded SecCertificate array - public static func certificates(_ secCertificates: [SecCertificate]) -> Self { .init(certificates: secCertificates) } - - /// Create certificate array from DER file - public static func der(filename: String) throws -> Self { - let certificateData = try Data(contentsOf: URL(fileURLWithPath: filename)) - guard let secCertificate = SecCertificateCreateWithData(nil, certificateData as CFData) else { throw Error.invalidFormat } - return .init(certificates: [secCertificate]) - } - } - - /// Initialize TSTLSOptions - public init(_ options: NWProtocolTLS.Options?) { - if let options { - self.value = .some(options) - } else { - self.value = .none - } - } - - /// TSTLSOptions holding options - public static func options(_ options: NWProtocolTLS.Options) -> Self { - .init(value: .some(options)) - } - - public static func options( - serverIdentity: Identity - ) -> Self? { - let options = NWProtocolTLS.Options() - - // server identity - guard let secIdentity = sec_identity_create(serverIdentity.secIdentity) else { return nil } - sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) - - return .init(value: .some(options)) - } - - public static func options( - clientIdentity: Identity, - trustRoots: Certificates = .none, - serverName: String? = nil - ) -> Self? { - let options = NWProtocolTLS.Options() - - // server identity - guard let secIdentity = sec_identity_create(clientIdentity.secIdentity) else { return nil } - sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) - if let serverName { - sec_protocol_options_set_tls_server_name(options.securityProtocolOptions, serverName) - } - // sec_protocol_options_set - sec_protocol_options_set_local_identity(options.securityProtocolOptions, secIdentity) - - // add verify block to control certificate verification - if trustRoots.certificates.count > 0 { - sec_protocol_options_set_verify_block( - options.securityProtocolOptions, - { _, sec_trust, sec_protocol_verify_complete in - let trust = sec_trust_copy_ref(sec_trust).takeRetainedValue() - SecTrustSetAnchorCertificates(trust, trustRoots.certificates as CFArray) - SecTrustEvaluateAsyncWithError(trust, Self.tlsDispatchQueue) { _, result, error in - if let error { - print("Trust failed: \(error.localizedDescription)") - } - sec_protocol_verify_complete(result) - } - }, - Self.tlsDispatchQueue - ) - } - return .init(value: .some(options)) - } - - /// Empty TSTLSOptions - public static var none: Self { - .init(value: .none) - } - - var options: NWProtocolTLS.Options? { - if case .some(let options) = self.value { return options } - return nil - } - - /// Internal storage for TSTLSOptions. @unchecked Sendable while NWProtocolTLS.Options - /// is not Sendable - private enum Internal: @unchecked Sendable { - case some(NWProtocolTLS.Options) - case none - } - - private let value: Internal - private init(value: Internal) { self.value = value } - - /// Dispatch queue used by Network framework TLS to control certificate verification - static let tlsDispatchQueue = DispatchQueue(label: "WSTSTLSConfiguration") -} -#endif diff --git a/Sources/Redis/RedisClient.swift b/Sources/Redis/RedisClient.swift index 298a65cd..8e9c327e 100644 --- a/Sources/Redis/RedisClient.swift +++ b/Sources/Redis/RedisClient.swift @@ -24,13 +24,6 @@ import NIOTransportServices /// /// Supports TLS via both NIOSSL and Network framework. public struct RedisClient { - enum MultiPlatformTLSConfiguration: Sendable { - case niossl(TLSConfiguration) - #if canImport(Network) - case ts(TSTLSOptions) - #endif - } - /// Server address let serverAddress: ServerAddress /// configuration @@ -39,8 +32,6 @@ public struct RedisClient { let eventLoopGroup: EventLoopGroup /// Logger let logger: Logger - /// TLS configuration - let tlsConfiguration: MultiPlatformTLSConfiguration? /// Initialize Redis client /// @@ -53,7 +44,6 @@ public struct RedisClient { public init( _ address: ServerAddress, configuration: RedisClientConfiguration = .init(), - tlsConfiguration: TLSConfiguration? = nil, eventLoopGroup: EventLoopGroup = MultiThreadedEventLoopGroup.singleton, logger: Logger ) { @@ -61,32 +51,7 @@ public struct RedisClient { self.configuration = configuration self.eventLoopGroup = eventLoopGroup self.logger = logger - self.tlsConfiguration = tlsConfiguration.map { .niossl($0) } - } - - #if canImport(Network) - /// Initialize Redis client - /// - /// - Parameters: - /// - address: redis database address - /// - configuration: Redis client configuration - /// - transportServicesTLSOptions: Redis TLS connection configuration - /// - eventLoopGroup: EventLoopGroup to run WebSocket client on - /// - logger: Logger - public init( - _ address: ServerAddress, - configuration: RedisClientConfiguration = .init(), - transportServicesTLSOptions: TSTLSOptions, - eventLoopGroup: NIOTSEventLoopGroup = NIOTSEventLoopGroup.singleton, - logger: Logger - ) { - self.serverAddress = address - self.configuration = configuration - self.eventLoopGroup = eventLoopGroup - self.logger = logger - self.tlsConfiguration = .ts(transportServicesTLSOptions) } - #endif } extension RedisClient { diff --git a/Sources/Redis/RedisClientConfiguration.swift b/Sources/Redis/RedisClientConfiguration.swift index a0d0cb77..9d19492d 100644 --- a/Sources/Redis/RedisClientConfiguration.swift +++ b/Sources/Redis/RedisClientConfiguration.swift @@ -12,19 +12,46 @@ // //===----------------------------------------------------------------------===// +import NIOSSL + /// Configuration for the redis client public struct RedisClientConfiguration: Sendable { - public enum RESPVersion: Sendable { - case v2 - case v3 + public struct RESPVersion: Sendable, Equatable { + enum Base { + case v2 + case v3 + } + let base: Base + + public static var v2: Self { .init(base: .v2) } + public static var v3: Self { .init(base: .v3) } + } + + public struct TLS: Sendable { + enum Base { + case disable + case enable(NIOSSLContext, String?) + } + let base: Base + + public static var disable: Self { .init(base: .disable) } + public static func enable(tlsConfiguration: TLSConfiguration, tlsServerName: String?) throws -> Self { + .init(base: .enable(try NIOSSLContext(configuration: tlsConfiguration), tlsServerName)) + } } public var respVersion: RESPVersion + public var tls: TLS /// Initialize RedisClientConfiguration /// - Parameters /// - respVersion: RESP version to use - public init(respVersion: RESPVersion = .v3) { + /// - tlsConfiguration: TLS configuration + public init( + respVersion: RESPVersion = .v3, + tls: TLS = .disable + ) { self.respVersion = respVersion + self.tls = tls } }