Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions Sources/GraphQLTransportWS/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,43 +50,52 @@ public actor Client<InitPayload: Equatable & Codable> {
do {
response = try decoder.decode(Response.self, from: message)
} catch {
try await self.error(.noType())
try await messenger.error(.noType())
return
}

switch response.type {
case .connectionAck:
guard
let connectionAckResponse = try? decoder.decode(
let connectionAckResponse: ConnectionAckResponse
do {
connectionAckResponse = try decoder.decode(
ConnectionAckResponse.self,
from: message
)
else {
try await error(.invalidResponseFormat(messageType: .connectionAck))
} catch {
try await messenger.error(.invalidResponseFormat(messageType: .connectionAck, error: error))
return
}
try await onConnectionAck(connectionAckResponse, self)
case .next:
guard let nextResponse = try? decoder.decode(NextResponse.self, from: message) else {
try await error(.invalidResponseFormat(messageType: .next))
let nextResponse: NextResponse
do {
nextResponse = try decoder.decode(NextResponse.self, from: message)
} catch {
try await messenger.error(.invalidResponseFormat(messageType: .next, error: error))
return
}
try await onNext(nextResponse, self)
case .error:
guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: message) else {
try await error(.invalidResponseFormat(messageType: .error))
let errorResponse: ErrorResponse
do {
errorResponse = try decoder.decode(ErrorResponse.self, from: message)
} catch {
try await messenger.error(.invalidResponseFormat(messageType: .error, error: error))
return
}
try await onError(errorResponse, self)
case .complete:
guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: message)
else {
try await error(.invalidResponseFormat(messageType: .complete))
let completeResponse: CompleteResponse
do {
completeResponse = try decoder.decode(CompleteResponse.self, from: message)
} catch {
try await messenger.error(.invalidResponseFormat(messageType: .complete, error: error))
return
}
try await onComplete(completeResponse, self)
default:
try await error(.invalidType())
try await messenger.error(.invalidType())
}
}

Expand Down Expand Up @@ -123,9 +132,4 @@ public actor Client<InitPayload: Equatable & Codable> {
)
)
}

/// Send an error through the messenger and close the connection
private func error(_ error: GraphQLTransportWSError) async throws {
try await messenger.error(error.message, code: error.code.rawValue)
}
}
8 changes: 4 additions & 4 deletions Sources/GraphQLTransportWS/GraphqlTransportWSError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ struct GraphQLTransportWSError: Error {
)
}

