Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Authentication #27

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions Sources/Networking/Components/Authentication/Authentication.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import Helpers

extension NetworkingComponent {
public func authenticated<Delegate: AuthenticationDelegate>(with delegate: Delegate) -> some NetworkingComponent {
checkedStatusCode().modified(Authentication(delegate: delegate))
}
}

struct Authentication<Delegate: AuthenticationDelegate>: NetworkingModifier {
typealias Credentials = Delegate.Credentials
let delegate: Delegate

func send(upstream: NetworkingComponent, request: HTTPRequestData) -> ResponseStream<HTTPResponseData> {
guard let method = request.authenticationMethod, method == Credentials.method else {
return upstream.send(request)
}
return ResponseStream { continuation in
Task {

// Fetch the initial credentials
var credentials: Credentials
do {
credentials = try await delegate.fetch(for: request)
} catch {
continuation.finish(
throwing: AuthenticationError.fetchCredentialsFailed(request, Credentials.method, error)
)
return
}

// Update the request to use the credentials
let newRequest = credentials.apply(to: request)

// Process the stream
do {
for try await event in upstream.send(newRequest) {
continuation.yield(event)
}
continuation.finish()
} catch let StackError.unauthorized(response) {
let newRequest = try await refresh(
unauthorized: &credentials,
response: response,
continuation: continuation
)
await upstream.send(newRequest).redirect(into: continuation)
} catch {
continuation.finish(throwing: error)
}
}
}
}

func refresh(
unauthorized credentials: inout Credentials,
response: HTTPResponseData,
continuation: ResponseStream<HTTPResponseData>.Continuation
) async throws -> HTTPRequestData {
do {
credentials = try await delegate.refresh(unauthorized: credentials, from: response)
return credentials.apply(to: response.request)
} catch {
throw AuthenticationError.refreshCredentialsFailed(response, Credentials.method, error)
}
}
}

public enum AuthenticationError: Error {
case fetchCredentialsFailed(HTTPRequestData, AuthenticationMethod, Error)
case refreshCredentialsFailed(HTTPResponseData, AuthenticationMethod, Error)
}

extension AuthenticationError: Equatable {
public static func == (lhs: AuthenticationError, rhs: AuthenticationError) -> Bool {
switch (lhs, rhs) {
case let (.fetchCredentialsFailed(lhsR, lhsAM, lhsE), .fetchCredentialsFailed(rhsR, rhsAM, rhsE)):
return lhsR == rhsR && lhsAM == rhsAM && _isEqual(lhsE, rhsE)
case let (.refreshCredentialsFailed(lhsR, lhsAM, lhsE), .refreshCredentialsFailed(rhsR, rhsAM, rhsE)):
return lhsR == rhsR && lhsAM == rhsAM && _isEqual(lhsE, rhsE)
default:
return false
}
}
}

