From 185653d9564846cde4a3fc04955cb5dd4fe51827 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Thu, 21 May 2026 13:54:21 -0600 Subject: [PATCH 1/4] feat: GraphQL executions are handled in parallel Subscriptions were already parallelized, but this parallelizes normal queries as well. --- Sources/GraphQLWS/Server.swift | 58 ++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index c831e68..93a9cbd 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -26,7 +26,7 @@ where private var initialized = false private var initResult: InitPayloadResult? - private var subscriptionTasks = [String: Task]() + private var executionTasks = [String: Task]() /// Create a new server /// @@ -54,7 +54,7 @@ where } deinit { - subscriptionTasks.values.forEach { $0.cancel() } + executionTasks.values.forEach { $0.cancel() } } /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. @@ -133,7 +133,6 @@ where } initialized = true try await sendConnectionAck() - // TODO: Should we send the `ka` message? } private func onStart(_ startRequest: StartRequest, _: Messenger) async throws { @@ -143,13 +142,9 @@ where } let id = startRequest.id - if subscriptionTasks[id] != nil { - try await error(.subscriberAlreadyExists(id: id)) - } - let graphQLRequest = startRequest.payload - var isStreaming = false + let isStreaming: Bool do { isStreaming = try graphQLRequest.isSubscription() } catch { @@ -157,8 +152,18 @@ where return } - if isStreaming { - subscriptionTasks[id] = Task { + let onSubscribe = self.onSubscribe + let onExecute = self.onExecute + guard executionTasks[id] == nil else { + try await self.error(.subscriberAlreadyExists(id: id)) + return + } + executionTasks[id] = Task { + defer { + executionTasks.removeValue(forKey: id) + } + + if isStreaming { do { let stream = try await onSubscribe(graphQLRequest, initResult) for try await event in stream { @@ -166,22 +171,21 @@ where try await self.sendData(event, id: id) } } catch { - try await sendError(error, id: id) - subscriptionTasks.removeValue(forKey: id) - throw error + try await self.sendError(error, id: id) } try await self.sendComplete(id: id) - subscriptionTasks.removeValue(forKey: id) - } - } else { - do { - let result = try await onExecute(graphQLRequest, initResult) - try await sendData(result, id: id) - try await sendComplete(id: id) - } catch { - try await sendError(error, id: id) + } else { + do { + let result = try await onExecute(graphQLRequest, initResult) + try await self.sendData(result, id: id) + try await self.sendComplete(id: id) + } catch { + try await self.sendError(error, id: id) + } } + executionTasks.removeValue(forKey: id) } + } private func onStop(_ stopRequest: StopRequest) async throws { @@ -191,9 +195,9 @@ where } let id = stopRequest.id - if let task = subscriptionTasks[id] { + if let task = executionTasks[id] { task.cancel() - subscriptionTasks.removeValue(forKey: id) + executionTasks.removeValue(forKey: id) } try await onOperationComplete(id) } @@ -201,10 +205,10 @@ where private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger) async throws { - for (_, subscriptionTask) in subscriptionTasks { - subscriptionTask.cancel() + for (_, task) in executionTasks { + task.cancel() } - subscriptionTasks.removeAll() + executionTasks.removeAll() try await messenger.close() } From be6aa8787137b0e3d6caebad002d38369716d9c6 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Fri, 22 May 2026 12:41:54 -0600 Subject: [PATCH 2/4] feat: Avoids sending multiple messages on message failures --- Sources/GraphQLWS/Server.swift | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index 93a9cbd..fff3db4 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -164,26 +164,28 @@ where } if isStreaming { + let stream: SubscriptionSequenceType do { - let stream = try await onSubscribe(graphQLRequest, initResult) - for try await event in stream { - try Task.checkCancellation() - try await self.sendData(event, id: id) - } + stream = try await onSubscribe(graphQLRequest, initResult) } catch { - try await self.sendError(error, id: id) + try await sendError(error, id: id) + return + } + for try await event in stream { + try await sendData(event, id: id) } - try await self.sendComplete(id: id) } else { + let result: GraphQLResult do { - let result = try await onExecute(graphQLRequest, initResult) - try await self.sendData(result, id: id) - try await self.sendComplete(id: id) + result = try await onExecute(graphQLRequest, initResult) } catch { - try await self.sendError(error, id: id) + try await sendError(error, id: id) + return } + try await sendData(result, id: id) } - executionTasks.removeValue(forKey: id) + try await sendComplete(id: id) + try await onOperationComplete(id) } } @@ -275,11 +277,6 @@ where try await sendError([error], id: id) } - /// Send an `error` response through the messenger - private func sendError(_ errorMessage: String, id: String) async throws { - try await sendError(GraphQLError(message: errorMessage), id: id) - } - /// Send an error through the messenger and close the connection private func error(_ error: GraphQLWSError) async throws { try await messenger.error(error.message, code: error.code.rawValue) From 76fb1264b29fb114a39566caa10052ce81522ce7 Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Fri, 22 May 2026 11:46:47 -0600 Subject: [PATCH 3/4] refactor: Move `error` to messenger --- Sources/GraphQLWS/Client.swift | 27 ++++++++++++++------------- Sources/GraphQLWS/Messenger.swift | 7 +++++++ Sources/GraphQLWS/Server.swift | 29 +++++++++++++---------------- 3 files changed, 34 insertions(+), 29 deletions(-) diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index c7c4ff9..8905e82 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -62,7 +62,7 @@ public actor Client { do { response = try decoder.decode(Response.self, from: message) } catch { - try await self.error(.noType()) + try await messenger.error(.noType()) return } @@ -74,7 +74,9 @@ public actor Client { from: message ) else { - try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR) + ) return } try await onConnectionError(connectionErrorResponse, self) @@ -85,7 +87,9 @@ public actor Client { from: message ) else { - try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR) + ) return } try await onConnectionAck(connectionAckResponse, self) @@ -96,31 +100,33 @@ public actor Client { from: message ) else { - try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)) + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE) + ) return } try await onConnectionKeepAlive(connectionKeepAliveResponse, self) case .GQL_DATA: guard let nextResponse = try? decoder.decode(DataResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .GQL_DATA)) + try await messenger.error(.invalidResponseFormat(messageType: .GQL_DATA)) return } try await onData(nextResponse, self) case .GQL_ERROR: guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .GQL_ERROR)) + try await messenger.error(.invalidResponseFormat(messageType: .GQL_ERROR)) return } try await onError(errorResponse, self) case .GQL_COMPLETE: guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) + try await messenger.error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) return } try await onComplete(completeResponse, self) default: - try await error(.invalidType()) + try await messenger.error(.invalidType()) } } @@ -166,9 +172,4 @@ public actor Client { ) ) } - - /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLWSError) async throws { - try await messenger.error(error.message, code: error.code.rawValue) - } } diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index 543c6fa..a6abb77 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -15,3 +15,10 @@ public protocol Messenger: Sendable { /// - code: An error code func error(_ message: String, code: Int) async throws } + +extension Messenger { + /// Send an error through the messenger and close the connection + func error(_ error: GraphQLWSError) async throws { + try await self.error(error.message, code: error.code.rawValue) + } +} diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index fff3db4..a4aab8b 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -71,7 +71,7 @@ where do { request = try decoder.decode(Request.self, from: message) } catch { - try await self.error(.noType()) + try await messenger.error(.noType()) return } @@ -84,19 +84,19 @@ where from: message ) else { - try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) + try await messenger.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) return } try await onConnectionInit(connectionInitRequest, messenger) case .GQL_START: guard let startRequest = try? decoder.decode(StartRequest.self, from: message) else { - try await error(.invalidRequestFormat(messageType: .GQL_START)) + try await messenger.error(.invalidRequestFormat(messageType: .GQL_START)) return } try await onStart(startRequest, messenger) case .GQL_STOP: guard let stopRequest = try? decoder.decode(StopRequest.self, from: message) else { - try await error(.invalidRequestFormat(messageType: .GQL_STOP)) + try await messenger.error(.invalidRequestFormat(messageType: .GQL_STOP)) return } try await onStop(stopRequest) @@ -107,12 +107,14 @@ where from: message ) else { - try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE)) + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE) + ) return } try await onConnectionTerminate(connectionTerminateRequest, messenger) default: - try await error(.invalidType()) + try await messenger.error(.invalidType()) } } @@ -121,14 +123,14 @@ where _: Messenger ) async throws { guard !initialized else { - try await error(.tooManyInitializations()) + try await messenger.error(.tooManyInitializations()) return } do { initResult = try await onInit(connectionInitRequest.payload) } catch { - try await self.error(.unauthorized()) + try await messenger.error(.unauthorized()) return } initialized = true @@ -137,7 +139,7 @@ where private func onStart(_ startRequest: StartRequest, _: Messenger) async throws { guard initialized, let initResult else { - try await error(.notInitialized()) + try await messenger.error(.notInitialized()) return } @@ -155,7 +157,7 @@ where let onSubscribe = self.onSubscribe let onExecute = self.onExecute guard executionTasks[id] == nil else { - try await self.error(.subscriberAlreadyExists(id: id)) + try await messenger.error(.subscriberAlreadyExists(id: id)) return } executionTasks[id] = Task { @@ -192,7 +194,7 @@ where private func onStop(_ stopRequest: StopRequest) async throws { guard initialized else { - try await error(.notInitialized()) + try await messenger.error(.notInitialized()) return } @@ -276,9 +278,4 @@ where private func sendError(_ error: Error, id: String) async throws { try await sendError([error], id: id) } - - /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLWSError) async throws { - try await messenger.error(error.message, code: error.code.rawValue) - } } From 56a5cab64650f3f9dee2eb21f0d87948d06816fc Mon Sep 17 00:00:00 2001 From: Jay Herron Date: Fri, 22 May 2026 12:32:14 -0600 Subject: [PATCH 4/4] feat: Improve formatting error reporting --- Sources/GraphQLWS/Client.swift | 57 ++++++++++++++-------- Sources/GraphQLWS/GraphQLWSError.swift | 8 ++-- Sources/GraphQLWS/Requests.swift | 3 +- Sources/GraphQLWS/Server.swift | 38 ++++++++++----- Tests/GraphQLWSTests/GraphQLWSTests.swift | 58 +++++++++++++++++++++++ 5 files changed, 126 insertions(+), 38 deletions(-) diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index 8905e82..6fa46a6 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -68,60 +68,77 @@ public actor Client { switch response.type { case .GQL_CONNECTION_ERROR: - guard - let connectionErrorResponse = try? decoder.decode( + let connectionErrorResponse: ConnectionErrorResponse + do { + connectionErrorResponse = try decoder.decode( ConnectionErrorResponse.self, from: message ) - else { + } catch { try await messenger.error( - .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR) + .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR, error: error) ) return } try await onConnectionError(connectionErrorResponse, self) case .GQL_CONNECTION_ACK: - guard - let connectionAckResponse = try? decoder.decode( + let connectionAckResponse: ConnectionAckResponse + do { + connectionAckResponse = try decoder.decode( ConnectionAckResponse.self, from: message ) - else { + } catch { try await messenger.error( - .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR) + .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR, error: error) ) return } try await onConnectionAck(connectionAckResponse, self) case .GQL_CONNECTION_KEEP_ALIVE: - guard - let connectionKeepAliveResponse = try? decoder.decode( + let connectionKeepAliveResponse: ConnectionKeepAliveResponse + do { + connectionKeepAliveResponse = try decoder.decode( ConnectionKeepAliveResponse.self, from: message ) - else { + } catch { try await messenger.error( - .invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE) + .invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE, error: error) ) return } try await onConnectionKeepAlive(connectionKeepAliveResponse, self) case .GQL_DATA: - guard let nextResponse = try? decoder.decode(DataResponse.self, from: message) else { - try await messenger.error(.invalidResponseFormat(messageType: .GQL_DATA)) + let dataResponse: DataResponse + do { + dataResponse = try decoder.decode(DataResponse.self, from: message) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_DATA, error: error) + ) return } - try await onData(nextResponse, self) + try await onData(dataResponse, self) case .GQL_ERROR: - guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: message) else { - try await messenger.error(.invalidResponseFormat(messageType: .GQL_ERROR)) + let errorResponse: ErrorResponse + do { + errorResponse = try decoder.decode(ErrorResponse.self, from: message) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_ERROR, error: error) + ) return } try await onError(errorResponse, self) case .GQL_COMPLETE: - guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: message) - else { - try await messenger.error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) + let completeResponse: CompleteResponse + do { + completeResponse = try decoder.decode(CompleteResponse.self, from: message) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_COMPLETE, error: error) + ) return } try await onComplete(completeResponse, self) diff --git a/Sources/GraphQLWS/GraphQLWSError.swift b/Sources/GraphQLWS/GraphQLWSError.swift index f2d11e0..acd11c6 100644 --- a/Sources/GraphQLWS/GraphQLWSError.swift +++ b/Sources/GraphQLWS/GraphQLWSError.swift @@ -58,16 +58,16 @@ struct GraphQLWSError: Error { ) } - static func invalidRequestFormat(messageType: RequestMessageType) -> Self { + static func invalidRequestFormat(messageType: RequestMessageType, error: Error) -> Self { return self.init( - "Request message doesn't match '\(messageType.type.rawValue)' JSON format", + "Request message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)", code: .invalidRequestFormat ) } - static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { + static func invalidResponseFormat(messageType: ResponseMessageType, error: Error) -> Self { return self.init( - "Response message doesn't match '\(messageType.type.rawValue)' JSON format", + "Response message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)", code: .invalidResponseFormat ) } diff --git a/Sources/GraphQLWS/Requests.swift b/Sources/GraphQLWS/Requests.swift index b77d384..8fe7bd2 100644 --- a/Sources/GraphQLWS/Requests.swift +++ b/Sources/GraphQLWS/Requests.swift @@ -67,8 +67,7 @@ public struct StopRequest: Equatable, Codable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) - if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE - { + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_STOP { throw DecodingError.dataCorrupted( .init( codingPath: decoder.codingPath, diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index a4aab8b..eedc49a 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -78,37 +78,51 @@ where // handle incoming message switch request.type { case .GQL_CONNECTION_INIT: - guard - let connectionInitRequest = try? decoder.decode( + let connectionInitRequest: ConnectionInitRequest + do { + connectionInitRequest = try decoder.decode( ConnectionInitRequest.self, from: message ) - else { - try await messenger.error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) + } catch { + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_CONNECTION_INIT, error: error) + ) return } try await onConnectionInit(connectionInitRequest, messenger) case .GQL_START: - guard let startRequest = try? decoder.decode(StartRequest.self, from: message) else { - try await messenger.error(.invalidRequestFormat(messageType: .GQL_START)) + let startRequest: StartRequest + do { + startRequest = try decoder.decode(StartRequest.self, from: message) + } catch { + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_START, error: error) + ) return } try await onStart(startRequest, messenger) case .GQL_STOP: - guard let stopRequest = try? decoder.decode(StopRequest.self, from: message) else { - try await messenger.error(.invalidRequestFormat(messageType: .GQL_STOP)) + let stopRequest: StopRequest + do { + stopRequest = try decoder.decode(StopRequest.self, from: message) + } catch { + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_STOP, error: error) + ) return } try await onStop(stopRequest) case .GQL_CONNECTION_TERMINATE: - guard - let connectionTerminateRequest = try? decoder.decode( + let connectionTerminateRequest: ConnectionTerminateRequest + do { + connectionTerminateRequest = try decoder.decode( ConnectionTerminateRequest.self, from: message ) - else { + } catch { try await messenger.error( - .invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE) + .invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE, error: error) ) return } diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index ad350ad..b4251e6 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -259,6 +259,64 @@ struct GraphqlTransportWSTests { ) } + /// Tests malformed requests include decoder details in the transport error + @Test func malformedRequestIncludesDecodingDetails() async throws { + let api = TestAPI() + let context = TestContext() + let server = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await api.execute( + request: graphQLRequest.query, + context: context + ) + }, + onSubscribe: { graphQLRequest, _ in + try await api.subscribe( + request: graphQLRequest.query, + context: context + ).get() + } + ) + let (incoming, continuation) = AsyncThrowingStream.makeStream() + + continuation.yield(Data(#"{"type":"stop"}"#.utf8)) + continuation.finish() + + try await server.listen(to: incoming) + + let error = await #expect(throws: TestMessengerError.self) { + for try await _ in serverMessenger.stream {} + } + #expect(error?.code == 4413) + #expect(error?.message.contains("Request message doesn't match 'stop' JSON format") == true) + #expect(error?.message.contains("keyNotFound") == true) + #expect(error?.message.contains(#""id""#) == true) + } + + /// Tests malformed responses include decoder details in the transport error + @Test func malformedResponseIncludesDecodingDetails() async throws { + let messenger = TestMessenger() + let client = Client(messenger: messenger) + let (incoming, continuation) = AsyncThrowingStream.makeStream() + + continuation.yield(Data(#"{"type":"data"}"#.utf8)) + continuation.finish() + + try await client.listen(to: incoming) + + let error = await #expect(throws: TestMessengerError.self) { + for try await _ in messenger.stream {} + } + #expect(error?.code == 4414) + #expect( + error?.message.contains("Response message doesn't match 'data' JSON format") == true + ) + #expect(error?.message.contains("keyNotFound") == true) + #expect(error?.message.contains(#""id""#) == true) + } + enum TestError: Error { case couldBeAnything }