static func invalidRequestFormat(messageType: RequestMessageType) -> Self {
static func invalidRequestFormat(messageType: RequestMessageType, error: Error) -> Self {
return self.init(
"Request message doesn't match '\(messageType.type.rawValue)' JSON format",
"Request message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)",
code: .miscellaneous
)
}

static func invalidResponseFormat(messageType: ResponseMessageType) -> Self {
static func invalidResponseFormat(messageType: ResponseMessageType, error: Error) -> Self {
return self.init(
"Response message doesn't match '\(messageType.type.rawValue)' JSON format",
"Response message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)",
code: .miscellaneous
)
}
Expand Down
7 changes: 7 additions & 0 deletions Sources/GraphQLTransportWS/Messenger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ public protocol Messenger: Sendable {
/// - code: An error code
func error(_ message: String, code: Int) async throws
}

extension Messenger {
/// Send an error through the messenger and close the connection
func error(_ error: GraphQLTransportWSError) async throws {
try await self.error(error.message, code: error.code.rawValue)
}
}
105 changes: 52 additions & 53 deletions Sources/GraphQLTransportWS/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ where

private var initialized = false
private var initResult: InitPayloadResult?
private var subscriptionTasks = [String: Task<Void, any Error>]()
private var executionTasks = [String: Task<Void, any Error>]()

/// Create a new server
///
Expand Down Expand Up @@ -53,7 +53,7 @@ where
}

deinit {
subscriptionTasks.values.forEach { $0.cancel() }
executionTasks.values.forEach { $0.cancel() }
}

/// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed.
Expand All @@ -70,39 +70,44 @@ where
do {
request = try decoder.decode(Request.self, from: message)
} catch {
try await self.error(.noType())
try await messenger.error(.noType())
return
}

// handle incoming message
switch request.type {
case .connectionInit:
guard
let connectionInitRequest = try? decoder.decode(
let connectionInitRequest: ConnectionInitRequest<InitPayload>
do {
connectionInitRequest = try decoder.decode(
ConnectionInitRequest<InitPayload>.self,
from: message
)
else {
try await error(.invalidRequestFormat(messageType: .connectionInit))
} catch {
try await messenger.error(.invalidRequestFormat(messageType: .connectionInit, error: error))
return
}
try await onConnectionInit(connectionInitRequest, messenger)
case .subscribe:
guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: message)
else {
try await error(.invalidRequestFormat(messageType: .subscribe))
let subscribeRequest: SubscribeRequest
do {
subscribeRequest = try decoder.decode(SubscribeRequest.self, from: message)
} catch {
try await messenger.error(.invalidRequestFormat(messageType: .subscribe, error: error))
return
}
try await onSubscribe(subscribeRequest)
case .complete:
guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: message)
else {
try await error(.invalidRequestFormat(messageType: .complete))
let completeRequest: CompleteRequest
do {
completeRequest = try decoder.decode(CompleteRequest.self, from: message)
} catch {
try await messenger.error(.invalidRequestFormat(messageType: .complete, error: error))
return
}
try await onOperationComplete(completeRequest)
default:
try await error(.invalidType())
try await messenger.error(.invalidType())
}
}

Expand All @@ -111,14 +116,14 @@ where
_: Messenger
) async throws {
guard !initialized else {
try await error(.tooManyInitializations())
try await messenger.error(.tooManyInitializations())
return
}

do {
initResult = try await onInit(connectionInitRequest.payload)
} catch {
try await self.error(.forbidden())
try await messenger.error(.forbidden())
return
}
initialized = true
Expand All @@ -128,62 +133,66 @@ where

private func onSubscribe(_ subscribeRequest: SubscribeRequest) async throws {
guard initialized, let initResult else {
try await error(.notInitialized())
try await messenger.error(.notInitialized())
return
}

let id = subscribeRequest.id
if subscriptionTasks[id] != nil {
try await error(.subscriberAlreadyExists(id: id))
}

let graphQLRequest = subscribeRequest.payload

var isStreaming = false
let isStreaming: Bool
do {
isStreaming = try graphQLRequest.isSubscription()
} catch {
try await sendError(error, id: id)
return
}

if isStreaming {
subscriptionTasks[id] = Task {
guard executionTasks[id] == nil else {
try await messenger.error(.subscriberAlreadyExists(id: id))
return
}
executionTasks[id] = Task {
defer {
executionTasks.removeValue(forKey: id)
}

if isStreaming {
let stream: SubscriptionSequenceType
do {
let stream = try await onSubscribe(graphQLRequest, initResult)
for try await event in stream {
try Task.checkCancellation()
try await self.sendNext(event, id: id)
}
stream = try await onSubscribe(graphQLRequest, initResult)
} catch {
try await sendError(error, id: id)
subscriptionTasks.removeValue(forKey: id)
throw error
return
}
for try await event in stream {
try await self.sendNext(event, id: id)
}
executionTasks.removeValue(forKey: id)
} else {
let result: GraphQLResult
do {
result = try await onExecute(graphQLRequest, initResult)
} catch {
try await sendError(error, id: id)
return
}
try await self.sendComplete(id: id)
subscriptionTasks.removeValue(forKey: id)
}
} else {
do {
let result = try await onExecute(graphQLRequest, initResult)
try await sendNext(result, id: id)
try await sendComplete(id: id)
} catch {
try await sendError(error, id: id)
}
try await sendComplete(id: id)
}
}

private func onOperationComplete(_ completeRequest: CompleteRequest) async throws {
guard initialized else {
try await error(.notInitialized())
try await messenger.error(.notInitialized())
return
}

let id = completeRequest.id
if let task = subscriptionTasks[id] {
if let task = executionTasks[id] {
task.cancel()
subscriptionTasks.removeValue(forKey: id)
executionTasks.removeValue(forKey: id)
}
try await onOperationComplete(id)
}
Expand Down Expand Up @@ -238,14 +247,4 @@ where
private func sendError(_ error: Error, id: String) async throws {
try await sendError([error], id: id)
}

/// Send an `error` response through the messenger
private func sendError(_ errorMessage: String, id: String) async throws {
try await sendError(GraphQLError(message: errorMessage), id: id)
}

/// Send an error through the messenger and close the connection
private func error(_ error: GraphQLTransportWSError) async throws {
try await messenger.error(error.message, code: error.code.rawValue)
}
}
56 changes: 56 additions & 0 deletions Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,62 @@ struct GraphqlTransportWSTests {
)
}

/// Tests malformed requests include decoder details in the transport error
@Test func malformedRequestIncludesDecodingDetails() async throws {
let api = TestAPI()
let context = TestContext()
let server = Server<TokenInitPayload, Void, AsyncThrowingStream<GraphQLResult, Error>>(
messenger: serverMessenger,
onInit: { _ in },
onExecute: { graphQLRequest, _ in
try await api.execute(
request: graphQLRequest.query,
context: context
)
},
onSubscribe: { graphQLRequest, _ in
try await api.subscribe(
request: graphQLRequest.query,
context: context
).get()
}
)
let (incoming, continuation) = AsyncThrowingStream<Data, any Error>.makeStream()

continuation.yield(Data(#"{"type":"complete"}"#.utf8))
continuation.finish()

try await server.listen(to: incoming)

let error = await #expect(throws: TestMessengerError.self) {
for try await _ in serverMessenger.stream {}
}
#expect(error?.code == 4400)
#expect(error?.message.contains("Request message doesn't match 'complete' JSON format") == true)
#expect(error?.message.contains("keyNotFound") == true)
#expect(error?.message.contains(#""id""#) == true)
}

/// Tests malformed responses include decoder details in the transport error
@Test func malformedResponseIncludesDecodingDetails() async throws {
let messenger = TestMessenger()
let client = Client<TokenInitPayload>(messenger: messenger)
let (incoming, continuation) = AsyncThrowingStream<Data, any Error>.makeStream()

continuation.yield(Data(#"{"type":"next"}"#.utf8))
continuation.finish()

try await client.listen(to: incoming)

let error = await #expect(throws: TestMessengerError.self) {
for try await _ in messenger.stream {}
}
#expect(error?.code == 4400)
#expect(error?.message.contains("Response message doesn't match 'next' JSON format") == true)
#expect(error?.message.contains("keyNotFound") == true)
#expect(error?.message.contains(#""id""#) == true)
}

enum TestError: Error {
case couldBeAnything
}
Expand Down
Loading