extension AuthenticationError: NetworkingError {
public var request: HTTPRequestData {
switch self {
case let .fetchCredentialsFailed(request, _, _):
return request
case let .refreshCredentialsFailed(response, _, _):
return response.request
}
}

public var response: HTTPResponseData? {
switch self {
case .fetchCredentialsFailed:
return nil
case let .refreshCredentialsFailed(response, _, _):
return response
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/// A system which can asynchronously fetch or refresh credentials
/// in order to make authenticated HTTP requests
public protocol AuthenticationDelegate<Credentials>: Sendable { // swiftlint:disable:this class_delegate_protocol

/// A type which represents the credentials to be used
associatedtype Credentials: AuthenticatingCredentials

/// The entry point into the authentication flow
///
/// Conforming types should manage their own state, providing thread safety
/// and perform whatever actions are necessary to retreive credentials from
/// an external system. For example - present a login interface to the user
/// to collect a username and password.
func fetch(for request: HTTPRequestData) async throws -> Credentials

/// After supplying a request with credentials, it is still possible to
/// encounter HTTP unauthorized errors. In such an event, this method will
/// be called, allowing for a single attempt to retry with a new set of
/// credentials. Typical usecases here would be for OAuth style refreshing
/// of a token.
func refresh(unauthorized: Credentials, from response: HTTPResponseData) async throws -> Credentials
}

public protocol AuthenticatingCredentials: Hashable, Sendable {

/// The authentication method
static var method: AuthenticationMethod { get }

/// Create a new request making use of the credentials in whichever way
/// suits their purpose. E.g. by appending a query parameter
func apply(to request: HTTPRequestData) -> HTTPRequestData
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
public struct AuthenticationMethod: Hashable, RawRepresentable, Sendable, HTTPRequestDataOption {
public static let defaultOption: Self? = nil

public let rawValue: String

public init(rawValue: String) {
self.rawValue = rawValue
}
}

extension HTTPRequestData {
public var authenticationMethod: AuthenticationMethod? {
get { self[option: AuthenticationMethod.self] }
set { self[option: AuthenticationMethod.self] = newValue }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import Foundation

extension AuthenticationMethod {
public static let basic = AuthenticationMethod(rawValue: "Basic")
}

public struct BasicCredentials: Hashable, Sendable, AuthenticatingCredentials, HTTPRequestDataOption {
public static var method: AuthenticationMethod = .basic
public static let defaultOption: Self? = nil

public let user: String
public let password: String

public init(user: String, password: String) {
self.user = user
self.password = password
}

public func apply(to request: HTTPRequestData) -> HTTPRequestData {
var copy = request
let joined = user + ":" + password
let data = Data(joined.utf8)
let encoded = data.base64EncodedString()
copy.headerFields[.authorization] = "Basic \(encoded)"
return copy
}
}

extension HTTPRequestData {
public var basicCredentials: BasicCredentials? {
get { self[option: BasicCredentials.self] }
set { self[option: BasicCredentials.self] = newValue }
}
}

public typealias BasicAuthentication<
Delegate: AuthenticationDelegate
> = HeaderBasedAuthentication<Delegate> where Delegate.Credentials == BasicCredentials
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import Foundation

extension AuthenticationMethod {
public static let bearer = AuthenticationMethod(rawValue: "Bearer")
}

public struct BearerCredentials: Hashable, Sendable, Codable, HTTPRequestDataOption, AuthenticatingCredentials {
public static let method: AuthenticationMethod = .bearer
public static let defaultOption: Self? = nil

public let token: String

public init(token: String) {
self.token = token
}

public func apply(to request: HTTPRequestData) -> HTTPRequestData {
var copy = request
copy.headerFields[.authorization] = "Bearer \(token)"
return copy
}
}

extension HTTPRequestData {
public var bearerCredentials: BearerCredentials? {
get { self[option: BearerCredentials.self] }
set { self[option: BearerCredentials.self] = newValue }
}
}

public typealias BearerAuthentication<
Delegate: AuthenticationDelegate
> = HeaderBasedAuthentication<Delegate> where Delegate.Credentials == BearerCredentials
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
public struct HeaderBasedAuthentication<Delegate: AuthenticationDelegate> {
actor StateMachine: AuthenticationDelegate {
typealias Credentials = Delegate.Credentials

private enum State {
case idle
case fetching(Task<Credentials, Error>)
case authorized(Credentials)
}

let delegate: Delegate
private var state: State = .idle

@NetworkEnvironment(\.logger) var logger

init(delegate: Delegate) {
self.delegate = delegate
}

private func set(state: State) {
self.state = state
}

func fetch(for request: HTTPRequestData) async throws -> Credentials {
switch state {
case let .authorized(credentials):
return credentials
case let .fetching(task):
return try await task.value
case .idle:
let task = Task { try await performCredentialFetch(for: request) }
set(state: .fetching(task))
return try await task.value
}
}

private func performCredentialFetch(for request: HTTPRequestData) async throws -> Credentials {
logger?.info("🔐 Fetching credentials for \(Credentials.method.rawValue, privacy: .public) authorization method")
do {
let credentials = try await delegate.fetch(for: request)
set(state: .authorized(credentials))
return credentials
} catch {
set(state: .idle)
throw AuthenticationError.fetchCredentialsFailed(request, Credentials.method, error)
}
}

func refresh(unauthorized credentials: Credentials, from response: HTTPResponseData) async throws -> Credentials {
if case let .fetching(task) = state {
return try await task.value
}

let task = Task { try await performCredentialRefresh(unauthorized: credentials, from: response) }
set(state: .fetching(task))
return try await task.value
}

private func performCredentialRefresh(
unauthorized credentials: Credentials,
from response: HTTPResponseData
) async throws -> Credentials {
logger?.info(
"🔑 Refreshing credentials for \(Credentials.method.rawValue, privacy: .public) authorization method"
)
do {
let refreshed = try await delegate.refresh(unauthorized: credentials, from: response)
set(state: .authorized(refreshed))
return refreshed
} catch {
set(state: .idle)
throw AuthenticationError.refreshCredentialsFailed(response, Credentials.method, error)
}
}
}

fileprivate let state: StateMachine

public init(delegate: Delegate) {
state = StateMachine(delegate: delegate)
}
}

extension HeaderBasedAuthentication: AuthenticationDelegate {
public typealias Credentials = Delegate.Credentials
public func fetch(for request: HTTPRequestData) async throws -> Credentials {
try await state.fetch(for: request)
}
public func refresh(
unauthorized credentials: Credentials, from response: HTTPResponseData
) async throws -> Credentials {
try await state.refresh(unauthorized: credentials, from: response)
}
}
2 changes: 1 addition & 1 deletion Sources/Networking/Core/HTTPRequestData.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public struct HTTPRequestData: Sendable, Identifiable {
id: ID,
method: HTTPRequest.Method = .get,
scheme: String? = "https",
authority: String?,
authority: String? = nil,
path: String? = nil,
headerFields: HTTPFields = [:],
body: Data? = nil
Expand Down
37 changes: 37 additions & 0 deletions Sources/TestSupport/TestAuthenticationDelegate.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import Networking
import Protected
import XCTestDynamicOverlay

public final class TestAuthenticationDelegate<Credentials: AuthenticatingCredentials>: @unchecked Sendable {
public typealias Fetch = @Sendable (HTTPRequestData) async throws -> Credentials
public typealias Refresh = @Sendable (Credentials, HTTPResponseData) async throws -> Credentials

@Protected public var fetchCount: Int = 0
@Protected public var refreshCount: Int = 0

public var fetch: Fetch
public var refresh: Refresh

public init(
fetch: @escaping Fetch = unimplemented("TestAuthenticationDelegate.fetch"),
refresh: @escaping Refresh = unimplemented("TestAuthenticationDelegate.refresh")
) {
self.fetch = fetch
self.refresh = refresh
}
}

extension TestAuthenticationDelegate: AuthenticationDelegate {
public func fetch(for request: HTTPRequestData) async throws -> Credentials {
fetchCount += 1
return try await fetch(request)
}

public func refresh(
unauthorized credentials: Credentials,
from response: HTTPResponseData
) async throws -> Credentials {
refreshCount += 1
return try await refresh(credentials, response)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import AssertionExtras
import Dependencies
import Foundation
import HTTPTypes
import TestSupport
import XCTest

@testable import Networking

final class BasicCredentialsTests: XCTestCase {
func test__apply_credentials() {
let credentials = BasicCredentials(user: "blob", password: "super!$3cret")
let request = credentials.apply(to: HTTPRequestData(id: "1"))
XCTAssertEqual(request.headerFields[.authorization], "Basic YmxvYjpzdXBlciEkM2NyZXQ=")
}

func test__provide_credentials() {
var request = HTTPRequestData(id: "1")
request.basicCredentials = BasicCredentials(user: "blob", password: "super!$3cret")
XCTAssertEqual(request.basicCredentials?.user, "blob")
XCTAssertEqual(request.basicCredentials?.password, "super!$3cret")
}
}
Loading
Loading