diff --git a/Sources/GraphQLWS/Client.swift b/Sources/GraphQLWS/Client.swift index c7c4ff9..6fa46a6 100644 --- a/Sources/GraphQLWS/Client.swift +++ b/Sources/GraphQLWS/Client.swift @@ -62,65 +62,88 @@ public actor Client { do { response = try decoder.decode(Response.self, from: message) } catch { - try await self.error(.noType()) + try await messenger.error(.noType()) return } switch response.type { case .GQL_CONNECTION_ERROR: - guard - let connectionErrorResponse = try? decoder.decode( + let connectionErrorResponse: ConnectionErrorResponse + do { + connectionErrorResponse = try decoder.decode( ConnectionErrorResponse.self, from: message ) - else { - try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR, error: error) + ) return } try await onConnectionError(connectionErrorResponse, self) case .GQL_CONNECTION_ACK: - guard - let connectionAckResponse = try? decoder.decode( + let connectionAckResponse: ConnectionAckResponse + do { + connectionAckResponse = try decoder.decode( ConnectionAckResponse.self, from: message ) - else { - try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR)) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_CONNECTION_ERROR, error: error) + ) return } try await onConnectionAck(connectionAckResponse, self) case .GQL_CONNECTION_KEEP_ALIVE: - guard - let connectionKeepAliveResponse = try? decoder.decode( + let connectionKeepAliveResponse: ConnectionKeepAliveResponse + do { + connectionKeepAliveResponse = try decoder.decode( ConnectionKeepAliveResponse.self, from: message ) - else { - try await error(.invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE)) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_CONNECTION_KEEP_ALIVE, error: error) + ) return } try await onConnectionKeepAlive(connectionKeepAliveResponse, self) case .GQL_DATA: - guard let nextResponse = try? decoder.decode(DataResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .GQL_DATA)) + let dataResponse: DataResponse + do { + dataResponse = try decoder.decode(DataResponse.self, from: message) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_DATA, error: error) + ) return } - try await onData(nextResponse, self) + try await onData(dataResponse, self) case .GQL_ERROR: - guard let errorResponse = try? decoder.decode(ErrorResponse.self, from: message) else { - try await error(.invalidResponseFormat(messageType: .GQL_ERROR)) + let errorResponse: ErrorResponse + do { + errorResponse = try decoder.decode(ErrorResponse.self, from: message) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_ERROR, error: error) + ) return } try await onError(errorResponse, self) case .GQL_COMPLETE: - guard let completeResponse = try? decoder.decode(CompleteResponse.self, from: message) - else { - try await error(.invalidResponseFormat(messageType: .GQL_COMPLETE)) + let completeResponse: CompleteResponse + do { + completeResponse = try decoder.decode(CompleteResponse.self, from: message) + } catch { + try await messenger.error( + .invalidResponseFormat(messageType: .GQL_COMPLETE, error: error) + ) return } try await onComplete(completeResponse, self) default: - try await error(.invalidType()) + try await messenger.error(.invalidType()) } } @@ -166,9 +189,4 @@ public actor Client { ) ) } - - /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLWSError) async throws { - try await messenger.error(error.message, code: error.code.rawValue) - } } diff --git a/Sources/GraphQLWS/GraphQLWSError.swift b/Sources/GraphQLWS/GraphQLWSError.swift index f2d11e0..acd11c6 100644 --- a/Sources/GraphQLWS/GraphQLWSError.swift +++ b/Sources/GraphQLWS/GraphQLWSError.swift @@ -58,16 +58,16 @@ struct GraphQLWSError: Error { ) } - static func invalidRequestFormat(messageType: RequestMessageType) -> Self { + static func invalidRequestFormat(messageType: RequestMessageType, error: Error) -> Self { return self.init( - "Request message doesn't match '\(messageType.type.rawValue)' JSON format", + "Request message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)", code: .invalidRequestFormat ) } - static func invalidResponseFormat(messageType: ResponseMessageType) -> Self { + static func invalidResponseFormat(messageType: ResponseMessageType, error: Error) -> Self { return self.init( - "Response message doesn't match '\(messageType.type.rawValue)' JSON format", + "Response message doesn't match '\(messageType.type.rawValue)' JSON format: \(error)", code: .invalidResponseFormat ) } diff --git a/Sources/GraphQLWS/Messenger.swift b/Sources/GraphQLWS/Messenger.swift index 543c6fa..a6abb77 100644 --- a/Sources/GraphQLWS/Messenger.swift +++ b/Sources/GraphQLWS/Messenger.swift @@ -15,3 +15,10 @@ public protocol Messenger: Sendable { /// - code: An error code func error(_ message: String, code: Int) async throws } + +extension Messenger { + /// Send an error through the messenger and close the connection + func error(_ error: GraphQLWSError) async throws { + try await self.error(error.message, code: error.code.rawValue) + } +} diff --git a/Sources/GraphQLWS/Requests.swift b/Sources/GraphQLWS/Requests.swift index b77d384..8fe7bd2 100644 --- a/Sources/GraphQLWS/Requests.swift +++ b/Sources/GraphQLWS/Requests.swift @@ -67,8 +67,7 @@ public struct StopRequest: Equatable, Codable { public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: Self.CodingKeys.self) - if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_CONNECTION_TERMINATE - { + if try container.decode(RequestMessageType.self, forKey: .type) != .GQL_STOP { throw DecodingError.dataCorrupted( .init( codingPath: decoder.codingPath, diff --git a/Sources/GraphQLWS/Server.swift b/Sources/GraphQLWS/Server.swift index c831e68..eedc49a 100644 --- a/Sources/GraphQLWS/Server.swift +++ b/Sources/GraphQLWS/Server.swift @@ -26,7 +26,7 @@ where private var initialized = false private var initResult: InitPayloadResult? - private var subscriptionTasks = [String: Task]() + private var executionTasks = [String: Task]() /// Create a new server /// @@ -54,7 +54,7 @@ where } deinit { - subscriptionTasks.values.forEach { $0.cancel() } + executionTasks.values.forEach { $0.cancel() } } /// Listen and react to the provided async sequence of client messages. This function will block until the stream is completed. @@ -71,48 +71,64 @@ where do { request = try decoder.decode(Request.self, from: message) } catch { - try await self.error(.noType()) + try await messenger.error(.noType()) return } // handle incoming message switch request.type { case .GQL_CONNECTION_INIT: - guard - let connectionInitRequest = try? decoder.decode( + let connectionInitRequest: ConnectionInitRequest + do { + connectionInitRequest = try decoder.decode( ConnectionInitRequest.self, from: message ) - else { - try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_INIT)) + } catch { + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_CONNECTION_INIT, error: error) + ) return } try await onConnectionInit(connectionInitRequest, messenger) case .GQL_START: - guard let startRequest = try? decoder.decode(StartRequest.self, from: message) else { - try await error(.invalidRequestFormat(messageType: .GQL_START)) + let startRequest: StartRequest + do { + startRequest = try decoder.decode(StartRequest.self, from: message) + } catch { + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_START, error: error) + ) return } try await onStart(startRequest, messenger) case .GQL_STOP: - guard let stopRequest = try? decoder.decode(StopRequest.self, from: message) else { - try await error(.invalidRequestFormat(messageType: .GQL_STOP)) + let stopRequest: StopRequest + do { + stopRequest = try decoder.decode(StopRequest.self, from: message) + } catch { + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_STOP, error: error) + ) return } try await onStop(stopRequest) case .GQL_CONNECTION_TERMINATE: - guard - let connectionTerminateRequest = try? decoder.decode( + let connectionTerminateRequest: ConnectionTerminateRequest + do { + connectionTerminateRequest = try decoder.decode( ConnectionTerminateRequest.self, from: message ) - else { - try await error(.invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE)) + } catch { + try await messenger.error( + .invalidRequestFormat(messageType: .GQL_CONNECTION_TERMINATE, error: error) + ) return } try await onConnectionTerminate(connectionTerminateRequest, messenger) default: - try await error(.invalidType()) + try await messenger.error(.invalidType()) } } @@ -121,35 +137,30 @@ where _: Messenger ) async throws { guard !initialized else { - try await error(.tooManyInitializations()) + try await messenger.error(.tooManyInitializations()) return } do { initResult = try await onInit(connectionInitRequest.payload) } catch { - try await self.error(.unauthorized()) + try await messenger.error(.unauthorized()) return } initialized = true try await sendConnectionAck() - // TODO: Should we send the `ka` message? } private func onStart(_ startRequest: StartRequest, _: Messenger) async throws { guard initialized, let initResult else { - try await error(.notInitialized()) + try await messenger.error(.notInitialized()) return } let id = startRequest.id - if subscriptionTasks[id] != nil { - try await error(.subscriberAlreadyExists(id: id)) - } - let graphQLRequest = startRequest.payload - var isStreaming = false + let isStreaming: Bool do { isStreaming = try graphQLRequest.isSubscription() } catch { @@ -157,43 +168,54 @@ where return } - if isStreaming { - subscriptionTasks[id] = Task { + let onSubscribe = self.onSubscribe + let onExecute = self.onExecute + guard executionTasks[id] == nil else { + try await messenger.error(.subscriberAlreadyExists(id: id)) + return + } + executionTasks[id] = Task { + defer { + executionTasks.removeValue(forKey: id) + } + + if isStreaming { + let stream: SubscriptionSequenceType do { - let stream = try await onSubscribe(graphQLRequest, initResult) - for try await event in stream { - try Task.checkCancellation() - try await self.sendData(event, id: id) - } + stream = try await onSubscribe(graphQLRequest, initResult) } catch { try await sendError(error, id: id) - subscriptionTasks.removeValue(forKey: id) - throw error + return + } + for try await event in stream { + try await sendData(event, id: id) + } + } else { + let result: GraphQLResult + do { + result = try await onExecute(graphQLRequest, initResult) + } catch { + try await sendError(error, id: id) + return } - try await self.sendComplete(id: id) - subscriptionTasks.removeValue(forKey: id) - } - } else { - do { - let result = try await onExecute(graphQLRequest, initResult) try await sendData(result, id: id) - try await sendComplete(id: id) - } catch { - try await sendError(error, id: id) } + try await sendComplete(id: id) + try await onOperationComplete(id) } + } private func onStop(_ stopRequest: StopRequest) async throws { guard initialized else { - try await error(.notInitialized()) + try await messenger.error(.notInitialized()) return } let id = stopRequest.id - if let task = subscriptionTasks[id] { + if let task = executionTasks[id] { task.cancel() - subscriptionTasks.removeValue(forKey: id) + executionTasks.removeValue(forKey: id) } try await onOperationComplete(id) } @@ -201,10 +223,10 @@ where private func onConnectionTerminate(_: ConnectionTerminateRequest, _ messenger: Messenger) async throws { - for (_, subscriptionTask) in subscriptionTasks { - subscriptionTask.cancel() + for (_, task) in executionTasks { + task.cancel() } - subscriptionTasks.removeAll() + executionTasks.removeAll() try await messenger.close() } @@ -270,14 +292,4 @@ where private func sendError(_ error: Error, id: String) async throws { try await sendError([error], id: id) } - - /// Send an `error` response through the messenger - private func sendError(_ errorMessage: String, id: String) async throws { - try await sendError(GraphQLError(message: errorMessage), id: id) - } - - /// Send an error through the messenger and close the connection - private func error(_ error: GraphQLWSError) async throws { - try await messenger.error(error.message, code: error.code.rawValue) - } } diff --git a/Tests/GraphQLWSTests/GraphQLWSTests.swift b/Tests/GraphQLWSTests/GraphQLWSTests.swift index ad350ad..b4251e6 100644 --- a/Tests/GraphQLWSTests/GraphQLWSTests.swift +++ b/Tests/GraphQLWSTests/GraphQLWSTests.swift @@ -259,6 +259,64 @@ struct GraphqlTransportWSTests { ) } + /// Tests malformed requests include decoder details in the transport error + @Test func malformedRequestIncludesDecodingDetails() async throws { + let api = TestAPI() + let context = TestContext() + let server = Server>( + messenger: serverMessenger, + onInit: { _ in }, + onExecute: { graphQLRequest, _ in + try await api.execute( + request: graphQLRequest.query, + context: context + ) + }, + onSubscribe: { graphQLRequest, _ in + try await api.subscribe( + request: graphQLRequest.query, + context: context + ).get() + } + ) + let (incoming, continuation) = AsyncThrowingStream.makeStream() + + continuation.yield(Data(#"{"type":"stop"}"#.utf8)) + continuation.finish() + + try await server.listen(to: incoming) + + let error = await #expect(throws: TestMessengerError.self) { + for try await _ in serverMessenger.stream {} + } + #expect(error?.code == 4413) + #expect(error?.message.contains("Request message doesn't match 'stop' JSON format") == true) + #expect(error?.message.contains("keyNotFound") == true) + #expect(error?.message.contains(#""id""#) == true) + } + + /// Tests malformed responses include decoder details in the transport error + @Test func malformedResponseIncludesDecodingDetails() async throws { + let messenger = TestMessenger() + let client = Client(messenger: messenger) + let (incoming, continuation) = AsyncThrowingStream.makeStream() + + continuation.yield(Data(#"{"type":"data"}"#.utf8)) + continuation.finish() + + try await client.listen(to: incoming) + + let error = await #expect(throws: TestMessengerError.self) { + for try await _ in messenger.stream {} + } + #expect(error?.code == 4414) + #expect( + error?.message.contains("Response message doesn't match 'data' JSON format") == true + ) + #expect(error?.message.contains("keyNotFound") == true) + #expect(error?.message.contains(#""id""#) == true) + } + enum TestError: Error { case couldBeAnything }