From 13451d58168d87f0a31a20650041dd5c0697b2e1 Mon Sep 17 00:00:00 2001 From: danthorpe Date: Wed, 11 Oct 2023 23:23:29 +0100 Subject: [PATCH] feat: Add Authentication --- .../Authentication/Authentication.swift | 104 ++++++++++++++++++ .../AuthenticationDelegate.swift | 32 ++++++ .../Authentication/AuthenticationMethod.swift | 16 +++ .../Authentication/BasicAuthentication.swift | 38 +++++++ .../Authentication/BearerAuthentication.swift | 33 ++++++ .../HeaderBasedAuthentication.swift | 94 ++++++++++++++++ Sources/Networking/Core/HTTPRequestData.swift | 2 +- .../TestAuthenticationDelegate.swift | 37 +++++++ .../BasicCredentialsTests.swift | 23 ++++ .../BearerCredentialsTests.swift | 22 ++++ .../HeaderBasedAuthenticationTests.swift | 98 +++++++++++++++++ 11 files changed, 498 insertions(+), 1 deletion(-) create mode 100644 Sources/Networking/Components/Authentication/Authentication.swift create mode 100644 Sources/Networking/Components/Authentication/AuthenticationDelegate.swift create mode 100644 Sources/Networking/Components/Authentication/AuthenticationMethod.swift create mode 100644 Sources/Networking/Components/Authentication/BasicAuthentication.swift create mode 100644 Sources/Networking/Components/Authentication/BearerAuthentication.swift create mode 100644 Sources/Networking/Components/Authentication/HeaderBasedAuthentication.swift create mode 100644 Sources/TestSupport/TestAuthenticationDelegate.swift create mode 100644 Tests/NetworkingTests/Components/Authentication/BasicCredentialsTests.swift create mode 100644 Tests/NetworkingTests/Components/Authentication/BearerCredentialsTests.swift create mode 100644 Tests/NetworkingTests/Components/Authentication/HeaderBasedAuthenticationTests.swift diff --git a/Sources/Networking/Components/Authentication/Authentication.swift b/Sources/Networking/Components/Authentication/Authentication.swift new file mode 100644 index 00000000..c2171178 --- /dev/null +++ b/Sources/Networking/Components/Authentication/Authentication.swift @@ -0,0 +1,104 @@ +import Helpers + +extension NetworkingComponent { + public func authenticated(with delegate: Delegate) -> some NetworkingComponent { + checkedStatusCode().modified(Authentication(delegate: delegate)) + } +} + +struct Authentication: NetworkingModifier { + typealias Credentials = Delegate.Credentials + let delegate: Delegate + + func send(upstream: NetworkingComponent, request: HTTPRequestData) -> ResponseStream { + guard let method = request.authenticationMethod, method == Credentials.method else { + return upstream.send(request) + } + return ResponseStream { continuation in + Task { + + // Fetch the initial credentials + var credentials: Credentials + do { + credentials = try await delegate.fetch(for: request) + } catch { + continuation.finish( + throwing: AuthenticationError.fetchCredentialsFailed(request, Credentials.method, error) + ) + return + } + + // Update the request to use the credentials + let newRequest = credentials.apply(to: request) + + // Process the stream + do { + for try await event in upstream.send(newRequest) { + continuation.yield(event) + } + continuation.finish() + } catch let StackError.unauthorized(response) { + let newRequest = try await refresh( + unauthorized: &credentials, + response: response, + continuation: continuation + ) + await upstream.send(newRequest).redirect(into: continuation) + } catch { + continuation.finish(throwing: error) + } + } + } + } + + func refresh( + unauthorized credentials: inout Credentials, + response: HTTPResponseData, + continuation: ResponseStream.Continuation + ) async throws -> HTTPRequestData { + do { + credentials = try await delegate.refresh(unauthorized: credentials, from: response) + return credentials.apply(to: response.request) + } catch { + throw AuthenticationError.refreshCredentialsFailed(response, Credentials.method, error) + } + } +} + +public enum AuthenticationError: Error { + case fetchCredentialsFailed(HTTPRequestData, AuthenticationMethod, Error) + case refreshCredentialsFailed(HTTPResponseData, AuthenticationMethod, Error) +} + +extension AuthenticationError: Equatable { + public static func == (lhs: AuthenticationError, rhs: AuthenticationError) -> Bool { + switch (lhs, rhs) { + case let (.fetchCredentialsFailed(lhsR, lhsAM, lhsE), .fetchCredentialsFailed(rhsR, rhsAM, rhsE)): + return lhsR == rhsR && lhsAM == rhsAM && _isEqual(lhsE, rhsE) + case let (.refreshCredentialsFailed(lhsR, lhsAM, lhsE), .refreshCredentialsFailed(rhsR, rhsAM, rhsE)): + return lhsR == rhsR && lhsAM == rhsAM && _isEqual(lhsE, rhsE) + default: + return false + } + } +} + +extension AuthenticationError: NetworkingError { + public var request: HTTPRequestData { + switch self { + case let .fetchCredentialsFailed(request, _, _): + return request + case let .refreshCredentialsFailed(response, _, _): + return response.request + } + } + + public var response: HTTPResponseData? { + switch self { + case .fetchCredentialsFailed: + return nil + case let .refreshCredentialsFailed(response, _, _): + return response + } + } +} diff --git a/Sources/Networking/Components/Authentication/AuthenticationDelegate.swift b/Sources/Networking/Components/Authentication/AuthenticationDelegate.swift new file mode 100644 index 00000000..2811e47c --- /dev/null +++ b/Sources/Networking/Components/Authentication/AuthenticationDelegate.swift @@ -0,0 +1,32 @@ +/// A system which can asynchronously fetch or refresh credentials +/// in order to make authenticated HTTP requests +public protocol AuthenticationDelegate: Sendable { // swiftlint:disable:this class_delegate_protocol + + /// A type which represents the credentials to be used + associatedtype Credentials: AuthenticatingCredentials + + /// The entry point into the authentication flow + /// + /// Conforming types should manage their own state, providing thread safety + /// and perform whatever actions are necessary to retreive credentials from + /// an external system. For example - present a login interface to the user + /// to collect a username and password. + func fetch(for request: HTTPRequestData) async throws -> Credentials + + /// After supplying a request with credentials, it is still possible to + /// encounter HTTP unauthorized errors. In such an event, this method will + /// be called, allowing for a single attempt to retry with a new set of + /// credentials. Typical usecases here would be for OAuth style refreshing + /// of a token. + func refresh(unauthorized: Credentials, from response: HTTPResponseData) async throws -> Credentials +} + +public protocol AuthenticatingCredentials: Hashable, Sendable { + + /// The authentication method + static var method: AuthenticationMethod { get } + + /// Create a new request making use of the credentials in whichever way + /// suits their purpose. E.g. by appending a query parameter + func apply(to request: HTTPRequestData) -> HTTPRequestData +} diff --git a/Sources/Networking/Components/Authentication/AuthenticationMethod.swift b/Sources/Networking/Components/Authentication/AuthenticationMethod.swift new file mode 100644 index 00000000..7072d422 --- /dev/null +++ b/Sources/Networking/Components/Authentication/AuthenticationMethod.swift @@ -0,0 +1,16 @@ +public struct AuthenticationMethod: Hashable, RawRepresentable, Sendable, HTTPRequestDataOption { + public static let defaultOption: Self? = nil + + public let rawValue: String + + public init(rawValue: String) { + self.rawValue = rawValue + } +} + +extension HTTPRequestData { + public var authenticationMethod: AuthenticationMethod? { + get { self[option: AuthenticationMethod.self] } + set { self[option: AuthenticationMethod.self] = newValue } + } +} diff --git a/Sources/Networking/Components/Authentication/BasicAuthentication.swift b/Sources/Networking/Components/Authentication/BasicAuthentication.swift new file mode 100644 index 00000000..0ae9e6fa --- /dev/null +++ b/Sources/Networking/Components/Authentication/BasicAuthentication.swift @@ -0,0 +1,38 @@ +import Foundation + +extension AuthenticationMethod { + public static let basic = AuthenticationMethod(rawValue: "Basic") +} + +public struct BasicCredentials: Hashable, Sendable, AuthenticatingCredentials, HTTPRequestDataOption { + public static var method: AuthenticationMethod = .basic + public static let defaultOption: Self? = nil + + public let user: String + public let password: String + + public init(user: String, password: String) { + self.user = user + self.password = password + } + + public func apply(to request: HTTPRequestData) -> HTTPRequestData { + var copy = request + let joined = user + ":" + password + let data = Data(joined.utf8) + let encoded = data.base64EncodedString() + copy.headerFields[.authorization] = "Basic \(encoded)" + return copy + } +} + +extension HTTPRequestData { + public var basicCredentials: BasicCredentials? { + get { self[option: BasicCredentials.self] } + set { self[option: BasicCredentials.self] = newValue } + } +} + +public typealias BasicAuthentication< + Delegate: AuthenticationDelegate +> = HeaderBasedAuthentication where Delegate.Credentials == BasicCredentials diff --git a/Sources/Networking/Components/Authentication/BearerAuthentication.swift b/Sources/Networking/Components/Authentication/BearerAuthentication.swift new file mode 100644 index 00000000..9742d466 --- /dev/null +++ b/Sources/Networking/Components/Authentication/BearerAuthentication.swift @@ -0,0 +1,33 @@ +import Foundation + +extension AuthenticationMethod { + public static let bearer = AuthenticationMethod(rawValue: "Bearer") +} + +public struct BearerCredentials: Hashable, Sendable, Codable, HTTPRequestDataOption, AuthenticatingCredentials { + public static let method: AuthenticationMethod = .bearer + public static let defaultOption: Self? = nil + + public let token: String + + public init(token: String) { + self.token = token + } + + public func apply(to request: HTTPRequestData) -> HTTPRequestData { + var copy = request + copy.headerFields[.authorization] = "Bearer \(token)" + return copy + } +} + +extension HTTPRequestData { + public var bearerCredentials: BearerCredentials? { + get { self[option: BearerCredentials.self] } + set { self[option: BearerCredentials.self] = newValue } + } +} + +public typealias BearerAuthentication< + Delegate: AuthenticationDelegate +> = HeaderBasedAuthentication where Delegate.Credentials == BearerCredentials diff --git a/Sources/Networking/Components/Authentication/HeaderBasedAuthentication.swift b/Sources/Networking/Components/Authentication/HeaderBasedAuthentication.swift new file mode 100644 index 00000000..6ac47b2e --- /dev/null +++ b/Sources/Networking/Components/Authentication/HeaderBasedAuthentication.swift @@ -0,0 +1,94 @@ +public struct HeaderBasedAuthentication { + actor StateMachine: AuthenticationDelegate { + typealias Credentials = Delegate.Credentials + + private enum State { + case idle + case fetching(Task) + case authorized(Credentials) + } + + let delegate: Delegate + private var state: State = .idle + + @NetworkEnvironment(\.logger) var logger + + init(delegate: Delegate) { + self.delegate = delegate + } + + private func set(state: State) { + self.state = state + } + + func fetch(for request: HTTPRequestData) async throws -> Credentials { + switch state { + case let .authorized(credentials): + return credentials + case let .fetching(task): + return try await task.value + case .idle: + let task = Task { try await performCredentialFetch(for: request) } + set(state: .fetching(task)) + return try await task.value + } + } + + private func performCredentialFetch(for request: HTTPRequestData) async throws -> Credentials { + logger?.info("🔐 Fetching credentials for \(Credentials.method.rawValue, privacy: .public) authorization method") + do { + let credentials = try await delegate.fetch(for: request) + set(state: .authorized(credentials)) + return credentials + } catch { + set(state: .idle) + throw AuthenticationError.fetchCredentialsFailed(request, Credentials.method, error) + } + } + + func refresh(unauthorized credentials: Credentials, from response: HTTPResponseData) async throws -> Credentials { + if case let .fetching(task) = state { + return try await task.value + } + + let task = Task { try await performCredentialRefresh(unauthorized: credentials, from: response) } + set(state: .fetching(task)) + return try await task.value + } + + private func performCredentialRefresh( + unauthorized credentials: Credentials, + from response: HTTPResponseData + ) async throws -> Credentials { + logger?.info( + "🔑 Refreshing credentials for \(Credentials.method.rawValue, privacy: .public) authorization method" + ) + do { + let refreshed = try await delegate.refresh(unauthorized: credentials, from: response) + set(state: .authorized(refreshed)) + return refreshed + } catch { + set(state: .idle) + throw AuthenticationError.refreshCredentialsFailed(response, Credentials.method, error) + } + } + } + + fileprivate let state: StateMachine + + public init(delegate: Delegate) { + state = StateMachine(delegate: delegate) + } +} + +extension HeaderBasedAuthentication: AuthenticationDelegate { + public typealias Credentials = Delegate.Credentials + public func fetch(for request: HTTPRequestData) async throws -> Credentials { + try await state.fetch(for: request) + } + public func refresh( + unauthorized credentials: Credentials, from response: HTTPResponseData + ) async throws -> Credentials { + try await state.refresh(unauthorized: credentials, from: response) + } +} diff --git a/Sources/Networking/Core/HTTPRequestData.swift b/Sources/Networking/Core/HTTPRequestData.swift index 31e59500..9abb9fde 100644 --- a/Sources/Networking/Core/HTTPRequestData.swift +++ b/Sources/Networking/Core/HTTPRequestData.swift @@ -31,7 +31,7 @@ public struct HTTPRequestData: Sendable, Identifiable { id: ID, method: HTTPRequest.Method = .get, scheme: String? = "https", - authority: String?, + authority: String? = nil, path: String? = nil, headerFields: HTTPFields = [:], body: Data? = nil diff --git a/Sources/TestSupport/TestAuthenticationDelegate.swift b/Sources/TestSupport/TestAuthenticationDelegate.swift new file mode 100644 index 00000000..4749cac5 --- /dev/null +++ b/Sources/TestSupport/TestAuthenticationDelegate.swift @@ -0,0 +1,37 @@ +import Networking +import Protected +import XCTestDynamicOverlay + +public final class TestAuthenticationDelegate: @unchecked Sendable { + public typealias Fetch = @Sendable (HTTPRequestData) async throws -> Credentials + public typealias Refresh = @Sendable (Credentials, HTTPResponseData) async throws -> Credentials + + @Protected public var fetchCount: Int = 0 + @Protected public var refreshCount: Int = 0 + + public var fetch: Fetch + public var refresh: Refresh + + public init( + fetch: @escaping Fetch = unimplemented("TestAuthenticationDelegate.fetch"), + refresh: @escaping Refresh = unimplemented("TestAuthenticationDelegate.refresh") + ) { + self.fetch = fetch + self.refresh = refresh + } +} + +extension TestAuthenticationDelegate: AuthenticationDelegate { + public func fetch(for request: HTTPRequestData) async throws -> Credentials { + fetchCount += 1 + return try await fetch(request) + } + + public func refresh( + unauthorized credentials: Credentials, + from response: HTTPResponseData + ) async throws -> Credentials { + refreshCount += 1 + return try await refresh(credentials, response) + } +} diff --git a/Tests/NetworkingTests/Components/Authentication/BasicCredentialsTests.swift b/Tests/NetworkingTests/Components/Authentication/BasicCredentialsTests.swift new file mode 100644 index 00000000..63cbeb23 --- /dev/null +++ b/Tests/NetworkingTests/Components/Authentication/BasicCredentialsTests.swift @@ -0,0 +1,23 @@ +import AssertionExtras +import Dependencies +import Foundation +import HTTPTypes +import TestSupport +import XCTest + +@testable import Networking + +final class BasicCredentialsTests: XCTestCase { + func test__apply_credentials() { + let credentials = BasicCredentials(user: "blob", password: "super!$3cret") + let request = credentials.apply(to: HTTPRequestData(id: "1")) + XCTAssertEqual(request.headerFields[.authorization], "Basic YmxvYjpzdXBlciEkM2NyZXQ=") + } + + func test__provide_credentials() { + var request = HTTPRequestData(id: "1") + request.basicCredentials = BasicCredentials(user: "blob", password: "super!$3cret") + XCTAssertEqual(request.basicCredentials?.user, "blob") + XCTAssertEqual(request.basicCredentials?.password, "super!$3cret") + } +} diff --git a/Tests/NetworkingTests/Components/Authentication/BearerCredentialsTests.swift b/Tests/NetworkingTests/Components/Authentication/BearerCredentialsTests.swift new file mode 100644 index 00000000..91e6e4d4 --- /dev/null +++ b/Tests/NetworkingTests/Components/Authentication/BearerCredentialsTests.swift @@ -0,0 +1,22 @@ +import AssertionExtras +import Dependencies +import Foundation +import HTTPTypes +import TestSupport +import XCTest + +@testable import Networking + +final class BearerCredentialsTests: XCTestCase { + func test__apply_credentials() { + let credentials = BearerCredentials(token: "super!$3cret") + let request = credentials.apply(to: HTTPRequestData(id: "1")) + XCTAssertEqual(request.headerFields[.authorization], "Bearer super!$3cret") + } + + func test__provide_credentials() { + var request = HTTPRequestData(id: "1") + request.bearerCredentials = BearerCredentials(token: "super!$3cret") + XCTAssertEqual(request.bearerCredentials?.token, "super!$3cret") + } +} diff --git a/Tests/NetworkingTests/Components/Authentication/HeaderBasedAuthenticationTests.swift b/Tests/NetworkingTests/Components/Authentication/HeaderBasedAuthenticationTests.swift new file mode 100644 index 00000000..d6d28197 --- /dev/null +++ b/Tests/NetworkingTests/Components/Authentication/HeaderBasedAuthenticationTests.swift @@ -0,0 +1,98 @@ +import AssertionExtras +import Dependencies +import Foundation +import HTTPTypes +import TestSupport +import XCTest + +@testable import Networking + +final class HeaderBasedAuthenticationTests: XCTestCase { + + func test__given_requires_credentials__delegate_is_triggered() async throws { + var request = HTTPRequestData(id: "1") + request.authenticationMethod = .bearer + + let delelgate = TestAuthenticationDelegate( + fetch: { [request] in + XCTAssertEqual(request, $0) + return BearerCredentials(token: "some token") + } + ) + + let authenticator = HeaderBasedAuthentication(delegate: delelgate) + + let newRequest = try await authenticator.fetch(for: request).apply(to: request) + + XCTAssertEqual(newRequest.headerFields[.authorization], "Bearer some token") + XCTAssertEqual(delelgate.fetchCount, 1) + } + + func test__given_delegate_throws_error() async throws { + struct CustomError: Error, Equatable { } + + var request = HTTPRequestData(id: "1") + request.authenticationMethod = .bearer + + let delelgate = TestAuthenticationDelegate( + fetch: { _ in + throw CustomError() + } + ) + + let authenticator = HeaderBasedAuthentication(delegate: delelgate) + + await XCTAssertThrowsError( + try await authenticator.fetch(for: request), + matches: AuthenticationError.fetchCredentialsFailed(request, .bearer, CustomError()) + ) + } + + func test__requests_are_queued_until_delegate_responds() async throws { + + let delelgate = TestAuthenticationDelegate( + fetch: { _ in + return BearerCredentials(token: "some token") + } + ) + + let authenticator = HeaderBasedAuthentication(delegate: delelgate) + + @Sendable func check(authority: String) async throws -> HTTPRequestData { + var request = HTTPRequestData(authority: "example.com") + request.authenticationMethod = .bearer + return try await authenticator.fetch(for: request).apply(to: request) + } + + try await withDependencies { + $0.shortID = .incrementing + } operation: { + try await withMainSerialExecutor { + let requests = try await withThrowingTaskGroup(of: HTTPRequestData.self) { group in + group.addTask { + try await check(authority: "example.com") + } + group.addTask { + try await check(authority: "example.co.uk") + } + group.addTask { + try await check(authority: "example.fr") + } + group.addTask { + try await check(authority: "example.com") + } + + var requests: [HTTPRequestData] = [] + for try await request in group { + requests.append(request) + } + return requests + } + + let authorization = Set(requests.compactMap(\.headerFields[.authorization])) + XCTAssertEqual(authorization, ["Bearer some token"]) + XCTAssertEqual(delelgate.fetchCount, 1) + } + } + } +}