From cc4143951151c4dc4f8300c9f8f1070b6562c04c Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Tue, 26 May 2026 16:51:00 -0600 Subject: [PATCH 1/4] feat: GraphQL executions are handled in parallel Subscriptions were already parallelized, but this parallelizes normal queries as well. --- Sources/GraphQLTransportWS/Server.swift | 45 +++++++++++++------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index f270f9f..82ba846 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -25,7 +25,7 @@ where private var initialized = false private var initResult: InitPayloadResult? - private var subscriptionTasks = [String: Task]() + private var executionTasks = [String: Task]() /// Create a new server /// @@ -53,7 +53,7 @@ where } deinit { - subscriptionTasks.values.forEach { $0.cancel() } + executionTasks.values.forEach { $0.cancel() } } /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. @@ -133,13 +133,9 @@ where } let id = subscribeRequest.id - if subscriptionTasks[id] != nil { - try await error(.subscriberAlreadyExists(id: id)) - } - let graphQLRequest = subscribeRequest.payload - var isStreaming = false + let isStreaming: Bool do { isStreaming = try graphQLRequest.isSubscription() } catch { @@ -147,8 +143,16 @@ where return } - if isStreaming { - subscriptionTasks[id] = Task { + guard executionTasks[id] == nil else { + try await self.error(.subscriberAlreadyExists(id: id)) + return + } + executionTasks[id] = Task { + defer { + executionTasks.removeValue(forKey: id) + } + + if isStreaming { do { let stream = try await onSubscribe(graphQLRequest, initResult) for try await event in stream { @@ -157,19 +161,16 @@ where } } catch { try await sendError(error, id: id) - subscriptionTasks.removeValue(forKey: id) - throw error } try await self.sendComplete(id: id) - subscriptionTasks.removeValue(forKey: id) - } - } else { - do { - let result = try await onExecute(graphQLRequest, initResult) - try await sendNext(result, id: id) - try await sendComplete(id: id) - } catch { - try await sendError(error, id: id) + } else { + do { + let result = try await onExecute(graphQLRequest, initResult) + try await sendNext(result, id: id) + try await sendComplete(id: id) + } catch { + try await sendError(error, id: id) + } } } } @@ -181,9 +182,9 @@ where } let id = completeRequest.id - if let task = subscriptionTasks[id] { + if let task = executionTasks[id] { task.cancel() - subscriptionTasks.removeValue(forKey: id) + executionTasks.removeValue(forKey: id) } try await onOperationComplete(id) } From 5ba5b78b69c6542f24df03b6b38c44651cce3a26 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Tue, 26 May 2026 16:57:25 -0600 Subject: [PATCH 2/4] feat: Avoids sending multiple messages on message failures --- Sources/GraphQLTransportWS/Server.swift | 26 ++++++++++++------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 82ba846..966db16 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -153,25 +153,28 @@ where } if isStreaming { + let stream: SubscriptionSequenceType do { - let stream = try await onSubscribe(graphQLRequest, initResult) - for try await event in stream { - try Task.checkCancellation() - try await self.sendNext(event, id: id) - } + stream = try await onSubscribe(graphQLRequest, initResult) } catch { try await sendError(error, id: id) + return } - try await self.sendComplete(id: id) + for try await event in stream { + try await self.sendNext(event, id: id) + } + executionTasks.removeValue(forKey: id) } else { + let result: GraphQLResult do { - let result = try await onExecute(graphQLRequest, initResult) - try await sendNext(result, id: id) - try await sendComplete(id: id) + result = try await onExecute(graphQLRequest, initResult) } catch { try await sendError(error, id: id) + return } + try await sendNext(result, id: id) } + try await sendComplete(id: id) } } @@ -240,11 +243,6 @@ where try await sendError([error], id: id) } - /// Send an `error` response through the messenger - private func sendError(_ errorMessage: String, id: String) async throws { - try await sendError(GraphQLError(message: errorMessage), id: id) - } - /// Send an error through the messenger and close the connection private func error(_ error: GraphQLTransportWSError) async throws { try await messenger.error(error.message, code: error.code.rawValue) From 60546141d5b9a4320294957606535301370629e7 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Tue, 26 May 2026 17:02:20 -0600 Subject: [PATCH 3/4] refactor: Move `error` to messenger --- Sources/GraphQLTransportWS/Client.swift | 17 ++++++--------- Sources/GraphQLTransportWS/Messenger.swift | 7 ++++++ Sources/GraphQLTransportWS/Server.swift | 25 +++++++++------------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index 3012143..f11f43b 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -50,7 +50,7 @@ public actor Client { do { response = try decoder.decode(Response.self, from: message) } catch { - try await self.error(.noType()) + try await messenger.error(.noType()) return } @@ -62,31 +62,31 @@ public actor Client { from: message ) else { - try await error(.invalidResponseFormat(messageType: .connectionAck)) + try await messenger.error(.invalidResponseFormat(messageType: .connectionAck)) return } try await onConnectionAck(connectionAckResponse, self) case .next: guard let nextResponse = try? decoder.decode(NextResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .next)) + try await messenger.error(.invalidResponseFormat(messageType: .next)) return } try await onNext(nextResponse, self) case .error: guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .error)) + try await messenger.error(.invalidResponseFormat(messageType: .error)) return } try await onError(errorResponse, self) case .complete: guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .complete)) + try await messenger.error(.invalidResponseFormat(messageType: .complete)) return } try await onComplete(completeResponse, self) default: - try await error(.invalidType()) + try await messenger.error(.invalidType()) } } @@ -123,9 +123,4 @@ public actor Client { ) ) } - - /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLTransportWSError) async throws { - try await messenger.error(error.message, code: error.code.rawValue) - } } diff --git a/Sources/GraphQLTransportWS/Messenger.swift b/Sources/GraphQLTransportWS/Messenger.swift index 543c6fa..c31fc08 100644 --- a/Sources/GraphQLTransportWS/Messenger.swift +++ b/Sources/GraphQLTransportWS/Messenger.swift @@ -15,3 +15,10 @@ public protocol Messenger: Sendable { /// - code: An error code func error(_ message: String, code: Int) async throws } + +extension Messenger { + /// Send an error through the messenger and close the connection + func error(_ error: GraphQLTransportWSError) async throws { + try await self.error(error.message, code: error.code.rawValue) + } +} diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 966db16..304ed0d 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -70,7 +70,7 @@ where do { request = try decoder.decode(Request.self, from: message) } catch { - try await self.error(.noType()) + try await messenger.error(.noType()) return } @@ -83,26 +83,26 @@ where from: message ) else { - try await error(.invalidRequestFormat(messageType: .connectionInit)) + try await messenger.error(.invalidRequestFormat(messageType: .connectionInit)) return } try await onConnectionInit(connectionInitRequest, messenger) case .subscribe: guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: message) else { - try await error(.invalidRequestFormat(messageType: .subscribe)) + try await messenger.error(.invalidRequestFormat(messageType: .subscribe)) return } try await onSubscribe(subscribeRequest) case .complete: guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: message) else { - try await error(.invalidRequestFormat(messageType: .complete)) + try await messenger.error(.invalidRequestFormat(messageType: .complete)) return } try await onOperationComplete(completeRequest) default: - try await error(.invalidType()) + try await messenger.error(.invalidType()) } } @@ -111,14 +111,14 @@ where _: Messenger ) async throws { guard !initialized else { - try await error(.tooManyInitializations()) + try await messenger.error(.tooManyInitializations()) return } do { initResult = try await onInit(connectionInitRequest.payload) } catch { - try await self.error(.forbidden()) + try await messenger.error(.forbidden()) return } initialized = true @@ -128,7 +128,7 @@ where private func onSubscribe(_ subscribeRequest: SubscribeRequest) async throws { guard initialized, let initResult else { - try await error(.notInitialized()) + try await messenger.error(.notInitialized()) return } @@ -144,7 +144,7 @@ where } guard executionTasks[id] == nil else { - try await self.error(.subscriberAlreadyExists(id: id)) + try await messenger.error(.subscriberAlreadyExists(id: id)) return } executionTasks[id] = Task { @@ -180,7 +180,7 @@ where private func onOperationComplete(_ completeRequest: CompleteRequest) async throws { guard initialized else { - try await error(.notInitialized()) + try await messenger.error(.notInitialized()) return } @@ -242,9 +242,4 @@ where private func sendError(_ error: Error, id: String) async throws { try await sendError([error], id: id) } - - /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLTransportWSError) async throws { - try await messenger.error(error.message, code: error.code.rawValue) - } } From 8dd79b26d36fb445e75e2e4d99399d25d5d2ac69 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Tue, 26 May 2026 17:15:35 -0600 Subject: [PATCH 4/4] feat: Improve formatting error reporting --- Sources/GraphQLTransportWS/Client.swift | 31 ++++++---- .../GraphqlTransportWSError.swift | 8 +-- Sources/GraphQLTransportWS/Server.swift | 25 +++++---- .../GraphQLTransportWSTests.swift | 56 +++++++++++++++++++ 4 files changed, 95 insertions(+), 25 deletions(-) diff --git a/Sources/GraphQLTransportWS/Client.swift b/Sources/GraphQLTransportWS/Client.swift index f11f43b..e363be4 100644 --- a/Sources/GraphQLTransportWS/Client.swift +++ b/Sources/GraphQLTransportWS/Client.swift @@ -56,32 +56,41 @@ public actor Client { switch response.type { case .connectionAck: - guard - let connectionAckResponse = try? decoder.decode( + let connectionAckResponse: ConnectionAckResponse + do { + connectionAckResponse = try decoder.decode( ConnectionAckResponse.self, from: message ) - else { - try await messenger.error(.invalidResponseFormat(messageType: .connectionAck)) + } catch { + try await messenger.error(.invalidResponseFormat(messageType: .connectionAck, error: error)) return } try await onConnectionAck(connectionAckResponse, self) case .next: - guard let nextResponse = try? decoder.decode(NextResponse.self, from: message) else { - try await messenger.error(.invalidResponseFormat(messageType: .next)) + let nextResponse: NextResponse + do { + nextResponse = try decoder.decode(NextResponse.self, from: message) + } catch { + try await messenger.error(.invalidResponseFormat(messageType: .next, error: error)) return } try await onNext(nextResponse, self) case .error: - guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: message) else { - try await messenger.error(.invalidResponseFormat(messageType: .error)) + let errorResponse: ErrorResponse + do { + errorResponse = try decoder.decode(ErrorResponse.self, from: message) + } catch { + try await messenger.error(.invalidResponseFormat(messageType: .error, error: error)) return } try await onError(errorResponse, self) case .complete: - guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: message) - else { - try await messenger.error(.invalidResponseFormat(messageType: .complete)) + let completeResponse: CompleteResponse + do { + completeResponse = try decoder.decode(CompleteResponse.self, from: message) + } catch { + try await messenger.error(.invalidResponseFormat(messageType: .complete, error: error)) return } try await onComplete(completeResponse, self) diff --git a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift index d2ae94b..8d21304 100644 --- a/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift +++ b/Sources/GraphQLTransportWS/GraphqlTransportWSError.swift @@ -58,16 +58,16 @@ struct GraphQLTransportWSError: Error { ) } - static func invalidRequestFormat(messageType: RequestMessageType) -> Self { + static func invalidRequestFormat(messageType: RequestMessageType, error: Error) -> Self { return self.init( - "Request message doesn't match '\(messageType.type.rawValue)' JSON format", + "Request message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)", code: .miscellaneous ) } - static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { + static func invalidResponseFormat(messageType: ResponseMessageType, error: Error) -> Self { return self.init( - "Response message doesn't match '\(messageType.type.rawValue)' JSON format", + "Response message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)", code: .miscellaneous ) } diff --git a/Sources/GraphQLTransportWS/Server.swift b/Sources/GraphQLTransportWS/Server.swift index 304ed0d..9c5b815 100644 --- a/Sources/GraphQLTransportWS/Server.swift +++ b/Sources/GraphQLTransportWS/Server.swift @@ -77,27 +77,32 @@ where // handle incoming message switch request.type { case .connectionInit: - guard - let connectionInitRequest = try? decoder.decode( + let connectionInitRequest: ConnectionInitRequest + do { + connectionInitRequest = try decoder.decode( ConnectionInitRequest.self, from: message ) - else { - try await messenger.error(.invalidRequestFormat(messageType: .connectionInit)) + } catch { + try await messenger.error(.invalidRequestFormat(messageType: .connectionInit, error: error)) return } try await onConnectionInit(connectionInitRequest, messenger) case .subscribe: - guard let subscribeRequest = try? decoder.decode(SubscribeRequest.self, from: message) - else { - try await messenger.error(.invalidRequestFormat(messageType: .subscribe)) + let subscribeRequest: SubscribeRequest + do { + subscribeRequest = try decoder.decode(SubscribeRequest.self, from: message) + } catch { + try await messenger.error(.invalidRequestFormat(messageType: .subscribe, error: error)) return } try await onSubscribe(subscribeRequest) case .complete: - guard let completeRequest = try? decoder.decode(CompleteRequest.self, from: message) - else { - try await messenger.error(.invalidRequestFormat(messageType: .complete)) + let completeRequest: CompleteRequest + do { + completeRequest = try decoder.decode(CompleteRequest.self, from: message) + } catch { + try await messenger.error(.invalidRequestFormat(messageType: .complete, error: error)) return } try await onOperationComplete(completeRequest) diff --git a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift index a1c397d..695016b 100644 --- a/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift +++ b/Tests/GraphQLTransportWSTests/GraphQLTransportWSTests.swift @@ -268,6 +268,62 @@ struct GraphqlTransportWSTests { ) } + /// Tests malformed requests include decoder details in the transport error + @Test func malformedRequestIncludesDecodingDetails() async throws { + let api = TestAPI() + let context = TestContext() + let server = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await api.execute( + request: graphQLRequest.query, + context: context + ) + }, + onSubscribe: { graphQLRequest, _ in + try await api.subscribe( + request: graphQLRequest.query, + context: context + ).get() + } + ) + let (incoming, continuation) = AsyncThrowingStream.makeStream() + + continuation.yield(Data(#"{"type":"complete"}"#.utf8)) + continuation.finish() + + try await server.listen(to: incoming) + + let error = await #expect(throws: TestMessengerError.self) { + for try await _ in serverMessenger.stream {} + } + #expect(error?.code == 4400) + #expect(error?.message.contains("Request message doesn't match 'complete' JSON format") == true) + #expect(error?.message.contains("keyNotFound") == true) + #expect(error?.message.contains(#""id""#) == true) + } + + /// Tests malformed responses include decoder details in the transport error + @Test func malformedResponseIncludesDecodingDetails() async throws { + let messenger = TestMessenger() + let client = Client(messenger: messenger) + let (incoming, continuation) = AsyncThrowingStream.makeStream() + + continuation.yield(Data(#"{"type":"next"}"#.utf8)) + continuation.finish() + + try await client.listen(to: incoming) + + let error = await #expect(throws: TestMessengerError.self) { + for try await _ in messenger.stream {} + } + #expect(error?.code == 4400) + #expect(error?.message.contains("Response message doesn't match 'next' JSON format") == true) + #expect(error?.message.contains("keyNotFound") == true) + #expect(error?.message.contains(#""id""#) == true) + } + enum TestError: Error { case couldBeAnything }