From 625e9d3f648bdacdcb4ef9f289915a595006928d Mon Sep 17 00:00:00 2001 From: Ken Hu <106191785+kenhuuu@users.noreply.github.com> Date: Mon, 1 Jun 2026 21:10:02 -0700 Subject: [PATCH 1/3] Standardize request interceptors and switch to JSON request serialization Redesign the request interceptor API across all GLVs to use a consistent mutate-only contract. Interceptors receive a mutable HTTP request object and modify it in place. The driver auto-serializes the request body to JSON (application/json) after all interceptors run via an idempotent serializeBody() method on the request object. Responses remain GraphBinary. The GraphBinary request serializers are retained in each GLV but are no longer wired into the default request path. Users who need GraphBinary request bodies can write a custom interceptor that calls the serializer. JavaScript: Add serializeBody() to HttpRequest class. Add toJSON() to RequestMessage for clean serialization without field-list duplication. The bulkResults field is changed from a string to a boolean in gremlin-go and gremlin-dotnet to match the provider specification which defines it as a JSON boolean. Assisted-by: Kiro:claude-opus-4-6 --- docs/src/reference/gremlin-variants.asciidoc | 208 ++++++- docs/src/upgrade/release-4.x.x.asciidoc | 45 ++ gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs | 5 +- .../src/Gremlin.Net/Driver/Connection.cs | 53 +- .../src/Gremlin.Net/Driver/GremlinClient.cs | 75 +-- .../Gremlin.Net/Driver/HttpRequestContext.cs | 54 +- .../Gremlin.Net/Driver/IMessageSerializer.cs | 9 +- .../Driver/Remote/DriverRemoteConnection.cs | 4 +- .../GraphBinary4/RequestMessageSerializer.cs | 8 +- .../Driver/GremlinClientTests.cs | 89 +++ .../Driver/ConnectionTests.cs | 201 +++---- .../Driver/DriverRemoteConnectionTests.cs | 6 +- .../Driver/GremlinClientTests.cs | 22 +- .../Driver/HttpRequestContextTests.cs | 168 +++++- .../GraphBinary4MessageSerializerTests.cs | 2 +- .../tinkerpop/gremlin/driver/Cluster.java | 103 +--- .../tinkerpop/gremlin/driver/HttpRequest.java | 65 ++- .../gremlin/driver/RequestInterceptor.java | 20 +- .../tinkerpop/gremlin/driver/auth/Basic.java | 8 +- .../tinkerpop/gremlin/driver/auth/Sigv4.java | 15 +- .../handler/HttpGremlinRequestEncoder.java | 43 +- .../PayloadSerializingInterceptor.java | 73 --- .../driver/simple/SimpleHttpClient.java | 6 +- .../tinkerpop/gremlin/driver/ClusterTest.java | 143 ----- .../gremlin/driver/InterceptorTest.java | 219 ++++++++ .../gremlin/driver/auth/Sigv4Test.java | 56 +- gremlin-go/driver/auth.go | 16 +- gremlin-go/driver/client.go | 10 + gremlin-go/driver/connection.go | 22 +- gremlin-go/driver/connection_test.go | 148 ++++- gremlin-go/driver/driverRemoteConnection.go | 10 + gremlin-go/driver/interceptor.go | 44 +- gremlin-go/driver/interceptor_test.go | 526 +++++++++++++----- gremlin-go/driver/request.go | 11 +- gremlin-go/driver/request_test.go | 2 +- gremlin-go/driver/serializer.go | 22 +- .../gremlin-javascript/lib/driver/auth.ts | 8 +- .../lib/driver/connection.ts | 57 +- .../lib/driver/http-request.ts | 72 +++ .../lib/driver/request-message.ts | 38 ++ gremlin-js/gremlin-javascript/lib/index.ts | 2 + .../test/integration/client-tests.js | 86 +++ .../gremlin-javascript/test/unit/auth-test.js | 73 ++- .../test/unit/connection-test.js | 150 +++++ .../test/unit/http-request-test.js | 238 ++++++++ .../src/main/python/examples/connections.py | 2 - .../driver/aiohttp/transport.py | 5 - .../main/python/gremlin_python/driver/auth.py | 26 +- .../python/gremlin_python/driver/client.py | 3 - .../gremlin_python/driver/connection.py | 50 +- .../driver/driver_remote_connection.py | 2 - .../gremlin_python/driver/http_request.py | 65 +++ .../src/main/python/tests/feature/terrain.py | 2 +- .../main/python/tests/integration/conftest.py | 20 +- .../tests/integration/driver/test_auth.py | 57 +- .../tests/integration/driver/test_client.py | 44 ++ .../tests/unit/driver/test_http_streaming.py | 77 +-- .../tests/unit/driver/test_interceptor.py | 221 ++++++++ .../handler/HttpRequestMessageDecoder.java | 6 + .../server/GremlinDriverIntegrateTest.java | 118 +++- .../GremlinServerAuthIntegrateTest.java | 9 +- .../gremlin/server/TestClientFactory.java | 5 - .../socket/server/TestHttpGremlinHandler.java | 25 +- .../gremlin/util/message/RequestMessage.java | 9 + 64 files changed, 2890 insertions(+), 1091 deletions(-) delete mode 100644 gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/interceptor/PayloadSerializingInterceptor.java delete mode 100644 gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/ClusterTest.java create mode 100644 gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/InterceptorTest.java create mode 100644 gremlin-js/gremlin-javascript/lib/driver/http-request.ts create mode 100644 gremlin-js/gremlin-javascript/test/unit/connection-test.js create mode 100644 gremlin-js/gremlin-javascript/test/unit/http-request-test.js create mode 100644 gremlin-python/src/main/python/gremlin_python/driver/http_request.py create mode 100644 gremlin-python/src/main/python/tests/unit/driver/test_interceptor.py diff --git a/docs/src/reference/gremlin-variants.asciidoc b/docs/src/reference/gremlin-variants.asciidoc index 7a8405b2924..8003ffb73ac 100644 --- a/docs/src/reference/gremlin-variants.asciidoc +++ b/docs/src/reference/gremlin-variants.asciidoc @@ -271,8 +271,50 @@ More details can be found in provider docs link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_graph_driver_provider_requirements[here].|true |RequestInterceptors |Functions that modify HTTP requests before sending. Used for authentication and custom headers. |empty |PDTRegistry |A `*PDTRegistry` for hydrating and dehydrating <>. |`nil` +|Auth |A single RequestInterceptor for authentication (e.g. `BasicAuth`). Always appended to the end of the interceptor list so it runs last. |nil |========================================================= +[[gremlin-go-interceptors]] +=== RequestInterceptor + +Gremlin-Go allows modification of the underlying HTTP request through `RequestInterceptor` functions. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin-Go is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A `RequestInterceptor` is a function with the signature `func(*HttpRequest) error` that mutates the `HttpRequest` +in place. A slice of these is maintained and will be run sequentially for each request. When creating a +`DriverRemoteConnection` or `Client`, the `RequestInterceptors` field on the settings struct accepts an ordered +slice of interceptors. Order matters, so if one interceptor depends on another's output, ensure they are added in +the correct order. Note that authentication (e.g. `BasicAuth`, `SigV4Auth`) is also implemented using interceptors. +The `auth` convenience on connection settings appends the auth interceptor to the end of the list so it runs last. + +[source,go] +---- +remote, err := gremlingo.NewDriverRemoteConnection("http://localhost:8182/gremlin", + func(settings *gremlingo.DriverRemoteConnectionSettings) { + settings.RequestInterceptors = []gremlingo.RequestInterceptor{ + gremlingo.BasicAuth("username", "password"), + func(req *gremlingo.HttpRequest) error { + req.Headers.Set("X-Custom-Header", "value") + return nil + }, + } + }) +---- + +Each interceptor receives a `*gremlingo.HttpRequest` whose `Body` field initially contains a `*RequestMessage`. +The `RequestMessage` struct has mutable fields (`Gremlin string` and `Fields map[string]interface{}`), so an +interceptor can modify them directly or replace the body entirely. By default, the driver serializes the body +to JSON (`application/json`) by calling `HttpRequest.SerializeBody()` after all interceptors have run. An +interceptor that needs the serialized bytes (for example, to compute a signature hash) can call +`SerializeBody()` itself; the method is idempotent. If you require a GraphBinary-encoded request body instead +of JSON, you can write a custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer +and sets the resulting `[]byte` as the body. + +For an example of a simple `RequestInterceptor` that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-go/driver/auth.go[basic authentication]. + [[gremlin-go-strategies]] === Traversal Strategies @@ -967,9 +1009,11 @@ The following table describes the various configuration options for the Gremlin |connectionPool.trustStore |File location for a SSL Certificate Chain to use when SSL is enabled. If this value is not provided and SSL is enabled, the default `TrustManager` will be used. |_none_ |connectionPool.trustStorePassword |The password of the `trustStore` if it is password-protected |_none_ |connectionPool.validationRequest |A script that is used to test server connectivity. A good script to use is one that evaluates quickly and returns no data. The default simply returns an empty string, but if a graph is required by a particular provider, a good traversal might be `g.inject()`. |_''_ +|auth |An authentication `RequestInterceptor` (e.g. `Auth.basic()`). Always appended to the end of the interceptor list so it runs last. |_none_ |bulkResults |Sets whether the server should attempt to get bulk results or not. |false |enableUserAgentOnConnect |Enables sending a user agent to the server during connection requests. More details can be found in provider docs link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_graph_driver_provider_requirements[here].|true |hosts |The list of hosts that the driver will connect to. |localhost +|interceptors |A list of `RequestInterceptor` instances that modify the HTTP request before sending. |empty |nioPoolSize |Size of the pool for handling request/response operations. |available processors |password |The password to submit on requests that require authentication. |_none_ |path |The URL path to the Gremlin Server. |_/gremlin_ @@ -1273,21 +1317,26 @@ to "g" and "g1" and "g2" are automatically rebound into "g" on the server-side. Gremlin-Java allows for modification of the underlying HTTP request through the use of `RequestInterceptors`. This is intended to be an advanced feature which means that you will need to understand how the implementation works in order -to safely utilize it. Gremlin-Java is written in a way that you should be able to interact with a TinkerPop-enabled -server without having to use interceptors. This is intended for cases where the server has special capabilities. - -A `RequestInterceptor` is simply a `UnaryOperator`. A list of these are maintained and will be run sequentially for -each request. When building a `Cluster` instance, the methods `addInterceptorAfter()`, `addInterceptorBefore()`, -`addInterceptor()`, and `removeInterceptor()` can be used to add or remove interceptors. It's important to remember -that order matters so if one interceptor depends on another's output then ensure they are added in the correct order. -Note that `Auth` is also implemented using interceptors, and `Auth` is always run last after your list of interceptors -has already ran. By default, the `PayloadSerializingInterceptor` with the name `serializer` is added to your list of -interceptors. This interceptor is used for serializing the body of the request. The first interceptor is provided with -a `org.apache.tinkerpop.gremlin.driver.HttpRequest` that contains a `RequestMessage` in the body. As a reminder -`RequestMessage` is immutable and only certain keys can be added to them. If you want to customize the body by adding -other fields, you will need to make a different copy of the `RequestMessage` or completely change the body to contain a -different data type. The very last interceptor should have a `org.apache.tinkerpop.gremlin.driver.HttpRequest` that -contains a byte[] in the body. +to safely utilize it. Gremlin-Java is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A `RequestInterceptor` is a function that mutates the `HttpRequest` in place (it returns nothing). +A list of these is maintained and will be run sequentially for each request. When building a +`Cluster` instance, the `interceptors()` method accepts an ordered list of interceptors. It's +important to remember that order matters so if one interceptor depends on another's output then +ensure they are added in the correct order. Note that `Auth` is also implemented using interceptors. +The `auth()` convenience on the `Cluster.Builder` appends the auth interceptor to the end of the list so it runs last. + +Each interceptor is provided with a `org.apache.tinkerpop.gremlin.driver.HttpRequest` that contains +a `RequestMessage` in the body. As a reminder `RequestMessage` is immutable and only certain keys +can be added to them. If you want to customize the body by adding other fields, you will need to +make a different copy of the `RequestMessage` or completely change the body to contain a different +data type. By default, the driver serializes the body to JSON (`application/json`) by calling +`HttpRequest.serializeBody()` after all interceptors have run. An interceptor that needs the +serialized bytes (for example, to compute a signature hash) can call `serializeBody()` itself; the +method is idempotent. If you require a GraphBinary-encoded request body instead of JSON, you can +write a custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer and +sets the resulting `byte[]` as the body. For an example of a simple `RequestInterceptor` that only modifies the header of the request take a look at link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java[basic authentication]. @@ -1826,6 +1875,7 @@ can be passed in the constructor of a new `Client` or `DriverRemoteConnection` : |options.traversalSource |String |The traversal source. |'g' |options.headers |Object |Additional HTTP header key/values included with each request. |undefined |options.interceptors |RequestInterceptor/RequestInterceptor[] |One or more functions that can modify the HTTP request before it is sent. |undefined +|options.auth |RequestInterceptor |An auth interceptor that is always appended to the end of the interceptor list so it runs last. |undefined |options.preciseNumbers |Boolean |When `true`, wraps deserialized numbers in typed wrappers that preserve the server's original type. |undefined |options.reader |GraphBinaryReader |The reader to use for deserializing responses. |GraphBinaryReader |options.writer |GraphBinaryWriter |The writer to use for serializing requests. |GraphBinaryWriter @@ -1834,6 +1884,47 @@ can be passed in the constructor of a new `Client` or `DriverRemoteConnection` : |options.pdtRegistry |ProviderDefinedTypeRegistry |A registry for hydrating and dehydrating <>. |undefined |========================================================= +[[gremlin-javascript-interceptors]] +=== RequestInterceptor + +Gremlin-JavaScript allows modification of the underlying HTTP request through `RequestInterceptor` functions. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin-JavaScript is written in a way that you should be able to interact with most + TinkerPop-enabled servers without having to use interceptors. This is intended for cases where the server has special + capabilities. + +A `RequestInterceptor` is a function with the signature `(request: HttpRequest) => void | Promise` that mutates +the `HttpRequest` in place (it returns nothing). A list of these is maintained and will be run sequentially for each +request. When creating a `DriverRemoteConnection` or `Client`, the `interceptors` option accepts one interceptor or an +array of interceptors. Order matters, so if one interceptor depends on another's output, ensure they are added in the +correct order. Note that authentication (e.g. `auth.basic()`, `auth.sigv4()`) is also implemented using interceptors. +The `auth` convenience on connection options appends the auth interceptor to the end of the list so it runs last. + +[source,javascript] +---- +const { auth } = gremlin.driver; +const g = traversal().with_(new DriverRemoteConnection('http://localhost:8182/gremlin', { + interceptors: [ + auth.basic('myuser', 'mypassword'), + (request) => { + request.headers['X-Custom-Header'] = 'value'; + } + ] +})); +---- + +Each interceptor receives an `HttpRequest` whose `body` field initially contains a `RequestMessage`. The +`RequestMessage` class is immutable (fields are private, accessed via getters). To customize the body by adding other +fields, you must build a new `RequestMessage` via `RequestMessage.build(gremlin)` and its builder methods, or set a +completely different body (such as a `Buffer`). By default, the driver serializes the body to JSON +(`application/json`) by calling `HttpRequest.serializeBody()` after all interceptors have run. An interceptor that +needs the serialized bytes (for example, to compute a signature hash) can call `serializeBody()` itself; the method +is idempotent. If you require a GraphBinary-encoded request body instead of JSON, you can write a custom interceptor +that serializes the `RequestMessage` with the GraphBinary writer and sets the resulting `Buffer` as the body. + +For an example of a simple `RequestInterceptor` that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-js/gremlin-javascript/lib/driver/auth.ts[basic authentication]. + [[gremlin-javascript-logging]] === Logging @@ -2458,8 +2549,52 @@ The following options can be passed to the `GremlinClient` constructor: |loggerFactory |An `ILoggerFactory` for logging. |`NullLoggerFactory` |interceptors |A list of `Func` that modify HTTP requests before sending. |_none_ |pdtRegistry |A `ProviderDefinedTypeRegistry` for hydrating and dehydrating <>. |`null` +|auth |A single `Func` for authentication. Always appended to the end of the interceptor list so it runs last. |_none_ |========================================================= +[[gremlin-dotnet-interceptors]] +=== RequestInterceptor + +Gremlin.Net allows modification of the underlying HTTP request through request interceptors. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin.Net is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A request interceptor is a `Func` that mutates the `HttpRequestContext` in place (it returns +`Task` for async support but does not produce a value). A list of these is maintained and will be run sequentially for +each request. When creating a `GremlinClient`, the `interceptors` parameter accepts an ordered collection of +interceptors. Order matters, so if one interceptor depends on another's output, ensure they are added in the correct +order. Note that authentication (e.g. `Auth.BasicAuth()`, `Auth.SigV4Auth()`) is also implemented using interceptors. +These factory methods return interceptor delegates that can be included in the `interceptors` list. Alternatively, the +`auth` parameter on `GremlinClient` appends the auth interceptor to the end of the list so it runs last. + +[source,csharp] +---- +var server = new GremlinServer("localhost", 8182); +using var client = new GremlinClient(server, + interceptors: new Func[] + { + Auth.BasicAuth("username", "password"), + context => + { + context.Headers["X-Custom-Header"] = "value"; + return Task.CompletedTask; + } + }); +---- + +Each interceptor receives an `HttpRequestContext` whose `Body` property initially contains a `RequestMessage`. +The `RequestMessage` has an immutable `Gremlin` property and a mutable `Fields` dictionary, so an interceptor can +add or modify fields directly. Alternatively, you can replace the body entirely. By default, the driver serializes +the body to JSON (`application/json`) by calling `HttpRequestContext.SerializeBody()` after all interceptors have +run. An interceptor that needs the serialized bytes (for example, to compute a signature hash) can call +`SerializeBody()` itself; the method is idempotent. If you require a GraphBinary-encoded request body instead of +JSON, you can write a custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer and +sets the resulting `byte[]` as the body. + +For an example of a simple request interceptor that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs[basic authentication]. + [[gremlin-dotnet-logging]] === Logging @@ -3058,7 +3193,7 @@ can be passed to the `Client` or `DriverRemoteConnection` instance as keyword ar |request_serializer |The request serializer implementation.|`gremlin_python.driver.serializer.GraphBinarySerializersV4` |response_serializer |The response serializer implementation.|`gremlin_python.driver.serializer.GraphBinarySerializersV4` |interceptors |The request interceptors to run after request serialization.|`None` -|auth |The authentication scheme to use when submitting requests that require authentication. |`None` +|auth |An authentication interceptor. Always appended to the end of the interceptor list so it runs last. |`None` |pool_size |The number of connections used by the pool. |4 |enable_user_agent_on_connect |Enables sending a user agent to the server during connection requests. More details can be found in provider docs @@ -3080,6 +3215,47 @@ g = traversal().with_( read_timeout=30)) ---- +[[gremlin-python-interceptors]] +=== RequestInterceptor + +Gremlin-Python allows modification of the underlying HTTP request through request interceptors. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin-Python is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A request interceptor is a callable that receives an `HttpRequest` and mutates it in place (it returns nothing). +A list of these is maintained and will be run sequentially for each request. When creating a +`DriverRemoteConnection` or `Client`, the `interceptors` keyword argument accepts an ordered list of interceptors. +Order matters, so if one interceptor depends on another's output, ensure they are added in the correct order. Note +that authentication (e.g. `basic()`, `sigv4()`) is also implemented using interceptors. The `auth` convenience +parameter on `DriverRemoteConnection` and `Client` appends the auth interceptor to the end of the list so it +runs last. + +[source,python] +---- +from gremlin_python.driver.auth import basic + +g = traversal().with_(DriverRemoteConnection( + 'http://localhost:8182/gremlin', 'g', + auth=basic('username', 'password'), + interceptors=[ + lambda req: req.headers.update({'X-Custom-Header': 'value'}) + ])) +---- + +Each interceptor receives an `HttpRequest` whose `body` attribute initially contains a `RequestMessage`. +The `RequestMessage` is a named tuple with mutable `fields` (a `dict`) and a `gremlin` string. Because +`fields` is a regular dictionary, an interceptor can add or modify entries directly. To replace the body +entirely, assign a new value to `req.body`. By default, the driver serializes the body to JSON +(`application/json`) by calling `HttpRequest.serialize_body()` after all interceptors have run. An interceptor +that needs the serialized bytes (for example, to compute a signature hash) can call `serialize_body()` itself; +the method is idempotent. If you require a GraphBinary-encoded request body instead of JSON, you can write a +custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer and sets the resulting +`bytes` as the body. + +For an example of a simple request interceptor that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-python/src/main/python/gremlin_python/driver/auth.py[basic authentication]. + [[gremlin-python-strategies]] === Traversal Strategies diff --git a/docs/src/upgrade/release-4.x.x.asciidoc b/docs/src/upgrade/release-4.x.x.asciidoc index ac8debb0a99..0f78dde55b9 100644 --- a/docs/src/upgrade/release-4.x.x.asciidoc +++ b/docs/src/upgrade/release-4.x.x.asciidoc @@ -32,6 +32,37 @@ complete list of all the modifications that are part of this release. === Upgrading for Users +==== Request Interceptors + +When TinkerPop supported WebSockets prior to 4.0.0, the Java driver offered a `RequestInterceptor` interface (and +its predecessor, `HandshakeInterceptor`) that allowed modification of the raw Netty `FullHttpRequest`. For WebSocket +connections, the interceptor only ran on the initial HTTP upgrade handshake. With the move to HTTP, the notion of +the "interceptor" has shifted to a per-request concern that is now standardized across all GLVs. + +All GLVs now support request interceptors, which allow modification of the HTTP request before it is sent to the +server. An interceptor is a function that receives the mutable HTTP request object and can modify headers, the +request body, the URI, or the HTTP method. Interceptors are run in the order they are registered. + +The most common use case for interceptors is authentication (e.g., SigV4 signing), but they can also be used to add +provider-specific fields, inject custom headers, or transform the request body. + +Here is a simple Java example that adds a custom header: + +[source,java] +---- +Cluster cluster = Cluster.build("localhost") + .interceptors(request -> request.headers().put("X-Custom-Header", "value")) + .create(); +---- + +Authentication is also an interceptor. Each GLV provides convenience methods (e.g., `Auth.basic()`, `Auth.sigv4()`) +that return interceptors and can be registered alongside custom ones. + +For full details on the interceptor API for each language variant, refer to the RequestInterceptor section in +each GLV's documentation in the +link:https://tinkerpop.apache.org/docs/x.y.z/reference/#gremlin-drivers-variants[Gremlin Drivers and Variants] +reference. + ==== Gremlator link:https://gremlator.com[Gremlator] has been rebuilt entirely in JavaScript as a browser-based single-page @@ -567,6 +598,20 @@ registration. ==== Graph Driver Providers +===== Request Interceptors + +Graph driver providers should implement request interceptor support in their drivers. Interceptors allow users to +modify the HTTP request before it is sent, which is essential for authentication schemes (like SigV4), adding +provider-specific request fields, and other server-specific capabilities. + +All TinkerPop reference drivers now include interceptor support. The interceptor contract is standardized across all +GLVs: interceptors receive a mutable HTTP request object, can modify headers/body/URI, and the driver auto-serializes +the request body to JSON after all interceptors have run. + +For the full specification of how interceptors should behave, see the +link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_http_request_interceptor[HTTP Request Interceptor] +section in the provider documentation. + == TinkerPop 4.0.0-beta.2 *Release Date: April 1, 2026* diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs index e10aafe55fb..eaf706f6dda 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs @@ -94,6 +94,9 @@ public static Func SigV4Auth( } } + // Ensure the body is serialized before signing so we have bytes to hash. + context.SerializeBody(); + // Use the async path — important for credential providers that perform // network I/O (e.g. IMDS on EC2, ECS task role endpoint). var immutableCreds = await cachedProvider.GetCredentialsAsync() @@ -116,7 +119,7 @@ private static void SignRequest(HttpRequestContext context, ? bytes : throw new InvalidOperationException( "SigV4 signing requires Body to be byte[]. " + - "Ensure serialization occurs before the SigV4 interceptor."), + "Ensure SerializeBody() was called before signing."), AuthenticationRegion = clientConfig.AuthenticationRegion, OverrideSigningServiceName = clientConfig.AuthenticationServiceName, }; diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs index e5c94611d4a..82a0d17724a 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs @@ -32,8 +32,6 @@ using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; using Gremlin.Net.Process; -using Gremlin.Net.Process.Traversal; -using Gremlin.Net.Structure.IO; namespace Gremlin.Net.Driver { @@ -44,7 +42,6 @@ internal class Connection : IDisposable { private readonly HttpClient _httpClient; private readonly Uri _uri; - private readonly IMessageSerializer? _requestSerializer; private readonly IMessageSerializer _responseSerializer; private readonly ConnectionSettings _settings; private readonly IReadOnlyList> _interceptors; @@ -55,24 +52,15 @@ internal class Connection : IDisposable /// so a single instance handles concurrent requests efficiently. /// /// The Gremlin Server URI. - /// - /// The serializer for outgoing requests. When non-null, the request body is serialized - /// to byte[] before interceptors run and the Content-Type header is set - /// automatically. When null, the body is passed as a - /// and an interceptor is responsible for serializing it to byte[] and setting - /// Content-Type. This follows the Python driver's request_serializer=None - /// pattern. - /// /// The serializer for incoming responses (always required). /// Connection settings. /// Optional request interceptors. - public Connection(Uri uri, IMessageSerializer? requestSerializer, + public Connection(Uri uri, IMessageSerializer responseSerializer, ConnectionSettings settings, IReadOnlyList>? interceptors = null) { _uri = uri; - _requestSerializer = requestSerializer; _responseSerializer = responseSerializer; _settings = settings; _interceptors = interceptors ?? Array.Empty>(); @@ -97,13 +85,12 @@ public Connection(Uri uri, IMessageSerializer? requestSerializer, /// /// Constructor that accepts a pre-configured HttpClient (for testing). /// - internal Connection(Uri uri, IMessageSerializer? requestSerializer, + internal Connection(Uri uri, IMessageSerializer responseSerializer, ConnectionSettings settings, HttpClient httpClient, IReadOnlyList>? interceptors = null) { _uri = uri; - _requestSerializer = requestSerializer; _responseSerializer = responseSerializer; _settings = settings; _httpClient = httpClient; @@ -139,7 +126,7 @@ public async Task> SubmitAsync(RequestMessage requestMessage, headers["bulkResults"] = "true"; } - // Promote transactionId to HTTP header before serialization. + // Promote transactionId to HTTP header before interceptors run. // The field remains in the serialized body as well (dual transmission // per the HTTP transaction protocol specification). if (requestMessage.Fields.TryGetValue(Tokens.ArgsTransactionId, out var txIdObj) && @@ -148,26 +135,20 @@ public async Task> SubmitAsync(RequestMessage requestMessage, headers["X-Transaction-Id"] = txId; } - object body; - if (_requestSerializer != null) - { - var requestBytes = await _requestSerializer.SerializeMessageAsync(requestMessage, cancellationToken) - .ConfigureAwait(false); - body = requestBytes; - headers["Content-Type"] = _requestSerializer.MimeType; - } - else - { - body = requestMessage; - } - - var context = new HttpRequestContext("POST", _uri, headers, body); + var context = new HttpRequestContext("POST", _uri, headers, requestMessage); foreach (var interceptor in _interceptors) { await interceptor(context).ConfigureAwait(false); } + // Auto-serialize after interceptors: idempotent if already serialized by an interceptor. + // Skip if body is HttpContent (an escape hatch for full wire-format control). + if (context.Body is not System.Net.Http.HttpContent) + { + context.SerializeBody(); + } + // The HttpResponseMessage is NOT disposed here — ownership transfers to // StreamingResponseContext via the background task. HttpResponseMessage response; @@ -184,10 +165,8 @@ public async Task> SubmitAsync(RequestMessage requestMessage, else { throw new InvalidOperationException( - "Request body must be byte[] or HttpContent after all interceptors complete, " + - "but found " + (context.Body?.GetType().Name ?? "null") + - ". Either provide a requestSerializer or add an interceptor " + - "that serializes the RequestMessage."); + "Request body must be byte[] or HttpContent after serialization, " + + "but found " + (context.Body?.GetType().Name ?? "null") + "."); } foreach (var header in context.Headers) @@ -196,6 +175,10 @@ public async Task> SubmitAsync(RequestMessage requestMessage, { httpRequest.Content.Headers.ContentType = new MediaTypeHeaderValue(header.Value); } + else if (string.Equals(header.Key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + // Content-Length is set automatically by ByteArrayContent; skip to avoid conflict. + } else { httpRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); @@ -345,5 +328,7 @@ protected virtual void Dispose(bool disposing) } #endregion + + internal IReadOnlyList> Interceptors => _interceptors; } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs index 4599e47de5e..402a8f77111 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs @@ -24,6 +24,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; @@ -48,86 +49,64 @@ public class GremlinClient : IGremlinClient /// Initializes a new instance of the class for the specified Gremlin Server. /// /// The the requests should be sent to. - /// - /// A instance to serialize outgoing request messages. - /// When null, the request body is passed as a to - /// interceptors, and an interceptor must serialize it to byte[] and set the - /// Content-Type header. This follows the Python driver's - /// request_serializer=None pattern. - /// /// /// A instance to deserialize incoming response messages. - /// Always required. + /// Defaults to . /// /// The for the HTTP connection. /// A factory to create loggers. If not provided, then nothing will be logged. /// /// An optional list of request interceptors. Each interceptor receives a mutable /// and can modify headers, body, URI, and method - /// before the request is sent. + /// before the request is sent. Interceptors that need the serialized bytes (e.g. + /// for payload signing) should call . + /// + /// + /// An optional auth interceptor. As a convenience, this is appended to the end of the + /// interceptor list so it runs last (after any user interceptors have modified the request). + /// This is equivalent to including the auth interceptor as the last element of . /// /// /// An optional for automatic hydration of /// provider-defined types. /// - public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? requestSerializer, - IMessageSerializer responseSerializer, + public GremlinClient(GremlinServer gremlinServer, + IMessageSerializer? responseSerializer = null, ConnectionSettings? connectionSettings = null, ILoggerFactory? loggerFactory = null, IReadOnlyList>? interceptors = null, + Func? auth = null, ProviderDefinedTypeRegistry? pdtRegistry = null) { connectionSettings ??= new ConnectionSettings(); LoggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + var actualResponseSerializer = responseSerializer ?? new GraphBinary4MessageSerializer(); + if (pdtRegistry != null) { - requestSerializer?.SetPdtRegistry(pdtRegistry); - responseSerializer.SetPdtRegistry(pdtRegistry); + actualResponseSerializer.SetPdtRegistry(pdtRegistry); + } + + // Append auth interceptor to the end of the list so it runs last. + IReadOnlyList>? allInterceptors = interceptors; + if (auth != null) + { + var list = interceptors?.ToList() ?? new List>(); + list.Add(auth); + allInterceptors = list; } _connection = new Connection( gremlinServer.Uri, - requestSerializer, - responseSerializer, + actualResponseSerializer, connectionSettings, - interceptors); + allInterceptors); var logger = LoggerFactory.CreateLogger(); logger.InitializedHttpConnection(gremlinServer.Uri); } - /// - /// Initializes a new instance of the class with a single - /// serializer used for both request serialization and response deserialization. - /// This is the backward-compatible convenience constructor. - /// - /// The the requests should be sent to. - /// - /// A instance used for both request serialization and - /// response deserialization. Defaults to . - /// - /// The for the HTTP connection. - /// A factory to create loggers. If not provided, then nothing will be logged. - /// - /// An optional list of request interceptors. - /// - /// - /// An optional for automatic hydration of - /// provider-defined types. - /// - public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? messageSerializer = null, - ConnectionSettings? connectionSettings = null, - ILoggerFactory? loggerFactory = null, - IReadOnlyList>? interceptors = null, - ProviderDefinedTypeRegistry? pdtRegistry = null) - : this(gremlinServer, - messageSerializer ?? new GraphBinary4MessageSerializer(), - messageSerializer ?? new GraphBinary4MessageSerializer(), - connectionSettings, loggerFactory, interceptors, pdtRegistry) - { - } - /// public async Task> SubmitAsync(RequestMessage requestMessage, CancellationToken cancellationToken = default) @@ -199,5 +178,7 @@ internal void UntrackTransaction(RemoteTransaction tx) } #endregion + + internal Connection Connection => _connection; } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs index 9fc170b45fe..ac98a91d8f7 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs @@ -24,6 +24,8 @@ using System; using System.Collections.Generic; using System.Security.Cryptography; +using System.Text.Json; +using Gremlin.Net.Driver.Messages; namespace Gremlin.Net.Driver { @@ -48,10 +50,10 @@ public class HttpRequestContext public Dictionary Headers { get; } /// - /// Gets or sets the request body. This is byte[] when serialization has occurred - /// (default path), or RequestMessage when serialization is deferred to interceptors - /// (requestSerializer = null). Interceptors may also set this to an - /// instance for full control over the wire format. + /// Gets or sets the request body. Initially a , becomes + /// byte[] after is called. Interceptors may also + /// set this to an instance for full control + /// over the wire format. /// public object Body { get; set; } @@ -61,8 +63,8 @@ public class HttpRequestContext /// The HTTP method. /// The request URI. /// The HTTP headers. - /// The request body. Typically byte[] (post-serialization) or - /// RequestMessage (pre-serialization). + /// The request body. Typically RequestMessage (pre-serialization) + /// or byte[] (post-serialization). public HttpRequestContext(string method, Uri uri, Dictionary headers, object body) { Method = method ?? throw new ArgumentNullException(nameof(method)); @@ -71,6 +73,46 @@ public HttpRequestContext(string method, Uri uri, Dictionary hea Body = body; } + /// + /// Serializes the body to JSON if it is still a . + /// Sets the body to the resulting byte[], and sets the Content-Type + /// and Content-Length headers. This method is idempotent: if the body is + /// already byte[], it returns those bytes without re-serializing. + /// + /// The serialized body bytes. + /// + /// Thrown if the body is neither nor byte[]. + /// + public byte[] SerializeBody() + { + if (Body is byte[] existing) + { + return existing; + } + + if (Body is RequestMessage message) + { + var payload = new Dictionary + { + [Tokens.ArgsGremlin] = message.Gremlin + }; + foreach (var field in message.Fields) + { + payload[field.Key] = field.Value; + } + + var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(payload); + Body = jsonBytes; + Headers["Content-Type"] = "application/json"; + Headers["Content-Length"] = jsonBytes.Length.ToString(); + return jsonBytes; + } + + throw new InvalidOperationException( + "Cannot serialize body of type " + (Body?.GetType().Name ?? "null") + + ". Body must be RequestMessage or byte[]."); + } + /// /// Returns the lowercase hex-encoded SHA-256 digest of the body. /// Throws if is not byte[], diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs index 831a8ee379a..4d69834e90a 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs @@ -31,23 +31,24 @@ namespace Gremlin.Net.Driver { /// - /// Serializes data to and from Gremlin Server. + /// Serializes and deserializes data to and from Gremlin Server. /// public interface IMessageSerializer { /// /// Gets the MIME type produced by this serializer (e.g. /// "application/vnd.graphbinary-v4.0"). Used by the driver to set - /// Content-Type and Accept headers automatically. + /// the Accept header. /// string MimeType { get; } /// - /// Serializes a . + /// Serializes a to bytes. This can be called from + /// interceptors to produce a serialized request body. /// /// The to serialize. /// The token to cancel the operation. The default value is None. - /// The serialized message. + /// The serialized message bytes. Task SerializeMessageAsync(RequestMessage requestMessage, CancellationToken cancellationToken = default); diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs index e6d5c032f90..3112bf783ce 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs @@ -137,11 +137,11 @@ public async Task> SubmitAsync(GremlinLan } } - // Default bulkResults to "true" if not set per-request + // Default bulkResults to true if not set per-request // (consistent with Java RequestOptions.fromGremlinLang and Python extract_request_options) if (!requestMsg.HasField(Tokens.ArgsBulkResults)) { - requestMsg.AddField(Tokens.ArgsBulkResults, "true"); + requestMsg.AddField(Tokens.ArgsBulkResults, true); } var resultSet = await _client.SubmitAsync(requestMsg.Create(), cancellationToken) diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs index 675f01ee2ec..1c35bb822d6 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs @@ -1,4 +1,4 @@ -#region License +#region License /* * Licensed to the Apache Software Foundation (ASF) under one @@ -31,6 +31,12 @@ namespace Gremlin.Net.Structure.IO.GraphBinary4 /// /// Serializes a in the GraphBinary 4.0 wire format. /// + /// + /// This serializer is no longer used by the default driver flow (which now serializes + /// requests as JSON via ), + /// but remains available for custom interceptors or alternative transport protocols that + /// require GraphBinary-encoded requests. + /// public class RequestMessageSerializer { /// diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs index a5fdac7c8e3..6370a5bb185 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs @@ -272,5 +272,94 @@ public async Task ShouldHandlePdtInCollection() Assert.Equal(3, p2.Fields["x"]); Assert.Equal(4, p2.Fields["y"]); } + + [Fact] + public async Task ShouldAutoSerializeRequestMessageWithInterceptorMutation() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + var interceptors = new List> + { + ctx => + { + if (ctx.Body is RequestMessage msg) + { + var g = msg.Fields.ContainsKey("g") ? (string)msg.Fields["g"] : "g"; + ctx.Body = RequestMessage.Build("g.inject(99)").AddG(g).Create(); + } + return Task.CompletedTask; + } + }; + + using var gremlinClient = new GremlinClient(gremlinServer, interceptors: interceptors); + + var response = await gremlinClient.SubmitWithSingleResultAsync("g.inject(1)"); + Assert.Equal(99, response); + } + + [Fact] + public async Task ShouldPropagateExceptionThrownDuringInterceptor() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + var callCount = 0; + var interceptors = new List> + { + ctx => + { + callCount++; + if (callCount == 1) + { + throw new InvalidOperationException("interceptor broke"); + } + return Task.CompletedTask; + } + }; + + using var gremlinClient = new GremlinClient(gremlinServer, interceptors: interceptors); + + // First request should fail with interceptor error + var ex = await Assert.ThrowsAsync(async () => + { + var resultSet = await gremlinClient.SubmitAsync("g.inject(1)"); + await resultSet.ToListAsync(); + }); + Assert.Contains("interceptor broke", ex.Message); + + // Subsequent request should succeed, proving connection recovery + var response = await gremlinClient.SubmitWithSingleResultAsync("g.inject(2)"); + Assert.Equal(2, response); + } + + [Fact] + public async Task ShouldPropagateErrorWhenInterceptorSetsUnsupportedBodyType() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + var callCount = 0; + var interceptors = new List> + { + ctx => + { + callCount++; + if (callCount == 1) + { + ctx.Body = 42; + } + return Task.CompletedTask; + } + }; + + using var gremlinClient = new GremlinClient(gremlinServer, interceptors: interceptors); + + // First request should fail with serialization error + var ex = await Assert.ThrowsAsync(async () => + { + var resultSet = await gremlinClient.SubmitAsync("g.inject(1)"); + await resultSet.ToListAsync(); + }); + Assert.Contains("Cannot serialize body", ex.Message); + + // Subsequent request should succeed, proving connection recovery + var response = await gremlinClient.SubmitWithSingleResultAsync("g.inject(2)"); + Assert.Equal(2, response); + } } } diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs index 6a8c2b78569..56010293637 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs @@ -33,7 +33,6 @@ using System.Threading.Tasks; using Gremlin.Net.Driver; using Gremlin.Net.Driver.Messages; -using Gremlin.Net.Process.Traversal; using Gremlin.Net.Structure.IO; using NSubstitute; using Xunit; @@ -85,12 +84,12 @@ public async Task ShouldSetContentTypeHeader() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(handler.CapturedRequest); - Assert.Equal(SerializationTokens.GraphBinary4MimeType, + Assert.Equal("application/json", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); } @@ -100,7 +99,7 @@ public async Task ShouldSetAcceptHeader() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -115,7 +114,7 @@ public async Task ShouldSendPostRequest() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -129,7 +128,7 @@ public async Task ShouldSendToCorrectUri() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -143,7 +142,7 @@ public async Task ShouldSetAcceptEncodingWhenCompressionEnabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -158,7 +157,7 @@ public async Task ShouldNotSetAcceptEncodingWhenCompressionDisabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = false }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -173,7 +172,7 @@ public async Task ShouldSetUserAgentWhenEnabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableUserAgentOnConnect = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -187,7 +186,7 @@ public async Task ShouldNotSetUserAgentWhenDisabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableUserAgentOnConnect = false }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -201,7 +200,7 @@ public async Task ShouldSetBulkResultsHeaderWhenEnabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { BulkResults = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -216,7 +215,7 @@ public async Task ShouldNotSetBulkResultsHeaderWhenDisabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { BulkResults = false }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -242,7 +241,7 @@ public async Task ShouldDecompressDeflateResponse() var (httpClient, handler) = CreateMockHttpClient(compressedBytes, "deflate"); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); // Should not throw — decompression should work var result = await connection.SubmitAsync(CreateTestRequest()); @@ -256,7 +255,7 @@ public void ShouldDisposeWithoutError() var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + var connection = new Connection(TestUri, serializer, settings, httpClient); connection.Dispose(); // Double dispose should not throw @@ -280,8 +279,6 @@ private static IMessageSerializer CreateMockSerializer( { var serializer = Substitute.For(); serializer.MimeType.Returns(mimeType); - serializer.SerializeMessageAsync(Arg.Any(), Arg.Any()) - .Returns(Task.FromResult(new byte[] { 0x84 })); serializer.DeserializeMessageAsync(Arg.Any(), Arg.Any()) .Returns(callInfo => ToAsyncEnumerable(results)); return serializer; @@ -311,7 +308,7 @@ public async Task ShouldCallInterceptorsInOrder() ctx => { callOrder.Add(3); return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -331,7 +328,7 @@ public async Task ShouldPropagateInterceptorException() _ => throw expectedException, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var ex = await Assert.ThrowsAsync( () => connection.SubmitAsync(CreateTestRequest())); @@ -357,7 +354,7 @@ public async Task ShouldAllowInterceptorToModifyHeaders() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -390,7 +387,7 @@ public async Task ShouldSeeEarlierInterceptorModifications() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -401,7 +398,7 @@ public async Task ShouldSeeEarlierInterceptorModifications() } [Fact] - public async Task ShouldSerializeBeforeInterceptorsWhenRequestSerializerProvided() + public async Task ShouldPassRequestMessageToInterceptorsBeforeSerialization() { var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -417,66 +414,43 @@ public async Task ShouldSerializeBeforeInterceptorsWhenRequestSerializerProvided }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); - Assert.IsType(observedBody); + Assert.IsType(observedBody); } [Fact] - public async Task ShouldPassRequestMessageWhenRequestSerializerIsNull() + public async Task ShouldThrowWhenBodyIsUnsupportedTypeAfterInterceptors() { - var (httpClient, _) = CreateMockHttpClient(); + var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - object? observedBody = null; var interceptors = new List> { ctx => { - observedBody = ctx.Body; - // Serialize the body so the request can proceed - ctx.Body = new byte[] { 0x84 }; - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; + ctx.Body = "unsupported type"; return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, - interceptors); - - await connection.SubmitAsync(CreateTestRequest()); - - Assert.IsType(observedBody); - } - - [Fact] - public async Task ShouldThrowWhenBodyIsNotByteArrayAfterInterceptors() - { - var (httpClient, handler) = CreateMockHttpClient(); - var serializer = CreateMockSerializer(); - var settings = new ConnectionSettings(); - - // No interceptor serializes the body - var interceptors = new List>(); - - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var ex = await Assert.ThrowsAsync( () => connection.SubmitAsync(CreateTestRequest())); - Assert.Contains("byte[] or HttpContent", ex.Message); - Assert.Contains("RequestMessage", ex.Message); + Assert.Contains("String", ex.Message); // HTTP request should not have been sent Assert.Null(handler.CapturedRequest); } [Fact] - public async Task ShouldSucceedWhenInterceptorSerializesBodyWithNullRequestSerializer() + public async Task ShouldSucceedWhenInterceptorPreSerializesBody() { var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -484,17 +458,15 @@ public async Task ShouldSucceedWhenInterceptorSerializesBodyWithNullRequestSeria var interceptors = new List> { - async ctx => + ctx => { - if (ctx.Body is RequestMessage msg) - { - ctx.Body = await serializer.SerializeMessageAsync(msg); - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; - } + // Interceptor calls SerializeBody() early (e.g. for signing) + ctx.SerializeBody(); + return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -504,7 +476,7 @@ public async Task ShouldSucceedWhenInterceptorSerializesBodyWithNullRequestSeria } [Fact] - public async Task ShouldNotSetContentTypeWhenRequestSerializerIsNull() + public async Task ShouldNotSetContentTypeBeforeInterceptorsRun() { var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -516,21 +488,18 @@ public async Task ShouldNotSetContentTypeWhenRequestSerializerIsNull() ctx => { hadContentType = ctx.Headers.ContainsKey("Content-Type"); - // Serialize so the request can proceed - ctx.Body = new byte[] { 0x84 }; - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); - Assert.False(hadContentType, "Content-Type should not be set before interceptors when requestSerializer is null"); + Assert.False(hadContentType, "Content-Type should not be set before interceptors run"); Assert.NotNull(handler.CapturedRequest); - Assert.Equal("application/vnd.graphbinary-v4.0", + Assert.Equal("application/json", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); } @@ -542,7 +511,7 @@ public async Task ShouldWorkWithEmptyInterceptorList() var settings = new ConnectionSettings(); var interceptors = new List>(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -558,7 +527,7 @@ public async Task ShouldWorkWithNoInterceptorsParameter() var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -583,7 +552,7 @@ public async Task ShouldAllowInterceptorToModifyUri() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -609,7 +578,7 @@ public async Task ShouldAllowInterceptorToReplaceBody() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -637,7 +606,7 @@ public async Task ShouldStopInterceptorChainOnException() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await Assert.ThrowsAsync( @@ -666,7 +635,7 @@ public async Task ShouldSupportAsyncInterceptors() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -678,7 +647,7 @@ public async Task ShouldSupportAsyncInterceptors() } [Fact] - public async Task ShouldAllowInterceptorToReadSerializedBody() + public async Task ShouldAllowInterceptorToCallSerializeBody() { var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -689,19 +658,19 @@ public async Task ShouldAllowInterceptorToReadSerializedBody() { ctx => { - capturedBody = ctx.Body as byte[]; + capturedBody = ctx.SerializeBody(); return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(capturedBody); - // The mock serializer returns { 0x84 } - Assert.Equal(new byte[] { 0x84 }, capturedBody); + // Should be valid JSON containing "gremlin" field + Assert.Contains("gremlin", System.Text.Encoding.UTF8.GetString(capturedBody!)); } [Fact] @@ -721,7 +690,7 @@ public async Task ShouldWorkWithSingleInterceptor() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -746,7 +715,7 @@ public async Task ShouldAllowInterceptorToRemoveHeader() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -772,7 +741,7 @@ public async Task ShouldThrowWhenBodyIsNullAfterInterceptors() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var ex = await Assert.ThrowsAsync( @@ -808,7 +777,7 @@ public async Task ShouldPreserveMultipleInterceptorHeaderModifications() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -820,7 +789,7 @@ public async Task ShouldPreserveMultipleInterceptorHeaderModifications() } [Fact] - public async Task ShouldAllowCustomContentTypeWhenRequestSerializerIsNull() + public async Task ShouldAllowInterceptorToOverrideContentType() { var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -830,53 +799,37 @@ public async Task ShouldAllowCustomContentTypeWhenRequestSerializerIsNull() { ctx => { + // Pre-serialize with custom content type ctx.Body = new byte[] { 0x01 }; - ctx.Headers["Content-Type"] = "application/json"; + ctx.Headers["Content-Type"] = "application/custom"; return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(handler.CapturedRequest); - Assert.Equal("application/json", + Assert.Equal("application/custom", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); } [Fact] - public async Task ShouldUseResponseSerializerWhenRequestSerializerIsNull() + public async Task ShouldUseResponseSerializerForDeserialization() { var (httpClient, _) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer(); var responseSerializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - var interceptors = new List> - { - async ctx => - { - if (ctx.Body is RequestMessage msg) - { - ctx.Body = await requestSerializer.SerializeMessageAsync(msg); - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; - } - }, - }; - - using var connection = new Connection(TestUri, null, responseSerializer, settings, httpClient, - interceptors); + using var connection = new Connection(TestUri, responseSerializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); // Verify the response serializer was called for deserialization responseSerializer.Received(1) .DeserializeMessageAsync(Arg.Any(), Arg.Any()); - // Verify the request serializer was NOT called by Connection (interceptor called it directly) - requestSerializer.DidNotReceive() - .DeserializeMessageAsync(Arg.Any(), Arg.Any()); } [Fact] @@ -896,7 +849,7 @@ public async Task ShouldAcceptHttpContentBodyFromInterceptor() }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -910,10 +863,9 @@ public async Task ShouldAcceptHttpContentBodyFromInterceptor() public async Task ShouldUseResponseSerializerMimeTypeForAcceptHeader() { var (httpClient, handler) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer("application/custom-request"); var responseSerializer = CreateMockSerializer("application/custom-response"); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); + using var connection = new Connection(TestUri, responseSerializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -923,39 +875,18 @@ public async Task ShouldUseResponseSerializerMimeTypeForAcceptHeader() } [Fact] - public async Task ShouldUseRequestSerializerMimeTypeForContentTypeHeader() - { - var (httpClient, handler) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer("application/custom-request"); - var responseSerializer = CreateMockSerializer("application/custom-response"); - var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); - - await connection.SubmitAsync(CreateTestRequest()); - - Assert.NotNull(handler.CapturedRequest); - Assert.Equal("application/custom-request", - handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); - } - - [Fact] - public async Task ShouldUseDifferentMimeTypesForRequestAndResponseSerializers() + public async Task ShouldSetContentTypeToApplicationJson() { var (httpClient, handler) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer("application/vnd.custom-request-v1.0"); - var responseSerializer = CreateMockSerializer("application/vnd.custom-response-v2.0"); + var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(handler.CapturedRequest); - // Content-Type comes from request serializer - Assert.Equal("application/vnd.custom-request-v1.0", + Assert.Equal("application/json", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); - // Accept comes from response serializer - Assert.Contains(handler.CapturedRequest.Headers.Accept, - h => h.MediaType == "application/vnd.custom-response-v2.0"); } [Fact] @@ -964,7 +895,7 @@ public async Task ShouldReturnStreamingResultSet() var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(new List { "hello", 42 }); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -980,7 +911,7 @@ public async Task ShouldReturnEmptyResultSetWhenNoResults() var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(new List()); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -1001,7 +932,7 @@ public async Task ShouldStreamResponseWithoutFullBuffering() var httpClient = new HttpClient(handler); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs index 7d832e96e0a..e7ea25f6480 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs @@ -133,7 +133,7 @@ public async Task ShouldDefaultBulkResultsToTrue() await connection.SubmitAsync(gl); Assert.NotNull(capturedRequest); - Assert.Equal("true", capturedRequest!.Fields[Tokens.ArgsBulkResults]); + Assert.Equal(true, capturedRequest!.Fields[Tokens.ArgsBulkResults]); } [Fact] @@ -189,13 +189,13 @@ public async Task ShouldRespectPerRequestBulkResults() gl.AddStep("V", Array.Empty()); gl.OptionsStrategies.Add(new OptionsStrategy(new Dictionary { - { Tokens.ArgsBulkResults, "false" } + { Tokens.ArgsBulkResults, false } })); await connection.SubmitAsync(gl); Assert.NotNull(capturedRequest); - Assert.Equal("false", capturedRequest!.Fields[Tokens.ArgsBulkResults]); + Assert.Equal(false, capturedRequest!.Fields[Tokens.ArgsBulkResults]); } /// diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs index f187ada59cd..58442eb7d12 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs @@ -21,6 +21,8 @@ #endregion +using System; +using System.Threading.Tasks; using Gremlin.Net.Driver; using Gremlin.Net.Structure.IO.GraphBinary4; using Xunit; @@ -37,10 +39,10 @@ public void ShouldCreateClientWithDefaultSettings() } [Fact] - public void ShouldCreateClientWithCustomSerializer() + public void ShouldCreateClientWithCustomResponseSerializer() { var serializer = new GraphBinary4MessageSerializer(); - using var client = new GremlinClient(new GremlinServer(), messageSerializer: serializer); + using var client = new GremlinClient(new GremlinServer(), responseSerializer: serializer); // Should not throw } @@ -64,5 +66,21 @@ public void ShouldDisposeWithoutError() // Should not throw on double dispose client.Dispose(); } + + [Fact] + public void ShouldAppendAuthInterceptorToEndOfList() + { + Func userInterceptor = ctx => Task.CompletedTask; + Func auth = ctx => Task.CompletedTask; + + using var client = new GremlinClient(new GremlinServer(), + auth: auth, + interceptors: new[] { userInterceptor }); + + var interceptors = client.Connection.Interceptors; + Assert.Equal(2, interceptors.Count); + Assert.Same(userInterceptor, interceptors[0]); + Assert.Same(auth, interceptors[1]); + } } } diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs index 86d2bc678d3..10522026445 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs @@ -24,6 +24,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Text.Json; using Gremlin.Net.Driver; using Gremlin.Net.Driver.Messages; using Xunit; @@ -37,7 +38,7 @@ public void ShouldConstructWithByteArrayBody() { var method = "POST"; var uri = new Uri("http://localhost:8182/gremlin"); - var headers = new Dictionary { { "Content-Type", "application/vnd.graphbinary-v4.0" } }; + var headers = new Dictionary { { "Content-Type", "application/json" } }; var body = new byte[] { 0x01, 0x02, 0x03 }; var context = new HttpRequestContext(method, uri, headers, body); @@ -127,5 +128,170 @@ public void ShouldThrowWhenComputingPayloadHashForNullBody() Assert.Contains("null", ex.Message); } + + #region SerializeBody Tests + + [Fact] + public void SerializeBodyShouldReturnJsonBytesForRequestMessage() + { + var message = RequestMessage.Build("g.V().has('name','marko')").AddG("g").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + Assert.IsType(context.Body); + Assert.Same(context.Body, result); + + var json = JsonDocument.Parse(result); + Assert.Equal("g.V().has('name','marko')", json.RootElement.GetProperty("gremlin").GetString()); + Assert.Equal("g", json.RootElement.GetProperty("g").GetString()); + Assert.Equal("gremlin-lang", json.RootElement.GetProperty("language").GetString()); + } + + [Fact] + public void SerializeBodyShouldSetContentTypeHeader() + { + var message = RequestMessage.Build("g.V()").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + context.SerializeBody(); + + Assert.Equal("application/json", context.Headers["Content-Type"]); + } + + [Fact] + public void SerializeBodyShouldSetContentLengthHeader() + { + var message = RequestMessage.Build("g.V()").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + Assert.Equal(result.Length.ToString(), context.Headers["Content-Length"]); + } + + [Fact] + public void SerializeBodyShouldBeIdempotentWhenBodyIsBytes() + { + var originalBytes = new byte[] { 0x01, 0x02, 0x03 }; + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), originalBytes); + + var result = context.SerializeBody(); + + Assert.Same(originalBytes, result); + Assert.Same(originalBytes, context.Body); + } + + [Fact] + public void SerializeBodyShouldBeIdempotentOnMultipleCalls() + { + var message = RequestMessage.Build("g.V()").AddG("g").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var first = context.SerializeBody(); + var second = context.SerializeBody(); + + Assert.Same(first, second); + } + + [Fact] + public void SerializeBodyShouldIncludeAllFields() + { + var message = RequestMessage.Build("g.V()") + .AddG("g") + .AddLanguage("gremlin-lang") + .AddBatchSize(100) + .AddEvaluationTimeout(30000) + .Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("g.V()", json.RootElement.GetProperty("gremlin").GetString()); + Assert.Equal("g", json.RootElement.GetProperty("g").GetString()); + Assert.Equal("gremlin-lang", json.RootElement.GetProperty("language").GetString()); + Assert.Equal(100, json.RootElement.GetProperty("batchSize").GetInt32()); + Assert.Equal(30000, json.RootElement.GetProperty("evaluationTimeout").GetInt32()); + } + + [Fact] + public void SerializeBodyShouldThrowForUnsupportedType() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), "unsupported"); + + var ex = Assert.Throws(() => context.SerializeBody()); + + Assert.Contains("String", ex.Message); + } + + [Fact] + public void SerializeBodyShouldThrowForNullBody() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), null!); + + var ex = Assert.Throws(() => context.SerializeBody()); + + Assert.Contains("null", ex.Message); + } + + [Fact] + public void SerializeBodyShouldReflectFieldMutations() + { + var message = RequestMessage.Build("g.V()").AddG("g").Create(); + // Mutate fields before serialization (simulating an interceptor) + message.Fields["customField"] = "customValue"; + + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("customValue", json.RootElement.GetProperty("customField").GetString()); + } + + [Fact] + public void SerializeBodyShouldReflectBodyReplacement() + { + var original = RequestMessage.Build("g.V()").AddG("g").Create(); + var replacement = RequestMessage.Build("g.E()").AddG("traversal").Create(); + + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), original); + + // Interceptor replaces the body + context.Body = replacement; + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("g.E()", json.RootElement.GetProperty("gremlin").GetString()); + Assert.Equal("traversal", json.RootElement.GetProperty("g").GetString()); + } + + [Fact] + public void SerializeBodyShouldIncludeBindingsField() + { + var message = RequestMessage.Build("g.V(x)") + .AddBindingsString("[x:1,y:'marko']") + .Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("[x:1,y:'marko']", json.RootElement.GetProperty("bindings").GetString()); + } + + #endregion } } diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs index 41298a6f7e6..dbee0f209f5 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs @@ -41,7 +41,7 @@ public async Task ShouldSerializeRequestMessageStartingWithVersionByte() var actual = await serializer.SerializeMessageAsync(msg); - // First byte should be version byte 0x84 — no MIME prefix + // First byte should be version byte 0x84, no MIME prefix Assert.Equal(0x84, actual[0]); } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java index d4a3967f029..a40e0f6a725 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java @@ -29,10 +29,8 @@ import io.netty.util.concurrent.Future; import org.apache.commons.configuration2.Configuration; import org.apache.commons.lang3.concurrent.BasicThreadFactory; -import org.apache.commons.lang3.tuple.Pair; import org.apache.tinkerpop.gremlin.driver.auth.Auth; import org.apache.tinkerpop.gremlin.driver.exception.NoHostAvailableException; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; import org.apache.tinkerpop.gremlin.driver.remote.HttpRemoteTransaction; import org.apache.tinkerpop.gremlin.structure.Transaction; import org.apache.tinkerpop.gremlin.util.MessageSerializer; @@ -63,7 +61,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.LinkedList; import java.util.List; import java.util.Optional; import java.util.Set; @@ -86,7 +83,6 @@ * @author Stephen Mallette (http://stephen.genoprime.com) */ public final class Cluster { - public static final String SERIALIZER_INTERCEPTOR_NAME = "serializer"; private static final Logger logger = LoggerFactory.getLogger(Cluster.class); private final Manager manager; @@ -141,10 +137,6 @@ public static Builder build(final String address) { return new Builder(address); } - public static Builder build(final RequestInterceptor serializingInterceptor) { - return new Builder(serializingInterceptor); - } - public static Builder build(final File configurationFile) throws FileNotFoundException { final Settings settings = Settings.read(new FileInputStream(configurationFile)); return getBuilderFromSettings(settings); @@ -370,8 +362,8 @@ MessageSerializer getSerializer() { return manager.serializer; } - List> getRequestInterceptors() { - return manager.interceptors; + List getRequestInterceptors() { + return Collections.unmodifiableList(manager.interceptors); } ScheduledExecutorService executor() { @@ -483,7 +475,6 @@ public boolean isBulkResultsEnabled() { } public final static class Builder { - private static int INTERCEPTOR_NOT_FOUND = -1; private final List addresses = new ArrayList<>(); private int port = 8182; @@ -510,25 +501,18 @@ public final static class Builder { private boolean sslSkipCertValidation = false; private SslContext sslContext = null; private LoadBalancingStrategy loadBalancingStrategy = new LoadBalancingStrategy.RoundRobin(); - private LinkedList> interceptors = new LinkedList<>(); + private List interceptors = new ArrayList<>(); + private Auth auth = null; private long connectionSetupTimeoutMillis = Connection.CONNECTION_SETUP_TIMEOUT_MILLIS; private long idleConnectionTimeoutMillis = Connection.CONNECTION_IDLE_TIMEOUT_MILLIS; private boolean enableUserAgentOnConnect = true; private boolean bulkResults = false; private Builder() { - addInterceptor(SERIALIZER_INTERCEPTOR_NAME, - new PayloadSerializingInterceptor(new GraphBinaryMessageSerializerV4())); } private Builder(final String address) { addContactPoint(address); - addInterceptor(SERIALIZER_INTERCEPTOR_NAME, - new PayloadSerializingInterceptor(new GraphBinaryMessageSerializerV4())); - } - - private Builder(final RequestInterceptor bodySerializer) { - addInterceptor(SERIALIZER_INTERCEPTOR_NAME, bodySerializer); } /** @@ -751,82 +735,28 @@ public Builder loadBalancingStrategy(final LoadBalancingStrategy loadBalancingSt } /** - * Adds a {@link RequestInterceptor} after another one that will allow manipulation of the {@code HttpRequest} - * prior to its being sent to the server. + * Sets the list of {@link RequestInterceptor} instances that will be run in order to allow + * modification of the {@link HttpRequest} prior to its being sent to the server. */ - public Builder addInterceptorAfter(final String priorInterceptorName, final String nameOfInterceptor, - final RequestInterceptor interceptor) { - final int index = getInterceptorIndex(priorInterceptorName); - if (INTERCEPTOR_NOT_FOUND == index) { - throw new IllegalArgumentException(priorInterceptorName + " interceptor not found"); - } else if (getInterceptorIndex(nameOfInterceptor) != INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(nameOfInterceptor + " interceptor already exists"); - } - interceptors.add(index + 1, Pair.of(nameOfInterceptor, interceptor)); - + public Builder interceptors(final List interceptors) { + this.interceptors = new ArrayList<>(interceptors); return this; } /** - * Adds a {@link RequestInterceptor} before another one that will allow manipulation of the {@code HttpRequest} - * prior to its being sent to the server. + * Sets the list of {@link RequestInterceptor} instances that will be run in order. */ - public Builder addInterceptorBefore(final String subsequentInterceptorName, final String nameOfInterceptor, - final RequestInterceptor interceptor) { - final int index = getInterceptorIndex(subsequentInterceptorName); - if (INTERCEPTOR_NOT_FOUND == index) { - throw new IllegalArgumentException(subsequentInterceptorName + " interceptor not found"); - } else if (getInterceptorIndex(nameOfInterceptor) != INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(nameOfInterceptor + " interceptor already exists"); - } else if (index == 0) { - interceptors.addFirst(Pair.of(nameOfInterceptor, interceptor)); - } else { - interceptors.add(index - 1, Pair.of(nameOfInterceptor, interceptor)); - } - + public Builder interceptors(final RequestInterceptor... interceptors) { + this.interceptors = new ArrayList<>(List.of(interceptors)); return this; } /** - * Adds a {@link RequestInterceptor} to the end of the list that will allow manipulation of the - * {@code HttpRequest} prior to its being sent to the server. - */ - public Builder addInterceptor(final String name, final RequestInterceptor interceptor) { - if (getInterceptorIndex(name) != INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(name + " interceptor already exists"); - } - interceptors.add(Pair.of(name, interceptor)); - return this; - } - - /** - * Removes a {@link RequestInterceptor} from the list. This can be used to remove the default interceptors that - * aren't needed. - */ - public Builder removeInterceptor(final String name) { - final int index = getInterceptorIndex(name); - if (index == INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(name + " interceptor not found"); - } - interceptors.remove(index); - return this; - } - - private int getInterceptorIndex(final String name) { - for (int i = 0; i < interceptors.size(); i++) { - if (interceptors.get(i).getLeft().equals(name)) { - return i; - } - } - - return INTERCEPTOR_NOT_FOUND; - } - - /** - * Adds an Auth {@link RequestInterceptor} to the end of list of interceptors. + * Adds an Auth {@link RequestInterceptor} that will always be appended to the end of the interceptor list + * when the {@link Cluster} is created, regardless of the order in which builder methods are called. */ public Builder auth(final Auth auth) { - addInterceptor(auth.getClass().getSimpleName().toLowerCase() + "-auth", auth); + this.auth = auth; return this; } @@ -906,6 +836,7 @@ List getContactPoints() { public Cluster create() { if (addresses.isEmpty()) addContactPoint("localhost"); if (null == serializer) serializer = Serializers.GRAPHBINARY_V4.simpleInstance(); + if (null != auth) interceptors.add(auth); return new Cluster(this); } } @@ -946,7 +877,7 @@ class Manager { private final LoadBalancingStrategy loadBalancingStrategy; private final Optional sslContextOptional; private final Supplier validationRequest; - private final List> interceptors; + private final List interceptors; /** * Thread pool for requests. @@ -985,7 +916,7 @@ private Manager(final Builder builder) { this.loadBalancingStrategy = builder.loadBalancingStrategy; this.contactPoints = builder.getContactPoints(); - this.interceptors = builder.interceptors; + this.interceptors = Collections.unmodifiableList(new ArrayList<>(builder.interceptors)); this.enableUserAgentOnConnect = builder.enableUserAgentOnConnect; this.bulkResults = builder.bulkResults; diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java index 253274940f1..ebe5011b7e7 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java @@ -18,16 +18,27 @@ */ package org.apache.tinkerpop.gremlin.driver; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; +import org.apache.tinkerpop.shaded.jackson.core.JsonProcessingException; +import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper; + import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; import java.util.Map; /** - * HttpRequest represents the data that will be used to create the actual request to the remote endpoint. It will be - * passed to different {@link RequestInterceptor} that can update its values. The body can be anything as the - * interceptor may change what the payload is. Also contains some convenience Strings for common HTTP header key and - * values and HTTP methods. + * Represents the HTTP request that will be sent to the server. It is passed through the + * {@link RequestInterceptor} chain where interceptors can modify headers, body, URI, and method. + *

+ * The body starts as a {@link RequestMessage} and can be serialized to JSON bytes via {@link #serializeBody()}. + * After all interceptors run, if the body is still a {@code RequestMessage}, the driver will call + * {@code serializeBody()} automatically before sending. */ public class HttpRequest { + + private static final ObjectMapper mapper = new ObjectMapper(); + public static class Headers { // Add as needed. Headers are case-insensitive; lower case for now to match Netty. public static final String ACCEPT = "accept"; @@ -79,7 +90,7 @@ public Map headers() { /** * Get the body of the request. * - * @return an Object representing the body. + * @return an Object representing the body ({@link RequestMessage} or {@code byte[]}). */ public Object getBody() { return body; @@ -104,8 +115,7 @@ public String getMethod() { } /** - * Set the HTTP body of the request. During processing, the body can be any type but the final interceptor must set - * the body to a {@code byte[]}. + * Set the HTTP body of the request. * * @return this HttpRequest for method chaining. */ @@ -133,4 +143,45 @@ public HttpRequest setUri(final URI uri) { this.uri = uri; return this; } + + /** + * Serialize the body to JSON bytes if it is still a {@link RequestMessage}. If the body is already + * {@code byte[]}, this method is idempotent and returns the existing bytes. This method also sets the + * {@code Content-Type} header to {@code application/json} and the {@code Content-Length} header to the + * byte length of the serialized body. + *

+ * Interceptors that need the serialized payload (e.g., for computing a signature hash) should call + * this method rather than serializing independently. + * + * @return the serialized body bytes + * @throws IllegalStateException if the body is neither a {@link RequestMessage} nor {@code byte[]} + */ + public byte[] serializeBody() { + if (body instanceof byte[]) { + return (byte[]) body; + } + + if (!(body instanceof RequestMessage)) { + throw new IllegalStateException("Cannot serialize body of type " + + (body == null ? "null" : body.getClass().getSimpleName()) + + ". Expected RequestMessage or byte[]."); + } + + final RequestMessage requestMessage = (RequestMessage) body; + + // Build JSON map: gremlin is top-level, plus all fields from the message + final Map jsonMap = new LinkedHashMap<>(); + jsonMap.put("gremlin", requestMessage.getGremlin()); + jsonMap.putAll(requestMessage.getFields()); + + try { + final byte[] jsonBytes = mapper.writeValueAsBytes(jsonMap); + this.body = jsonBytes; + this.headers.put(Headers.CONTENT_TYPE, "application/json"); + this.headers.put(Headers.CONTENT_LENGTH, String.valueOf(jsonBytes.length)); + return jsonBytes; + } catch (JsonProcessingException e) { + throw new IllegalStateException("Failed to serialize RequestMessage to JSON", e); + } + } } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java index 74055303321..0a4a3d49706 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java @@ -18,14 +18,20 @@ */ package org.apache.tinkerpop.gremlin.driver; -import java.util.function.UnaryOperator; - /** - * Interceptors are run as a list to allow modification of the HTTP request before it is sent to the server. The first - * interceptor will be provided with a {@link HttpRequest} that holds a - * {@link org.apache.tinkerpop.gremlin.util.message.RequestMessage} in the body. The final interceptor should contain a - * {@code byte[]} in the body. + * Interceptors are run as an ordered list to allow modification of the {@link HttpRequest} before it is sent to the + * server. The interceptor receives an {@link HttpRequest} whose body starts as a + * {@link org.apache.tinkerpop.gremlin.util.message.RequestMessage}. Interceptors mutate the request in place. + * After all interceptors run, if the body is still a {@code RequestMessage} the driver will auto-serialize it to JSON. + * Interceptors that need the serialized bytes (e.g., for signing) should call {@link HttpRequest#serializeBody()}. */ -public interface RequestInterceptor extends UnaryOperator { +@FunctionalInterface +public interface RequestInterceptor { + /** + * Intercept and mutate the HTTP request before it is sent. + * + * @param httpRequest the mutable HTTP request + */ + void intercept(HttpRequest httpRequest); } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java index f5a48c21d6e..f33173958e0 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java @@ -18,10 +18,9 @@ */ package org.apache.tinkerpop.gremlin.driver.auth; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpHeaderNames; import org.apache.tinkerpop.gremlin.driver.HttpRequest; +import java.nio.charset.StandardCharsets; import java.util.Base64; public class Basic implements Auth { @@ -35,10 +34,9 @@ public Basic(final String username, final String password) { } @Override - public HttpRequest apply(final HttpRequest httpRequest) { + public void intercept(final HttpRequest httpRequest) { final String valueToEncode = username + ":" + password; httpRequest.headers().put(HttpRequest.Headers.AUTHORIZATION, - "Basic " + Base64.getEncoder().encodeToString(valueToEncode.getBytes())); - return httpRequest; + "Basic " + Base64.getEncoder().encodeToString(valueToEncode.getBytes(StandardCharsets.UTF_8))); } } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java index fb51aa80736..a19e75d7499 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.tinkerpop.gremlin.driver.HttpRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,8 +48,9 @@ import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_SECURITY_TOKEN; /** - * A {@link org.apache.tinkerpop.gremlin.driver.RequestInterceptor} that provides headers required for SigV4. Because - * the signing process requires final header and body data, this interceptor should almost always be last. + * A {@link org.apache.tinkerpop.gremlin.driver.RequestInterceptor} that signs requests with AWS SigV4. + * This interceptor calls {@link HttpRequest#serializeBody()} to ensure the body is serialized before + * computing the payload hash for signing. It should typically be the last interceptor in the chain. */ public class Sigv4 implements Auth { private static final Logger logger = LoggerFactory.getLogger(Sigv4.class); @@ -70,8 +72,11 @@ public Sigv4(final String regionName, final AwsCredentialsProvider awsCredential } @Override - public HttpRequest apply(final HttpRequest httpRequest) { + public void intercept(final HttpRequest httpRequest) { try { + // Ensure the body is serialized to bytes so we can compute the payload hash. + httpRequest.serializeBody(); + final ContentStreamProvider content = toContentStream(httpRequest); // Convert Http request into an AWS SDK signable request final SdkHttpRequest awsSignableRequest = toSignableRequest(httpRequest); @@ -91,7 +96,6 @@ public HttpRequest apply(final HttpRequest httpRequest) { logger.error("Error signing HTTP request: {}", ex.getMessage(), ex); throw new AuthenticationException(ex); } - return httpRequest; } private void setSessionToken(final Map headers, final AwsCredentials credentials) { @@ -123,9 +127,6 @@ private String getSingleHeaderValue(final Map> headers, fin private ContentStreamProvider toContentStream(final HttpRequest httpRequest) { // carry over the entity (or an empty entity, if no entity is provided) - if (!(httpRequest.getBody() instanceof byte[])) { - throw new IllegalArgumentException("Expected byte[] in HttpRequest body but got " + httpRequest.getBody().getClass()); - } final byte[] body = (byte[]) httpRequest.getBody(); return (body.length != 0) ? ContentStreamProvider.fromByteArray(body) : ContentStreamProvider.fromUtf8String(""); } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java index 0f8ede4c7fd..c1f5eca0ea5 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java @@ -29,7 +29,6 @@ import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; -import org.apache.commons.lang3.tuple.Pair; import org.apache.tinkerpop.gremlin.driver.HttpRequest; import org.apache.tinkerpop.gremlin.driver.RequestInterceptor; import org.apache.tinkerpop.gremlin.driver.UserAgent; @@ -49,7 +48,7 @@ import static org.apache.tinkerpop.gremlin.driver.handler.InactiveChannelHandler.REQUEST_SENT; /** - * Converts {@link RequestMessage} to a {@code HttpRequest}. + * Converts {@link RequestMessage} to an HTTP request, running interceptors and serializing the body to JSON. */ @ChannelHandler.Sharable public final class HttpGremlinRequestEncoder extends MessageToMessageEncoder { @@ -57,11 +56,11 @@ public final class HttpGremlinRequestEncoder extends MessageToMessageEncoder serializer; private final boolean userAgentEnabled; private final boolean bulkResults; - private final List> interceptors; + private final List interceptors; private final URI uri; public HttpGremlinRequestEncoder(final MessageSerializer serializer, - final List> interceptors, + final List interceptors, final boolean userAgentEnabled, boolean bulkResults, final URI uri) { this.serializer = serializer; this.interceptors = interceptors; @@ -72,7 +71,6 @@ public HttpGremlinRequestEncoder(final MessageSerializer serializer, @Override protected void encode(final ChannelHandlerContext channelHandlerContext, final RequestMessage requestMessage, final List objects) throws Exception { - final String mimeType = serializer.mimeTypesSupported()[0]; if (requestMessage.getField("gremlin") instanceof GremlinLang) { throw new ResponseException(HttpResponseStatus.BAD_REQUEST, String.format( "An error occurred during serialization of this request [%s] - it could not be sent to the server - Reason: GremlinLang is not intended to be send as query.", @@ -81,9 +79,10 @@ protected void encode(final ChannelHandlerContext channelHandlerContext, final R final InetSocketAddress remoteAddress = getRemoteAddress(channelHandlerContext.channel()); try { - Map headersMap = new HashMap<>(); + final Map headersMap = new HashMap<>(); headersMap.put(HttpRequest.Headers.HOST, remoteAddress.getAddress().getHostAddress()); - headersMap.put(HttpRequest.Headers.ACCEPT, mimeType); + // Accept header uses the response serializer's mime type (GraphBinary for responses) + headersMap.put(HttpRequest.Headers.ACCEPT, serializer.mimeTypesSupported()[0]); headersMap.put(HttpRequest.Headers.ACCEPT_ENCODING, HttpRequest.Headers.DEFLATE); if (userAgentEnabled) { headersMap.put(HttpRequest.Headers.USER_AGENT, UserAgent.USER_AGENT); @@ -91,21 +90,26 @@ protected void encode(final ChannelHandlerContext channelHandlerContext, final R if (bulkResults) { headersMap.put(Tokens.BULK_RESULTS, "true"); } - - // Add X-Transaction-Id header to comply with specification's dual transmission (header and body) + + // Promote transactionId to HTTP header for dual transmission (header and body) final String transactionId = requestMessage.getField(Tokens.ARGS_TRANSACTION_ID); if (transactionId != null) { headersMap.put(Tokens.Headers.TRANSACTION_ID, transactionId); } - - HttpRequest gremlinRequest = new HttpRequest(headersMap, requestMessage, uri); - for (final Pair interceptor : interceptors) { - gremlinRequest = interceptor.getRight().apply(gremlinRequest); + final HttpRequest gremlinRequest = new HttpRequest(headersMap, requestMessage, uri); + + for (final RequestInterceptor interceptor : interceptors) { + interceptor.intercept(gremlinRequest); } + // Auto-serialize if interceptors did not already produce bytes + gremlinRequest.serializeBody(); + + final byte[] bodyBytes = (byte[]) gremlinRequest.getBody(); + final FullHttpRequest finalRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, - uri.getPath(), convertBody(gremlinRequest)); + uri.getPath(), Unpooled.wrappedBuffer(bodyBytes)); gremlinRequest.headers().forEach((k, v) -> finalRequest.headers().add(k, v)); objects.add(finalRequest); @@ -128,15 +132,4 @@ private static InetSocketAddress getRemoteAddress(Channel channel) { } return remoteAddress; } - - private static ByteBuf convertBody(final HttpRequest request) { - final Object body = request.getBody(); - if (body instanceof byte[]) { - request.headers().put(HttpRequest.Headers.CONTENT_LENGTH, String.valueOf(((byte[]) body).length)); - return Unpooled.wrappedBuffer((byte[]) body); - } else { - throw new IllegalArgumentException("Final body must be byte[] but found " - + body.getClass().getSimpleName()); - } - } } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/interceptor/PayloadSerializingInterceptor.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/interceptor/PayloadSerializingInterceptor.java deleted file mode 100644 index ae35bc05af4..00000000000 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/interceptor/PayloadSerializingInterceptor.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.tinkerpop.gremlin.driver.interceptor; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.handler.codec.http.HttpResponseStatus; -import org.apache.tinkerpop.gremlin.driver.HttpRequest; -import org.apache.tinkerpop.gremlin.driver.RequestInterceptor; -import org.apache.tinkerpop.gremlin.driver.exception.ResponseException; -import org.apache.tinkerpop.gremlin.util.MessageSerializer; -import org.apache.tinkerpop.gremlin.util.message.RequestMessage; -import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; -import org.apache.tinkerpop.gremlin.util.ser.SerializationException; -import org.apache.tinkerpop.gremlin.util.ser.Serializers; - -import java.util.Map; - -/** - * A {@link RequestInterceptor} that serializes the request body usng the provided {@link MessageSerializer}. This - * interceptor should be run before other interceptors that need to calculate values based on the request body. - */ -public class PayloadSerializingInterceptor implements RequestInterceptor { - // Should be thread-safe as the GraphBinaryWriter/GraphSONMessageSerializer doesn't maintain state. - private final MessageSerializer serializer; - - public PayloadSerializingInterceptor(final MessageSerializer serializer) { - this.serializer = serializer; - } - - @Override - public HttpRequest apply(HttpRequest httpRequest) { - if (!(httpRequest.getBody() instanceof RequestMessage)) { - throw new IllegalArgumentException("Only RequestMessage serialization is supported"); - } - - final RequestMessage request = (RequestMessage) httpRequest.getBody(); - final ByteBuf requestBuf; - try { - requestBuf = serializer.serializeRequestAsBinary(request, ByteBufAllocator.DEFAULT); - } catch (SerializationException se) { - throw new RuntimeException(new ResponseException(HttpResponseStatus.BAD_REQUEST, String.format( - "An error occurred during serialization of this request [%s] - it could not be sent to the server - Reason: %s", - request, se))); - } - - // Convert from ByteBuf to bytes[] because that's what the final request body should contain. - final byte[] requestBytes = ByteBufUtil.getBytes(requestBuf); - requestBuf.release(); - - httpRequest.setBody(requestBytes); - httpRequest.headers().put(HttpRequest.Headers.CONTENT_TYPE, serializer.mimeTypesSupported()[0]); - - return httpRequest; - } -} diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java index 984d6c7817f..faa7aac7df0 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java @@ -23,12 +23,10 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; -import org.apache.commons.lang3.tuple.Pair; import org.apache.tinkerpop.gremlin.driver.Channelizer; import org.apache.tinkerpop.gremlin.driver.handler.HttpContentDecompressionHandler; import org.apache.tinkerpop.gremlin.driver.handler.HttpGremlinResponseStreamDecoder; import org.apache.tinkerpop.gremlin.driver.handler.HttpGremlinRequestEncoder; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; import org.apache.tinkerpop.gremlin.util.MessageSerializer; import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import io.netty.bootstrap.Bootstrap; @@ -109,9 +107,7 @@ protected void initChannel(final SocketChannel ch) { new HttpContentDecompressionHandler(), new HttpGremlinResponseStreamDecoder(serializer, Integer.MAX_VALUE), new HttpGremlinRequestEncoder(serializer, - Collections.singletonList( - Pair.of("serializer", new PayloadSerializingInterceptor( - new GraphBinaryMessageSerializerV4()))), + Collections.emptyList(), false, false, uri), callbackResponseHandler); } diff --git a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/ClusterTest.java b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/ClusterTest.java deleted file mode 100644 index 0edc12cd0ae..00000000000 --- a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/ClusterTest.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.tinkerpop.gremlin.driver; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; -import org.junit.Test; - -import java.util.List; - -import static org.apache.tinkerpop.gremlin.driver.Cluster.SERIALIZER_INTERCEPTOR_NAME; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.core.IsInstanceOf.instanceOf; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -/** - * Test Cluster and Cluster.Builder. - */ -public class ClusterTest { - private static final RequestInterceptor TEST_INTERCEPTOR = httpRequest -> new HttpRequest(null, null, null); - - @Test - public void shouldNotAllowModifyingRelativeToNonExistentInterceptor() { - try { - Cluster.build().addInterceptorAfter("none", "test", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("none interceptor not found", e.getMessage()); - } - - try { - Cluster.build().addInterceptorBefore("none", "test", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("none interceptor not found", e.getMessage()); - } - - try { - Cluster.build().removeInterceptor("nonexistent"); - fail("Should not have allowed interceptor to be removed."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("nonexistent interceptor not found", e.getMessage()); - } - } - - @Test - public void shouldAddToInterceptorToBeginningIfBeforeFirst() { - final Cluster testCluster = Cluster.build() - .addInterceptor("b", req -> req) - .addInterceptorBefore(SERIALIZER_INTERCEPTOR_NAME, "a", TEST_INTERCEPTOR) - .create(); - assertEquals("a", testCluster.getRequestInterceptors().get(0).getLeft()); - assertEquals(TEST_INTERCEPTOR, testCluster.getRequestInterceptors().get(0).getRight()); - assertEquals(SERIALIZER_INTERCEPTOR_NAME, testCluster.getRequestInterceptors().get(1).getLeft()); - } - - @Test - public void shouldAddToInterceptorAfter() { - final Cluster testCluster = Cluster.build() - .addInterceptor("b", req -> req) - .addInterceptorAfter(SERIALIZER_INTERCEPTOR_NAME, "a", TEST_INTERCEPTOR) - .create(); - assertEquals(SERIALIZER_INTERCEPTOR_NAME, testCluster.getRequestInterceptors().get(0).getLeft()); - assertEquals("a", testCluster.getRequestInterceptors().get(1).getLeft()); - assertEquals(TEST_INTERCEPTOR, testCluster.getRequestInterceptors().get(1).getRight()); - assertEquals("b", testCluster.getRequestInterceptors().get(2).getLeft()); - - } - - @Test - public void shouldAddToInterceptorLast() { - final Cluster testCluster = Cluster.build() - .addInterceptor("c", req -> req) - .addInterceptor("b", req -> req) - .addInterceptor("a", req -> req) - .create(); - assertEquals(SERIALIZER_INTERCEPTOR_NAME, testCluster.getRequestInterceptors().get(0).getLeft()); - assertEquals("c", testCluster.getRequestInterceptors().get(1).getLeft()); - assertEquals("b", testCluster.getRequestInterceptors().get(2).getLeft()); - assertEquals("a", testCluster.getRequestInterceptors().get(3).getLeft()); - } - - @Test - public void shouldNotAllowAddingDuplicateName() { - try { - Cluster.build().addInterceptor("name", req -> req).addInterceptor("name", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("name interceptor already exists", e.getMessage()); - } - - try { - Cluster.build().addInterceptor("name", req -> req).addInterceptorAfter("name", "name", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("name interceptor already exists", e.getMessage()); - } - - try { - Cluster.build().addInterceptor("name", req -> req).addInterceptorBefore("name", "name", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("name interceptor already exists", e.getMessage()); - } - } - - @Test - public void shouldContainBodySerializerByDefault() { - final List> interceptors = Cluster.build().create().getRequestInterceptors(); - assertEquals(1, interceptors.size()); - assertTrue(interceptors.get(0).getRight() instanceof PayloadSerializingInterceptor); - } - - @Test - public void shouldRemoveDefaultSerializer() { - final Cluster testCluster = Cluster.build().removeInterceptor(SERIALIZER_INTERCEPTOR_NAME).create(); - assertEquals(0, testCluster.getRequestInterceptors().size()); - } -} diff --git a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/InterceptorTest.java b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/InterceptorTest.java new file mode 100644 index 00000000000..a10c36ae32a --- /dev/null +++ b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/InterceptorTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.driver; + +import org.apache.tinkerpop.gremlin.driver.auth.Auth; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; +import org.apache.tinkerpop.shaded.jackson.databind.JsonNode; +import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper; +import org.junit.Test; + +import java.net.URI; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public class InterceptorTest { + + private static final ObjectMapper mapper = new ObjectMapper(); + + private HttpRequest createRequest(final RequestMessage msg) throws Exception { + return new HttpRequest(new HashMap<>(), msg, new URI("http://localhost:8182/gremlin")); + } + + private HttpRequest createRequest() throws Exception { + return createRequest(RequestMessage.build("g.V()").addG("g").create()); + } + + @Test + public void interceptorReceivesRequestMessageInBody() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V()").create(); + final HttpRequest request = createRequest(msg); + + final List captured = new ArrayList<>(); + final RequestInterceptor interceptor = req -> captured.add(req.getBody()); + interceptor.intercept(request); + + assertEquals(1, captured.size()); + assertSame(msg, captured.get(0)); + } + + @Test + public void interceptorCanReadAndModifyHeaders() throws Exception { + final HttpRequest request = createRequest(); + request.headers().put("X-Existing", "original"); + + final RequestInterceptor interceptor = req -> { + assertEquals("original", req.headers().get("X-Existing")); + req.headers().put("X-Existing", "modified"); + req.headers().put("X-New", "added"); + }; + interceptor.intercept(request); + + assertEquals("modified", request.headers().get("X-Existing")); + assertEquals("added", request.headers().get("X-New")); + } + + @Test + public void interceptorCanModifyUri() throws Exception { + final HttpRequest request = createRequest(); + final URI newUri = new URI("http://other-host:9999/gremlin"); + + final RequestInterceptor interceptor = req -> req.setUri(newUri); + interceptor.intercept(request); + + assertEquals(newUri, request.getUri()); + } + + @Test + public void interceptorsRunInRegistrationOrder() throws Exception { + final HttpRequest request = createRequest(); + final List order = new ArrayList<>(); + + final List interceptors = List.of( + req -> order.add(1), + req -> order.add(2), + req -> order.add(3) + ); + + for (final RequestInterceptor i : interceptors) { + i.intercept(request); + } + + assertEquals(List.of(1, 2, 3), order); + } + + @Test + public void serializeBodyConvertsRequestMessageToJsonBytes() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V().count()").addG("g").create(); + final HttpRequest request = createRequest(msg); + + final byte[] result = request.serializeBody(); + + final JsonNode json = mapper.readTree(result); + assertEquals("g.V().count()", json.get("gremlin").asText()); + assertEquals("g", json.get("g").asText()); + } + + @Test + public void serializeBodySetsContentTypeHeader() throws Exception { + final HttpRequest request = createRequest(); + + request.serializeBody(); + + assertEquals("application/json", request.headers().get(HttpRequest.Headers.CONTENT_TYPE)); + } + + @Test + public void serializeBodySetsContentLengthHeader() throws Exception { + final HttpRequest request = createRequest(); + + final byte[] result = request.serializeBody(); + + assertEquals(String.valueOf(result.length), request.headers().get(HttpRequest.Headers.CONTENT_LENGTH)); + } + + @Test + public void serializeBodyIsIdempotentWithPreSerializedBytes() throws Exception { + final byte[] existing = "{\"gremlin\":\"g.V()\"}".getBytes(); + final HttpRequest request = new HttpRequest(new HashMap<>(), existing, new URI("http://localhost:8182/gremlin")); + + final byte[] first = request.serializeBody(); + final byte[] second = request.serializeBody(); + + assertSame(existing, first); + assertSame(first, second); + } + + @Test + public void serializeBodyIsIdempotentWithRequestMessage() throws Exception { + final HttpRequest request = createRequest(); + + final byte[] first = request.serializeBody(); + final byte[] second = request.serializeBody(); + + assertSame(first, second); + } + + @Test + public void serializeBodyIncludesAllFields() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V()") + .addG("g") + .addLanguage("gremlin-lang") + .addTimeoutMillis(5000L) + .addBulkResults(true) + .create(); + final HttpRequest request = createRequest(msg); + + final byte[] result = request.serializeBody(); + final JsonNode json = mapper.readTree(result); + + assertEquals("g.V()", json.get("gremlin").asText()); + assertEquals("g", json.get("g").asText()); + assertEquals("gremlin-lang", json.get("language").asText()); + assertEquals(5000, json.get("timeoutMs").asLong()); + assertEquals("true", json.get("bulkResults").asText()); + } + + @Test + public void interceptorCanReplaceBodyBeforeSerialization() throws Exception { + final RequestMessage original = RequestMessage.build("g.V()").addG("g").create(); + final HttpRequest request = createRequest(original); + + final RequestInterceptor interceptor = req -> { + req.setBody(RequestMessage.build("g.E()").addG("gmodern").create()); + }; + interceptor.intercept(request); + request.serializeBody(); + + final JsonNode json = mapper.readTree((byte[]) request.getBody()); + assertEquals("g.E()", json.get("gremlin").asText()); + assertEquals("gmodern", json.get("g").asText()); + } + + @Test + public void authIsAlwaysLastInterceptorRegardlessOfBuilderCallOrder() throws Exception { + // auth called before interceptors + final Cluster cluster1 = Cluster.build("localhost") + .auth(Auth.basic("user", "pass")) + .interceptors(req -> req.headers().put("X-Custom", "value")) + .create(); + + final List interceptors1 = cluster1.getRequestInterceptors(); + assertEquals(2, interceptors1.size()); + assertTrue("Auth should be last interceptor", interceptors1.get(interceptors1.size() - 1) instanceof Auth); + cluster1.close(); + + // auth called after interceptors + final Cluster cluster2 = Cluster.build("localhost") + .interceptors(req -> req.headers().put("X-Custom", "value")) + .auth(Auth.basic("user", "pass")) + .create(); + + final List interceptors2 = cluster2.getRequestInterceptors(); + assertEquals(2, interceptors2.size()); + assertTrue("Auth should be last interceptor", interceptors2.get(interceptors2.size() - 1) instanceof Auth); + cluster2.close(); + } + +} diff --git a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java index 5c7ecd1fd70..a72bcd09467 100644 --- a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java +++ b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java @@ -20,10 +20,9 @@ package org.apache.tinkerpop.gremlin.driver.auth; import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; import java.util.HashMap; import org.apache.tinkerpop.gremlin.driver.HttpRequest; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -52,7 +51,6 @@ public class Sigv4Test { private static final String REGION = "us-west-2"; private static final String SERVICE_NAME = "service-name"; - private static final byte[] REQUEST_BODY = "{\"gremlin\":\"2-1\"}".getBytes(StandardCharsets.UTF_8); private static final String HOST = "localhost"; private static final String URI_WITH_QUERY_PARAMS = "http://" + HOST + ":8182?a=1&b=2"; private static final String KEY = "foo"; @@ -71,8 +69,16 @@ public void setup() { @Test public void shouldAddSignedHeaders() throws Exception { when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); - HttpRequest httpRequest = createRequest(); - sigv4.apply(httpRequest); + HttpRequest httpRequest = createRequestWithRequestMessage(); + sigv4.intercept(httpRequest); + validateExpectedHeaders(httpRequest); + } + + @Test + public void shouldSignPreSerializedBody() throws Exception { + when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); + HttpRequest httpRequest = createRequestWithBytes(); + sigv4.intercept(httpRequest); validateExpectedHeaders(httpRequest); } @@ -80,35 +86,48 @@ public void shouldAddSignedHeaders() throws Exception { public void shouldAddSignedHeadersAndSessionToken() throws Exception { String sessionToken = "foobarz"; when(credentialsProvider.resolveCredentials()).thenReturn(AwsSessionCredentials.create(KEY, SECRET, sessionToken)); - HttpRequest httpRequest = createRequest(); - sigv4.apply(httpRequest); + HttpRequest httpRequest = createRequestWithRequestMessage(); + sigv4.intercept(httpRequest); validateExpectedHeaders(httpRequest); assertEquals(sessionToken, httpRequest.headers().get(X_AMZ_SECURITY_TOKEN)); } @Test - public void shouldThrowIfRequestNonByteArray() { + public void shouldThrowIfBodyIsUnsupportedType() { + // serializeBody() will throw because body is a String Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class, - () -> sigv4.apply(new HttpRequest(new HashMap<>(), "not byte array", new URI(URI_WITH_QUERY_PARAMS)))); - assertTrue(ex.getMessage().contains("Expected byte[] in HttpRequest body")); + () -> sigv4.intercept(new HttpRequest(new HashMap<>(), "not valid", URI.create(URI_WITH_QUERY_PARAMS)))); + assertTrue(ex.getCause().getMessage().contains("Cannot serialize body of type")); } @Test - public void shouldThrowIfNoRequestMethod() { + public void shouldThrowIfNoRequestMethod() throws Exception { + when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); + final byte[] body = "{\"gremlin\":\"g.V()\"}".getBytes(); Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class, - () -> sigv4.apply(new HttpRequest(new HashMap<>(), REQUEST_BODY, new URI(URI_WITH_QUERY_PARAMS), null))); - assertTrue(ex.getMessage().contains("The request method must not be null")); + () -> sigv4.intercept(new HttpRequest(new HashMap<>(), body, URI.create(URI_WITH_QUERY_PARAMS), null))); + assertTrue(ex.getCause().getMessage().contains("The request method must not be null")); } @Test public void shouldThrowIfNoRequestURI() { + when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); + final byte[] body = "{\"gremlin\":\"g.V()\"}".getBytes(); Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class, - () -> sigv4.apply(new HttpRequest(new HashMap<>(), REQUEST_BODY, null))); - assertTrue(ex.getMessage().contains("The request URI must not be null")); + () -> sigv4.intercept(new HttpRequest(new HashMap<>(), body, null))); + assertTrue(ex.getCause().getMessage().contains("The request URI must not be null")); + } + + private HttpRequest createRequestWithRequestMessage() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V()").addG("g").create(); + HttpRequest httpRequest = new HttpRequest(new HashMap<>(), msg, new URI(URI_WITH_QUERY_PARAMS)); + httpRequest.headers().put("Host", "this-should-be-ignored-for-signed-host-header"); + return httpRequest; } - private HttpRequest createRequest() throws URISyntaxException { - HttpRequest httpRequest = new HttpRequest(new HashMap<>(), REQUEST_BODY, new URI(URI_WITH_QUERY_PARAMS)); + private HttpRequest createRequestWithBytes() throws Exception { + final byte[] body = "{\"gremlin\":\"2-1\"}".getBytes(); + HttpRequest httpRequest = new HttpRequest(new HashMap<>(), body, new URI(URI_WITH_QUERY_PARAMS)); httpRequest.headers().put("Content-Type", "application/json"); httpRequest.headers().put("Host", "this-should-be-ignored-for-signed-host-header"); return httpRequest; @@ -123,5 +142,4 @@ private void validateExpectedHeaders(HttpRequest httpRequest) { containsString("/" + REGION + "/service-name/aws4_request"), containsString("Signature="))); } - -} \ No newline at end of file +} diff --git a/gremlin-go/driver/auth.go b/gremlin-go/driver/auth.go index e0f279110ab..411458e65c3 100644 --- a/gremlin-go/driver/auth.go +++ b/gremlin-go/driver/auth.go @@ -49,13 +49,12 @@ func SigV4Auth(region, service string) RequestInterceptor { // SigV4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4 // with the provided credentials provider. If provider is nil, uses default credential chain. // If the request body has not been serialized yet (*RequestMessage), it is automatically -// serialized to GraphBinary before signing. +// serialized to JSON before signing via SerializeBody(). // // Caches the signer and credentials provider for efficiency. func SigV4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor { // Create signer once - it's stateless and safe to reuse signer := v4.NewSigner() - serialize := SerializeRequest() // Cache for resolved credentials provider (lazy initialization) var cachedProvider aws.CredentialsProvider @@ -63,15 +62,10 @@ func SigV4AuthWithCredentials(region, service string, credentialsProvider aws.Cr var providerErr error return func(req *HttpRequest) error { - // If Body is still *RequestMessage, serialize it to GraphBinary before signing. - if _, ok := req.Body.(*RequestMessage); ok { - if err := serialize(req); err != nil { - return fmt.Errorf("SigV4 auto-serialization failed: %w", err) - } - } - - if _, ok := req.Body.([]byte); !ok { - return fmt.Errorf("SigV4 signing requires body to be []byte; got %T", req.Body) + // Ensure body is serialized to JSON bytes before signing. + // SerializeBody is idempotent: safe to call even if already serialized. + if _, err := req.SerializeBody(); err != nil { + return fmt.Errorf("SigV4 signing requires a serialized body: %w", err) } ctx := context.Background() diff --git a/gremlin-go/driver/client.go b/gremlin-go/driver/client.go index ced49a1184d..aa34406fcd6 100644 --- a/gremlin-go/driver/client.go +++ b/gremlin-go/driver/client.go @@ -67,6 +67,11 @@ type ClientSettings struct { // RequestInterceptors are functions that modify HTTP requests before sending. RequestInterceptors []RequestInterceptor + + // Auth is a RequestInterceptor for authentication (e.g. BasicAuth, SigV4Auth). + // As a convenience, this is always appended to the end of the interceptor list + // so it runs last, after any user interceptors have modified the request. + Auth RequestInterceptor } // Client is used to connect and interact with a Gremlin-supported server. @@ -122,6 +127,11 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C conn.AddInterceptor(interceptor) } + // Auth interceptor is always last so it runs after user interceptors + if settings.Auth != nil { + conn.AddInterceptor(settings.Auth) + } + client := &Client{ url: url, traversalSource: settings.TraversalSource, diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go index 6a61500132f..8673414d254 100644 --- a/gremlin-go/driver/connection.go +++ b/gremlin-go/driver/connection.go @@ -53,7 +53,6 @@ type connection struct { httpClient *http.Client connSettings *connectionSettings logHandler *logHandler - serializer *GraphBinarySerializer interceptors []RequestInterceptor wg sync.WaitGroup } @@ -111,7 +110,6 @@ func newConnection(handler *logHandler, url string, connSettings *connectionSett httpClient: &http.Client{Transport: transport}, // No Timeout - allows streaming connSettings: connSettings, logHandler: handler, - serializer: newGraphBinarySerializer(handler), } } @@ -205,19 +203,12 @@ func (c *connection) sendRequest(req *RequestMessage) (*http.Response, error) { } } - // After interceptors, serialize if Body is still *RequestMessage - if r, ok := httpReq.Body.(*RequestMessage); ok { - if c.serializer != nil { - data, err := c.serializer.SerializeMessage(r) - if err != nil { - c.logHandler.logf(Error, failedToSendRequest, err.Error()) - return nil, err - } - httpReq.Body = data - } else { - errMsg := "request body was not serialized; either provide a serializer or add an interceptor that serializes the request" - c.logHandler.logf(Error, failedToSendRequest, errMsg) - return nil, fmt.Errorf("%s", errMsg) + // After interceptors, auto-serialize the body to JSON if still a *RequestMessage. + // SerializeBody is idempotent: if an interceptor already called it, this is a no-op. + if _, ok := httpReq.Body.(*RequestMessage); ok { + if _, err := httpReq.SerializeBody(); err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + return nil, err } } @@ -277,7 +268,6 @@ func (c *connection) streamResponse(resp *http.Response, rs ResultSet) { // setHttpRequestHeaders sets default headers on HttpRequest (for interceptors) func (c *connection) setHttpRequestHeaders(req *HttpRequest) { - req.Headers.Set(HeaderContentType, graphBinaryMimeType) req.Headers.Set(HeaderAccept, graphBinaryMimeType) if c.connSettings.enableUserAgentOnConnect { diff --git a/gremlin-go/driver/connection_test.go b/gremlin-go/driver/connection_test.go index 572cc954600..9314b0616d7 100644 --- a/gremlin-go/driver/connection_test.go +++ b/gremlin-go/driver/connection_test.go @@ -907,13 +907,12 @@ func TestNewConnection(t *testing.T) { } func TestSetHttpRequestHeaders(t *testing.T) { - t.Run("sets content type and accept headers", func(t *testing.T) { + t.Run("sets accept header", func(t *testing.T) { conn := newConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{}) req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") conn.setHttpRequestHeaders(req) - assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Content-Type")) assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Accept")) }) @@ -1012,7 +1011,7 @@ func TestConnectionWithMockServer(t *testing.T) { select { case receivedHeaders := <-headersCh: - assert.Equal(t, graphBinaryMimeType, receivedHeaders.Get("Content-Type")) + assert.Equal(t, "application/json", receivedHeaders.Get("Content-Type")) assert.Equal(t, "deflate", receivedHeaders.Get("Accept-Encoding")) assert.NotEmpty(t, receivedHeaders.Get(userAgentHeader)) case <-time.After(time.Second): @@ -1105,8 +1104,7 @@ func TestConnectionWithMockServer(t *testing.T) { require.NoError(t, err) _, _ = rs.All() - // Interceptor should see the default headers - assert.Equal(t, graphBinaryMimeType, interceptorHeaders.Get("Content-Type")) + // Interceptor should see the default headers (Content-Type is not set until SerializeBody runs) assert.Equal(t, graphBinaryMimeType, interceptorHeaders.Get("Accept")) }) @@ -1311,6 +1309,142 @@ func TestDriverRemoteConnectionSettingsWiring(t *testing.T) { }) } +func TestInterceptorIntegration(t *testing.T) { + testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) + testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) + + t.Run("should auto serialize request message with interceptor mutation", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + // Interceptor replaces the RequestMessage body with a different query. + // The driver should auto-serialize the modified RequestMessage and the server + // should execute the modified query. + client, err := NewClient(testNoAuthUrl, + func(settings *ClientSettings) { + settings.TraversalSource = testServerModernGraphAlias + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { + if msg, ok := req.Body.(*RequestMessage); ok { + req.Body = &RequestMessage{ + Gremlin: "g.inject(99)", + Fields: map[string]interface{}{"g": msg.Fields["g"], "language": "gremlin-lang"}, + } + } + return nil + }, + } + }) + assert.Nil(t, err) + assert.NotNil(t, client) + defer client.Close() + + rs, err := client.Submit("g.inject(1)") + assert.Nil(t, err) + result, ok, err := rs.One() + assert.Nil(t, err) + assert.True(t, ok) + val, err := result.GetInt() + assert.Nil(t, err) + assert.Equal(t, 99, val) + }) + + t.Run("should propagate exception thrown during interceptor", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + // Only throw on the first request to verify recovery. + var callCount int + var mu sync.Mutex + + client, err := NewClient(testNoAuthUrl, + func(settings *ClientSettings) { + settings.TraversalSource = testServerModernGraphAlias + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { + mu.Lock() + callCount++ + count := callCount + mu.Unlock() + if count == 1 { + return fmt.Errorf("interceptor broke") + } + return nil + }, + } + }) + assert.Nil(t, err) + assert.NotNil(t, client) + defer client.Close() + + // First request should fail with interceptor error + rs, err := client.Submit("g.inject(1)") + if err != nil { + assert.Contains(t, err.Error(), "interceptor broke") + } else { + _, err = rs.All() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "interceptor broke") + } + + // Subsequent request should succeed, proving connection recovery + rs, err = client.Submit("g.inject(2)") + assert.Nil(t, err) + result, ok, err := rs.One() + assert.Nil(t, err) + assert.True(t, ok) + val, err := result.GetInt() + assert.Nil(t, err) + assert.Equal(t, 2, val) + }) + + t.Run("should propagate error when interceptor sets unsupported body type", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + // Only set invalid body on the first request to verify recovery. + var callCount int + var mu sync.Mutex + + client, err := NewClient(testNoAuthUrl, + func(settings *ClientSettings) { + settings.TraversalSource = testServerModernGraphAlias + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { + mu.Lock() + callCount++ + count := callCount + mu.Unlock() + if count == 1 { + req.Body = 42 + } + return nil + }, + } + }) + assert.Nil(t, err) + assert.NotNil(t, client) + defer client.Close() + + // First request should fail with serialization error + rs, err := client.Submit("g.inject(1)") + if err != nil { + assert.Contains(t, err.Error(), "unsupported body type") + } else { + _, err = rs.All() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "unsupported body type") + } + + // Subsequent request should succeed, proving connection recovery + rs, err = client.Submit("g.inject(2)") + assert.Nil(t, err) + result, ok, err := rs.One() + assert.Nil(t, err) + assert.True(t, ok) + val, err := result.GetInt() + assert.Nil(t, err) + assert.Equal(t, 2, val) + }) +} + // TestConnectionWithMockServer_BasicAuth verifies that BasicAuth interceptor sets the correct // Authorization header and the body is still valid serialized bytes. func TestConnectionWithMockServer_BasicAuth(t *testing.T) { @@ -1340,6 +1474,6 @@ func TestConnectionWithMockServer_BasicAuth(t *testing.T) { // Body should still be valid serialized bytes assert.NotEmpty(t, capturedBody, "serialized body should be non-empty with BasicAuth") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + assert.Equal(t, byte('{'), capturedBody[0], + "body should start with '{' (JSON)") } diff --git a/gremlin-go/driver/driverRemoteConnection.go b/gremlin-go/driver/driverRemoteConnection.go index a5a47a8c517..3450b7e1792 100644 --- a/gremlin-go/driver/driverRemoteConnection.go +++ b/gremlin-go/driver/driverRemoteConnection.go @@ -60,6 +60,11 @@ type DriverRemoteConnectionSettings struct { // RequestInterceptors are functions that modify HTTP requests before sending. RequestInterceptors []RequestInterceptor + // Auth is a RequestInterceptor for authentication (e.g. BasicAuth, SigV4Auth). + // As a convenience, this is always appended to the end of the interceptor list + // so it runs last, after any user interceptors have modified the request. + Auth RequestInterceptor + // PDTRegistry enables registry-based dehydration in the gremlin-lang translator. PDTRegistry *PDTRegistry } @@ -118,6 +123,11 @@ func NewDriverRemoteConnection( conn.AddInterceptor(interceptor) } + // Auth interceptor is always last so it runs after user interceptors + if settings.Auth != nil { + conn.AddInterceptor(settings.Auth) + } + client := &Client{ url: url, traversalSource: settings.TraversalSource, diff --git a/gremlin-go/driver/interceptor.go b/gremlin-go/driver/interceptor.go index e7e0c8e0879..7485303d400 100644 --- a/gremlin-go/driver/interceptor.go +++ b/gremlin-go/driver/interceptor.go @@ -23,9 +23,12 @@ import ( "bytes" "crypto/sha256" "encoding/hex" + "encoding/json" + "fmt" "io" "net/http" "net/url" + "strconv" ) // Common HTTP header keys @@ -90,24 +93,33 @@ func (r *HttpRequest) PayloadHash() string { } } -// RequestInterceptor is a function that modifies an HTTP request before it is sent. -type RequestInterceptor func(*HttpRequest) error - -// SerializeRequest returns a RequestInterceptor that serializes the raw *RequestMessage body -// to GraphBinary []byte. Place this before auth interceptors (e.g., SigV4Auth) that -// need the serialized body bytes. -func SerializeRequest() RequestInterceptor { - serializer := newGraphBinarySerializer(nil) - return func(req *HttpRequest) error { - r, ok := req.Body.(*RequestMessage) - if !ok { - return nil // already serialized or not a *RequestMessage +// SerializeBody serializes the request body to JSON if it is still a *RequestMessage. +// If the body is already []byte, it returns those bytes (idempotent). +// On successful serialization from *RequestMessage, it sets the body to the resulting bytes +// and updates the Content-Type and Content-Length headers. +func (r *HttpRequest) SerializeBody() ([]byte, error) { + switch b := r.Body.(type) { + case []byte: + return b, nil + case *RequestMessage: + payload := make(map[string]interface{}) + payload["gremlin"] = b.Gremlin + for k, v := range b.Fields { + payload[k] = v } - data, err := serializer.SerializeMessage(r) + data, err := json.Marshal(payload) if err != nil { - return err + return nil, fmt.Errorf("failed to serialize request to JSON: %w", err) } - req.Body = data - return nil + r.Body = data + r.Headers.Set(HeaderContentType, "application/json") + r.Headers.Set("Content-Length", strconv.Itoa(len(data))) + return data, nil + default: + return nil, fmt.Errorf("cannot serialize request body of type %T; expected *RequestMessage or []byte. "+ + "If an interceptor modified the body, it must set it to []byte or leave it as *RequestMessage", r.Body) } } + +// RequestInterceptor is a function that modifies an HTTP request before it is sent. +type RequestInterceptor func(*HttpRequest) error diff --git a/gremlin-go/driver/interceptor_test.go b/gremlin-go/driver/interceptor_test.go index adb51786fc9..3931dff5795 100644 --- a/gremlin-go/driver/interceptor_test.go +++ b/gremlin-go/driver/interceptor_test.go @@ -21,12 +21,13 @@ package gremlingo import ( "bytes" + "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "reflect" - "strings" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -70,15 +71,210 @@ func TestInterceptorReceivesRawRequest(t *testing.T) { } } -// TestSigV4AuthWithSerializeInterceptor verifies that SerializeRequest() + SigV4Auth -// works in a chain. SerializeRequest converts *RequestMessage to []byte, then SigV4Auth -// can sign the serialized body. -func TestSigV4AuthWithSerializeInterceptor(t *testing.T) { +// TestInterceptorCanModifyHeaders verifies that interceptors can read and modify headers. +func TestInterceptorCanModifyHeaders(t *testing.T) { var capturedHeaders http.Header - var capturedBody []byte server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + conn.AddInterceptor(func(req *HttpRequest) error { + req.Headers.Set("X-Custom-Header", "custom-value") + return nil + }) + + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header"), + "interceptor should be able to set custom headers") +} + +// TestInterceptorCanModifyURI verifies that interceptors can modify the request URI. +func TestInterceptorCanModifyURI(t *testing.T) { + var capturedPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL+"/gremlin", &connectionSettings{}) + + conn.AddInterceptor(func(req *HttpRequest) error { + req.URL.Path = "/custom-path" + return nil + }) + + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, "/custom-path", capturedPath, + "interceptor should be able to modify the URL path") +} + +// TestInterceptor_ChainOrder verifies that interceptors run in the order they are added. +func TestInterceptor_ChainOrder(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + var order []int + + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 1) + return nil + }) + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 2) + return nil + }) + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 3) + return nil + }) + + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, []int{1, 2, 3}, order, + "interceptors should run in the order they were added") +} + +// TestSerializeBody_JSONOutput verifies that SerializeBody produces valid JSON containing +// the gremlin field and all fields from RequestMessage.Fields. +func TestSerializeBody_JSONOutput(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V().has('name','marko')", + Fields: map[string]interface{}{"language": "gremlin-lang", "g": "g"}, + } + + data, err := req.SerializeBody() + require.NoError(t, err) + + var parsed map[string]interface{} + err = json.Unmarshal(data, &parsed) + require.NoError(t, err, "SerializeBody should produce valid JSON") + + assert.Equal(t, "g.V().has('name','marko')", parsed["gremlin"]) + assert.Equal(t, "gremlin-lang", parsed["language"]) + assert.Equal(t, "g", parsed["g"]) +} + +// TestSerializeBody_SetsContentTypeHeader verifies that SerializeBody sets Content-Type to application/json. +func TestSerializeBody_SetsContentTypeHeader(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V()", + Fields: map[string]interface{}{}, + } + + _, err = req.SerializeBody() + require.NoError(t, err) + + assert.Equal(t, "application/json", req.Headers.Get(HeaderContentType), + "SerializeBody should set Content-Type to application/json") +} + +// TestSerializeBody_SetsContentLengthHeader verifies that SerializeBody sets Content-Length +// to the byte length of the serialized body. +func TestSerializeBody_SetsContentLengthHeader(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V()", + Fields: map[string]interface{}{}, + } + + data, err := req.SerializeBody() + require.NoError(t, err) + + expected := strconv.Itoa(len(data)) + assert.Equal(t, expected, req.Headers.Get("Content-Length"), + "SerializeBody should set Content-Length to byte length of the body") +} + +// TestSerializeBody_Idempotent verifies that calling SerializeBody when body is already +// []byte returns the same bytes without re-serialization. +func TestSerializeBody_Idempotent(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V()", + Fields: map[string]interface{}{"g": "g"}, + } + + // First call serializes + data1, err := req.SerializeBody() + require.NoError(t, err) + + // Second call should return same bytes + data2, err := req.SerializeBody() + require.NoError(t, err) + + assert.Equal(t, data1, data2, + "SerializeBody should be idempotent, returning the same bytes on subsequent calls") +} + +// TestSerializeBody_MultipleCalls verifies that multiple calls produce identical results. +func TestSerializeBody_MultipleCalls(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V().count()", + Fields: map[string]interface{}{"language": "gremlin-lang"}, + } + + results := make([][]byte, 3) + for i := range results { + data, err := req.SerializeBody() + require.NoError(t, err) + results[i] = data + } + + assert.Equal(t, results[0], results[1]) + assert.Equal(t, results[1], results[2]) +} + +// TestSerializeBody_UnsupportedType verifies that SerializeBody returns an error when +// body is neither *RequestMessage nor []byte. +func TestSerializeBody_UnsupportedType(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = 42 + + _, err = req.SerializeBody() + require.Error(t, err, "SerializeBody should return error for unsupported body type") + assert.Contains(t, err.Error(), "cannot serialize request body") +} + +// TestFieldMutationBeforeSerialization verifies that an interceptor can modify +// RequestMessage fields and the serialized output reflects those changes. +func TestFieldMutationBeforeSerialization(t *testing.T) { + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err == nil { capturedBody = body @@ -89,35 +285,34 @@ func TestSigV4AuthWithSerializeInterceptor(t *testing.T) { conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) - mockProvider := &mockCredentialsProvider{ - accessKey: "MOCK_ID", - secretKey: "MOCK_KEY", - } - - conn.AddInterceptor(SerializeRequest()) - conn.AddInterceptor(SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", mockProvider)) + // Interceptor that adds a custom field before serialization + conn.AddInterceptor(func(req *HttpRequest) error { + r, ok := req.Body.(*RequestMessage) + if !ok { + return fmt.Errorf("expected *RequestMessage, got %T", req.Body) + } + r.Fields["customField"] = "customValue" + return nil + }) - rs, err := conn.submit(&RequestMessage{Gremlin: "g.V().count()", Fields: map[string]interface{}{}}) + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{"g": "g"}}) require.NoError(t, err) - _, _ = rs.All() // drain + _, _ = rs.All() - // SigV4 should have added Authorization and X-Amz-Date headers - assert.NotEmpty(t, capturedHeaders.Get("Authorization"), - "SigV4Auth should set Authorization header after SerializeRequest") - assert.NotEmpty(t, capturedHeaders.Get("X-Amz-Date"), - "SigV4Auth should set X-Amz-Date header") - assert.Contains(t, capturedHeaders.Get("Authorization"), "AWS4-HMAC-SHA256", - "Authorization header should use AWS4-HMAC-SHA256 signing algorithm") + // Parse the JSON body sent to the server + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "server should receive valid JSON") - // Body should be valid serialized bytes - assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + assert.Equal(t, "g.V()", parsed["gremlin"]) + assert.Equal(t, "g", parsed["g"]) + assert.Equal(t, "customValue", parsed["customField"], + "interceptor field mutation should be reflected in the serialized output") } -// TestSigV4Auth_AutoSerializesInChain verifies that SigV4Auth works as the only -// interceptor — it auto-serializes *RequestMessage before signing. -func TestSigV4Auth_AutoSerializesInChain(t *testing.T) { +// TestSigV4AuthWithSerializeBody verifies that SigV4Auth calls SerializeBody and signs +// the request properly. +func TestSigV4AuthWithSerializeBody(t *testing.T) { var capturedHeaders http.Header var capturedBody []byte @@ -145,17 +340,86 @@ func TestSigV4Auth_AutoSerializesInChain(t *testing.T) { require.NoError(t, err) _, _ = rs.All() + // SigV4 should have added Authorization and X-Amz-Date headers assert.NotEmpty(t, capturedHeaders.Get("Authorization"), "SigV4Auth should set Authorization header") - assert.Contains(t, capturedHeaders.Get("Authorization"), "AWS4-HMAC-SHA256") + assert.NotEmpty(t, capturedHeaders.Get("X-Amz-Date"), + "SigV4Auth should set X-Amz-Date header") + assert.Contains(t, capturedHeaders.Get("Authorization"), "AWS4-HMAC-SHA256", + "Authorization header should use AWS4-HMAC-SHA256 signing algorithm") + + // Body should be valid JSON assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "body should be valid JSON") + assert.Equal(t, "g.V().count()", parsed["gremlin"]) +} + +// TestSigV4Auth_AutoSerializesRequestMessage verifies that SigV4Auth automatically +// serializes *RequestMessage to JSON bytes before signing. +func TestSigV4Auth_AutoSerializesRequestMessage(t *testing.T) { + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", + } + interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) + + req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") + require.NoError(t, err) + req.Headers.Set("Content-Type", "application/json") + req.Headers.Set("Accept", graphBinaryMimeType) + + // Set Body to *RequestMessage + req.Body = &RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}} + + err = interceptor(req) + require.NoError(t, err, "SigV4Auth should auto-serialize *RequestMessage") + + // Body should now be []byte (serialized JSON) + bodyBytes, ok := req.Body.([]byte) + assert.True(t, ok, "Body should be []byte after SigV4Auth auto-serialization") + assert.NotEmpty(t, bodyBytes, "serialized body should be non-empty") + + // Verify it's valid JSON + var parsed map[string]interface{} + err = json.Unmarshal(bodyBytes, &parsed) + require.NoError(t, err, "body should be valid JSON after auto-serialization") + assert.Equal(t, "g.V()", parsed["gremlin"]) + + // SigV4 headers should be set + assert.NotEmpty(t, req.Headers.Get("Authorization"), "Authorization header should be set") + assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"), "X-Amz-Date header should be set") + assert.Contains(t, req.Headers.Get("Authorization"), "AWS4-HMAC-SHA256") +} + +// TestSigV4Auth_RejectsNonByteBody verifies that SigV4Auth returns an error when Body +// is not []byte and not *RequestMessage (e.g., an io.Reader). +func TestSigV4Auth_RejectsNonByteBody(t *testing.T) { + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", + } + interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) + + req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") + require.NoError(t, err) + req.Headers.Set("Content-Type", "application/json") + req.Headers.Set("Accept", graphBinaryMimeType) + + // Set Body to an unsupported type (not []byte and not *RequestMessage) + req.Body = bytes.NewReader([]byte("not bytes")) + + err = interceptor(req) + require.Error(t, err, "SigV4Auth should reject non-[]byte, non-*RequestMessage body") + assert.Contains(t, err.Error(), "cannot serialize request body", + "error message should indicate unsupported body type") } -// TestMultipleInterceptors_SerializeThenAuth verifies that a custom interceptor can -// modify the raw request, then SerializeRequest serializes it, then BasicAuth adds headers. -func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) { +// TestMultipleInterceptors_MutateThenAuth verifies that a custom interceptor can +// modify the raw request fields, then BasicAuth adds headers, and the driver +// auto-serializes to JSON. +func TestMultipleInterceptors_MutateThenAuth(t *testing.T) { var capturedAuthHeader string var capturedBody []byte @@ -182,9 +446,6 @@ func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) { return nil }) - // SerializeRequest converts the modified *RequestMessage to []byte - conn.AddInterceptor(SerializeRequest()) - // BasicAuth adds the Authorization header (works on any body type) conn.AddInterceptor(BasicAuth("admin", "secret")) @@ -196,10 +457,13 @@ func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) { assert.Equal(t, "Basic YWRtaW46c2VjcmV0", capturedAuthHeader, "Authorization header should be Basic base64(admin:secret)") - // Body should be valid serialized bytes (from SerializeRequest) - assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + // Body should be valid JSON (auto-serialized by driver after interceptors) + assert.NotEmpty(t, capturedBody, "body should be non-empty") + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "body should be valid JSON") + assert.Equal(t, "g.V()", parsed["gremlin"]) + assert.Equal(t, "customValue", parsed["customField"]) } // TestInterceptor_IoReaderBody verifies that an interceptor can set Body to an io.Reader @@ -235,23 +499,6 @@ func TestInterceptor_IoReaderBody(t *testing.T) { "server should receive the custom payload set via io.Reader") } -// TestInterceptor_NilSerializerNoSerialization verifies that when serializer is nil -// and no interceptor serializes, the correct error message is produced. -func TestInterceptor_NilSerializerNoSerialization(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) - conn.serializer = nil // explicitly nil serializer - - _, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) - require.Error(t, err, "should get an error when serializer is nil and no interceptor serializes") - assert.Contains(t, err.Error(), "request body was not serialized", - "error message should indicate the body was not serialized") -} - // TestInterceptor_HttpRequestBody verifies that an interceptor can set Body to *http.Request // and the driver sends it directly, using the *http.Request's headers and body instead of // HttpRequest.Headers. @@ -285,28 +532,13 @@ func TestInterceptor_HttpRequestBody(t *testing.T) { return nil }) - // Also set a header on HttpRequest.Headers that should NOT appear, - // because *http.Request body bypasses HttpRequest.Headers - conn.AddInterceptor(func(req *HttpRequest) error { - req.Headers.Set("X-Should-Not-Appear", "ignored") - return nil - }) - rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) require.NoError(t, err) _, _ = rs.All() // drain - // The server should receive headers from the *http.Request, not from HttpRequest.Headers - assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header"), - "server should receive custom header from *http.Request") - assert.Equal(t, "application/octet-stream", capturedHeaders.Get("Content-Type"), - "server should receive Content-Type from *http.Request") - assert.Empty(t, capturedHeaders.Get("X-Should-Not-Appear"), - "headers set on HttpRequest.Headers should not appear when Body is *http.Request") - - // The server should receive the body from the *http.Request - assert.Equal(t, customBody, capturedBody, - "server should receive body from the *http.Request") + assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header")) + assert.Equal(t, "application/octet-stream", capturedHeaders.Get("Content-Type")) + assert.Equal(t, customBody, capturedBody) } // TestInterceptor_ErrorPropagation verifies that when an interceptor returns an error, @@ -351,88 +583,92 @@ func TestInterceptor_UnsupportedBodyType(t *testing.T) { "error message should indicate unsupported body type") } -// TestInterceptor_ChainOrder verifies that interceptors run in the order they are added. -func TestInterceptor_ChainOrder(t *testing.T) { +// TestDriverAutoSerializes verifies that without any interceptors, the driver +// auto-serializes the request body to JSON. +func TestDriverAutoSerializes(t *testing.T) { + var capturedBody []byte + var capturedContentType string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedContentType = r.Header.Get("Content-Type") + body, err := io.ReadAll(r.Body) + if err == nil { + capturedBody = body + } w.WriteHeader(http.StatusOK) })) defer server.Close() conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) - var order []int - - conn.AddInterceptor(func(req *HttpRequest) error { - order = append(order, 1) - return nil - }) - conn.AddInterceptor(func(req *HttpRequest) error { - order = append(order, 2) - return nil - }) - conn.AddInterceptor(func(req *HttpRequest) error { - order = append(order, 3) - return nil + // No interceptors, driver should auto-serialize + rs, err := conn.submit(&RequestMessage{ + Gremlin: "g.V().count()", + Fields: map[string]interface{}{"language": "gremlin-lang", "g": "g"}, }) - - rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) require.NoError(t, err) - _, _ = rs.All() // drain - - assert.Equal(t, []int{1, 2, 3}, order, - "interceptors should run in the order they were added") -} - -// TestSigV4Auth_RejectsNonByteBody verifies that SigV4Auth returns an error when Body -// is not []byte and not *RequestMessage (e.g., an io.Reader). -func TestSigV4Auth_RejectsNonByteBody(t *testing.T) { - provider := &mockCredentialsProvider{ - accessKey: "MOCK_ID", - secretKey: "MOCK_KEY", - } - interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) - - req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") - require.NoError(t, err) - req.Headers.Set("Content-Type", graphBinaryMimeType) - req.Headers.Set("Accept", graphBinaryMimeType) + _, _ = rs.All() - // Set Body to an unsupported type (not []byte and not *RequestMessage) - req.Body = strings.NewReader("not bytes") + assert.Equal(t, "application/json", capturedContentType, + "Content-Type should be application/json") - err = interceptor(req) - require.Error(t, err, "SigV4Auth should reject non-[]byte, non-*RequestMessage body") - assert.Contains(t, err.Error(), "SigV4 signing requires body to be []byte", - "error message should indicate SigV4 requires []byte body") + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "body should be valid JSON") + assert.Equal(t, "g.V().count()", parsed["gremlin"]) + assert.Equal(t, "gremlin-lang", parsed["language"]) + assert.Equal(t, "g", parsed["g"]) } -// TestSigV4Auth_AutoSerializesRequestMessage verifies that SigV4Auth automatically -// serializes *RequestMessage to []byte before signing. -func TestSigV4Auth_AutoSerializesRequestMessage(t *testing.T) { - provider := &mockCredentialsProvider{ - accessKey: "MOCK_ID", - secretKey: "MOCK_KEY", - } - interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) - - req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") - require.NoError(t, err) - req.Headers.Set("Content-Type", graphBinaryMimeType) - req.Headers.Set("Accept", graphBinaryMimeType) - - // Set Body to *RequestMessage — SigV4Auth should auto-serialize it - req.Body = &RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}} +// TestAuthInterceptorIsAlwaysLast verifies that the Auth field is always appended +// to the end of the interceptor chain, regardless of configuration order. +func TestAuthInterceptorIsAlwaysLast(t *testing.T) { + var order []int - err = interceptor(req) - require.NoError(t, err, "SigV4Auth should auto-serialize *RequestMessage") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() - // Body should now be []byte (serialized) - bodyBytes, ok := req.Body.([]byte) - assert.True(t, ok, "Body should be []byte after SigV4Auth auto-serialization") - assert.NotEmpty(t, bodyBytes, "serialized body should be non-empty") + t.Run("Auth runs after RequestInterceptors on Client", func(t *testing.T) { + order = nil + client, err := NewClient(server.URL, + func(settings *ClientSettings) { + settings.Auth = func(req *HttpRequest) error { order = append(order, 3); return nil } + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { order = append(order, 1); return nil }, + func(req *HttpRequest) error { order = append(order, 2); return nil }, + } + }) + require.NoError(t, err) + defer client.Close() + + rs, err := client.conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, []int{1, 2, 3}, order, + "Auth interceptor should always run last") + }) - // SigV4 headers should be set - assert.NotEmpty(t, req.Headers.Get("Authorization"), "Authorization header should be set") - assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"), "X-Amz-Date header should be set") - assert.Contains(t, req.Headers.Get("Authorization"), "AWS4-HMAC-SHA256") + t.Run("Auth runs after RequestInterceptors on DriverRemoteConnection", func(t *testing.T) { + order = nil + remote, err := NewDriverRemoteConnection(server.URL, + func(settings *DriverRemoteConnectionSettings) { + settings.Auth = func(req *HttpRequest) error { order = append(order, 3); return nil } + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { order = append(order, 1); return nil }, + func(req *HttpRequest) error { order = append(order, 2); return nil }, + } + }) + require.NoError(t, err) + defer remote.Close() + + rs, err := remote.client.conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, []int{1, 2, 3}, order, + "Auth interceptor should always run last") + }) } diff --git a/gremlin-go/driver/request.go b/gremlin-go/driver/request.go index 65f8662cf0c..ebf375d49bf 100644 --- a/gremlin-go/driver/request.go +++ b/gremlin-go/driver/request.go @@ -19,8 +19,6 @@ under the License. package gremlingo -import "strconv" - // RequestMessage represents a request to the server. type RequestMessage struct { Gremlin string @@ -31,7 +29,7 @@ type RequestMessage struct { // // This function is exposed publicly to enable alternative transport protocols (gRPC, HTTP/2, etc.) // to construct properly formatted requests outside the standard HTTP client. The returned -// request can then be serialized using SerializeMessage(). +// request can be placed in an HttpRequest.Body and serialized via SerializeBody(). // // Parameters: // - stringGremlin: The Gremlin query string to execute @@ -44,8 +42,9 @@ type RequestMessage struct { // Example for alternative transports: // // req := MakeStringRequest("g.V().count()", "g", RequestOptions{}) -// serializer := newGraphBinarySerializer(nil) -// bytes, _ := serializer.(graphBinarySerializer).SerializeMessage(&req) +// httpReq, _ := NewHttpRequest("POST", "http://localhost:8182/gremlin") +// httpReq.Body = &req +// bytes, _ := httpReq.SerializeBody() // // Send bytes over gRPC, HTTP/2, etc. func MakeStringRequest(stringGremlin string, traversalSource string, requestOptions RequestOptions) (req RequestMessage) { newFields := map[string]interface{}{ @@ -74,7 +73,7 @@ func MakeStringRequest(stringGremlin string, traversalSource string, requestOpti } if requestOptions.bulkResults != nil { - newFields["bulkResults"] = strconv.FormatBool(*requestOptions.bulkResults) + newFields["bulkResults"] = *requestOptions.bulkResults } if requestOptions.transactionId != "" { diff --git a/gremlin-go/driver/request_test.go b/gremlin-go/driver/request_test.go index 2da846cc929..f1752b0f6c7 100644 --- a/gremlin-go/driver/request_test.go +++ b/gremlin-go/driver/request_test.go @@ -55,7 +55,7 @@ func TestRequest(t *testing.T) { t.Run("Test makeStringRequest() with bulkResults", func(t *testing.T) { r := MakeStringRequest("g.V()", "g", new(RequestOptionsBuilder).SetBulkResults(true).Create()) - assert.Equal(t, "true", r.Fields["bulkResults"]) + assert.Equal(t, true, r.Fields["bulkResults"]) }) t.Run("Test makeStringRequest() with string bindings", func(t *testing.T) { diff --git a/gremlin-go/driver/serializer.go b/gremlin-go/driver/serializer.go index 38d618600cc..509d51657f9 100644 --- a/gremlin-go/driver/serializer.go +++ b/gremlin-go/driver/serializer.go @@ -28,7 +28,7 @@ import ( const graphBinaryMimeType = "application/vnd.graphbinary-v4.0" -// Serializer interface for serializers. +// Serializer interface for serializing requests and deserializing responses. type Serializer interface { SerializeMessage(request *RequestMessage) ([]byte, error) DeserializeMessage(message []byte) (Response, error) @@ -58,27 +58,17 @@ const versionByte byte = 0x84 // SerializeMessage serializes a request message into GraphBinary format. // -// This method is part of the serializer interface and is used internally by the HTTP driver. -// It is also exposed publicly to enable alternative transport protocols (gRPC, HTTP/2, etc.) to -// serialize requests created with MakeBytecodeRequest() or MakeStringRequest(). -// -// The serialized bytes can be transmitted over any transport protocol that supports binary data. +// This method is part of the Serializer interface. It is no longer used by the default driver +// flow (which now serializes requests as JSON via HttpRequest.SerializeBody), but remains +// available for custom interceptors or alternative transport protocols that require +// GraphBinary-encoded requests. // // Parameters: // - request: The request to serialize (created via MakeBytecodeRequest or MakeStringRequest) // // Returns: -// - []byte: The GraphBinary-encoded request ready for transmission +// - serialized: The GraphBinary-encoded request bytes // - error: Any serialization error encountered -// -// Example for alternative transports: -// -// req := MakeBytecodeRequest(bytecode, "g", "") -// serializer := newGraphBinarySerializer(nil) -// bytes, err := serializer.(graphBinarySerializer).SerializeMessage(&req) -// // Send bytes over custom transport -// -// SerializeMessage serializes a request message into GraphBinary. func (gs *GraphBinarySerializer) SerializeMessage(request *RequestMessage) ([]byte, error) { finalMessage, err := gs.buildMessage(request.Gremlin, request.Fields) if err != nil { diff --git a/gremlin-js/gremlin-javascript/lib/driver/auth.ts b/gremlin-js/gremlin-javascript/lib/driver/auth.ts index 41f5573b27a..4d14f7c4f36 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/auth.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/auth.ts @@ -18,12 +18,11 @@ */ import { Buffer } from 'buffer'; -import type { HttpRequest, RequestInterceptor } from './connection.js'; +import type { HttpRequest, RequestInterceptor } from './http-request.js'; export function basic(username: string, password: string): RequestInterceptor { return (request: HttpRequest) => { request.headers['authorization'] = 'Basic ' + Buffer.from(`${username}:${password}`).toString('base64'); - return request; }; } @@ -40,6 +39,10 @@ export function sigv4(region: string, service: string, credentialsProvider?: Aws let resolvedProvider: AwsCredentialsProvider; return async (request: HttpRequest) => { + // Ensure body is serialized to JSON bytes before signing. + // serializeBody is idempotent: safe to call even if already serialized. + request.serializeBody(); + // Lazy-initialize signer and credentials provider on first use if (!signer) { const { SignatureV4 } = await import('@smithy/signature-v4'); @@ -76,6 +79,5 @@ export function sigv4(region: string, service: string, credentialsProvider?: Aws }); request.headers = signed.headers; - return request; }; } diff --git a/gremlin-js/gremlin-javascript/lib/driver/connection.ts b/gremlin-js/gremlin-javascript/lib/driver/connection.ts index fb91060af3f..d007fd04ea4 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/connection.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/connection.ts @@ -30,10 +30,11 @@ import StreamReader from '../structure/io/binary/internals/StreamReader.js'; import * as utils from '../utils.js'; import ResultSet from './result-set.js'; import {RequestMessage} from "./request-message.js"; +import { HttpRequest, RequestInterceptor } from './http-request.js'; import ResponseError from './response-error.js'; import { Traverser } from '../process/traversal.js'; -const { graphBinaryWriter } = ioc; +const { graphBinaryReader } = ioc; const responseStatusCode = { success: 200, @@ -41,15 +42,6 @@ const responseStatusCode = { partialContent: 206, }; -export type HttpRequest = { - url: string; - method: string; - headers: Record; - body: any; -}; - -export type RequestInterceptor = (request: HttpRequest) => HttpRequest | Promise; - export type ConnectionOptions = { ca?: string[]; cert?: string | string[] | Buffer; @@ -59,11 +51,13 @@ export type ConnectionOptions = { reader?: any; rejectUnauthorized?: boolean; traversalSource?: string; - writer?: any | null; headers?: Record; enableUserAgentOnConnect?: boolean; agent?: Agent; interceptors?: RequestInterceptor | RequestInterceptor[]; + /** An optional auth interceptor. As a convenience, this is always appended to the end of the + * interceptor list so it runs last, after any user interceptors have modified the request. */ + auth?: RequestInterceptor; }; /** @@ -71,7 +65,6 @@ export type ConnectionOptions = { */ export default class Connection extends EventEmitter { private readonly _reader: any; - private readonly _writer: any | null; isOpen = true; traversalSource: string; @@ -90,8 +83,7 @@ export default class Connection extends EventEmitter { ) { super(); - this._reader = options.reader || (options.preciseNumbers === true ? createPreciseReader() : new GraphBinaryReader(ioc)); - this._writer = 'writer' in options ? options.writer : graphBinaryWriter; + this._reader = options.reader || (options.preciseNumbers === true ? createPreciseReader() : graphBinaryReader); if (options.pdtRegistry) { this._reader.pdtRegistry = options.pdtRegistry; } @@ -102,12 +94,17 @@ export default class Connection extends EventEmitter { if (typeof interceptors === 'function') { this._interceptors = [interceptors]; } else if (Array.isArray(interceptors)) { - this._interceptors = interceptors; + this._interceptors = [...interceptors]; } else if (interceptors === undefined || interceptors === null) { this._interceptors = []; } else { throw new TypeError('interceptors must be a function, array, or undefined'); } + + // Auth interceptor is always last so it runs after user interceptors + if (options.auth) { + this._interceptors.push(options.auth); + } } /** @@ -123,8 +120,7 @@ export default class Connection extends EventEmitter { * Send a request and buffer the entire response. Returns a Promise. */ async submit(request: RequestMessage) { - const body = this._writer ? this._writer.writeRequest(request) : request; - const response = await this.#makeHttpRequest(body); + const response = await this.#makeHttpRequest(request); return this.#handleResponse(response); } @@ -142,12 +138,11 @@ export default class Connection extends EventEmitter { * @returns {AsyncGenerator} */ async *stream(request: RequestMessage): AsyncGenerator { - const body = this._writer ? this._writer.writeRequest(request) : request; const abortController = new AbortController(); let response: Response; try { - response = await this.#makeHttpRequest(body, abortController.signal); + response = await this.#makeHttpRequest(request, abortController.signal); } catch (e: any) { throw new Error(`Stream request failed: ${e.message}`, { cause: e }); } @@ -215,15 +210,11 @@ export default class Connection extends EventEmitter { return null; } - async #makeHttpRequest(body: any, signal?: AbortSignal): Promise { + async #makeHttpRequest(request: RequestMessage, signal?: AbortSignal): Promise { const headers: Record = { - 'Accept': this._reader.mimeType + 'Accept': this._reader.mimeType, }; - if (this._writer) { - headers['Content-Type'] = this._writer.mimeType; - } - if (this._enableUserAgentOnConnect) { const userAgent = await utils.getUserAgent(); if (userAgent !== undefined) { @@ -237,18 +228,13 @@ export default class Connection extends EventEmitter { }); } - let httpRequest: HttpRequest = { - url: this.url, - method: 'POST', - headers, - body, - }; + const httpRequest = new HttpRequest('POST', this.url, headers, request); // Promote transactionId to HTTP header before interceptors run. // The field remains in the serialized body as well (dual transmission // per the HTTP transaction protocol specification). - if (body instanceof RequestMessage) { - const fields = body.getFields(); + if (request instanceof RequestMessage) { + const fields = request.getFields(); if (fields.has('transactionId')) { httpRequest.headers['X-Transaction-Id'] = fields.get('transactionId'); } @@ -256,12 +242,15 @@ export default class Connection extends EventEmitter { for (let i = 0; i < this._interceptors.length; i++) { try { - httpRequest = await this._interceptors[i](httpRequest); + await this._interceptors[i](httpRequest); } catch (e: any) { throw new Error(`Request interceptor at index ${i} failed: ${e.message}`, { cause: e }); } } + // Auto-serialize body to JSON after interceptors run (idempotent if already serialized) + httpRequest.serializeBody(); + return fetch(httpRequest.url, { method: httpRequest.method, headers: httpRequest.headers, diff --git a/gremlin-js/gremlin-javascript/lib/driver/http-request.ts b/gremlin-js/gremlin-javascript/lib/driver/http-request.ts new file mode 100644 index 00000000000..f47fb62bb25 --- /dev/null +++ b/gremlin-js/gremlin-javascript/lib/driver/http-request.ts @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Buffer } from 'buffer'; +import { RequestMessage } from './request-message.js'; + +/** + * Represents an HTTP request that is passed through interceptors before being sent to the server. + * Interceptors may mutate any field. The {@link serializeBody} method serializes the body to JSON + * if it is still a RequestMessage, or returns the existing bytes if already serialized. + */ +export class HttpRequest { + url: string; + method: string; + headers: Record; + body: RequestMessage | Buffer | any; + + constructor(method: string, url: string, headers: Record, body: any) { + this.method = method; + this.url = url; + this.headers = headers; + this.body = body; + } + + /** + * Serializes the request body to JSON if it is still a RequestMessage. + * If body is already a Buffer, returns it as-is (idempotent). + * Sets Content-Type and Content-Length headers on successful serialization. + * + * @returns The serialized body as a Buffer. + * @throws Error if body is neither a RequestMessage nor a Buffer. + */ + serializeBody(): Buffer { + if (Buffer.isBuffer(this.body)) { + return this.body; + } + + if (this.body instanceof RequestMessage) { + // RequestMessage.toJSON() produces the flattened wire object (standard + custom fields). + const data = Buffer.from(JSON.stringify(this.body), 'utf-8'); + this.body = data; + this.headers['Content-Type'] = 'application/json'; + this.headers['Content-Length'] = String(data.length); + return data; + } + + const typeName = this.body === null ? 'null' : typeof this.body; + throw new Error(`unsupported body type: ${typeName}`); + } +} + +/** + * A request interceptor receives an HttpRequest and mutates it in place. + * Interceptors must not return a value. + */ +export type RequestInterceptor = (request: HttpRequest) => void | Promise; diff --git a/gremlin-js/gremlin-javascript/lib/driver/request-message.ts b/gremlin-js/gremlin-javascript/lib/driver/request-message.ts index f81d011cd37..c8783099354 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/request-message.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/request-message.ts @@ -100,6 +100,44 @@ export class RequestMessage { return this.customFields; } + /** + * Builds the plain object that represents this message on the wire. Standard fields are + * included when set, and custom fields are flattened to the top level. This method is + * invoked automatically by {@link JSON.stringify}. + * + * When a new standard field is added to this class, it should be added here as well so that + * it is included in the serialized request body. + */ + toJSON(): Record { + const payload: Record = { gremlin: this.gremlin }; + + if (this.language) { + payload['language'] = this.language; + } + if (this.g) { + payload['g'] = this.g; + } + if (this.bindings !== undefined) { + payload['bindings'] = this.bindings; + } + if (this.timeoutMs !== undefined) { + payload['timeoutMs'] = this.timeoutMs; + } + if (this.materializeProperties) { + payload['materializeProperties'] = this.materializeProperties; + } + if (this.bulkResults !== undefined) { + payload['bulkResults'] = this.bulkResults; + } + + // Flatten custom/provider fields to the top level + this.customFields.forEach((v, k) => { + payload[k] = v; + }); + + return payload; + } + static build(gremlin: string): Builder { return new Builder(gremlin); } diff --git a/gremlin-js/gremlin-javascript/lib/index.ts b/gremlin-js/gremlin-javascript/lib/index.ts index ef0e7ce1334..a2443b8d1d4 100644 --- a/gremlin-js/gremlin-javascript/lib/index.ts +++ b/gremlin-js/gremlin-javascript/lib/index.ts @@ -33,6 +33,7 @@ import DriverRemoteConnection from './driver/driver-remote-connection.js'; import ResponseError from './driver/response-error.js'; import Client from './driver/client.js'; import ResultSet from './driver/result-set.js'; +import { HttpRequest } from './driver/http-request.js'; import { basic, sigv4 } from './driver/auth.js'; import AnonymousTraversalSource from './process/anonymous-traversal.js'; @@ -44,6 +45,7 @@ export const driver = { DriverRemoteConnection, Client, ResultSet, + HttpRequest, auth: { basic, sigv4, diff --git a/gremlin-js/gremlin-javascript/test/integration/client-tests.js b/gremlin-js/gremlin-javascript/test/integration/client-tests.js index cf439aaa56b..eaf7a738c6f 100644 --- a/gremlin-js/gremlin-javascript/test/integration/client-tests.js +++ b/gremlin-js/gremlin-javascript/test/integration/client-tests.js @@ -273,4 +273,90 @@ describe('ProviderDefinedType - Client', function () { assert.strictEqual(list[1].fields.y, 4); }); }); +}); + +describe('Client interceptor integration', function () { + it('should auto serialize request message with interceptor mutation', async function () { + const { RequestMessage } = await import('../../lib/driver/request-message.js'); + const interceptor = (request) => { + if (request.body instanceof RequestMessage) { + request.body = RequestMessage.build('g.inject(99)').addG('gmodern').create(); + } + }; + const interceptorClient = new Client(serverUrl, { + traversalSource: 'gmodern', + interceptors: [interceptor], + }); + await interceptorClient.open(); + try { + const result = await interceptorClient.submit('g.inject(1)'); + assert.strictEqual(result.first(), 99); + } finally { + await interceptorClient.close(); + } + }); + + it('should propagate exception thrown during interceptor', async function () { + let callCount = 0; + const interceptor = () => { + callCount++; + if (callCount === 1) { + throw new Error('interceptor broke'); + } + }; + const interceptorClient = new Client(serverUrl, { + traversalSource: 'gmodern', + interceptors: [interceptor], + }); + await interceptorClient.open(); + try { + // First request should fail with interceptor error + await assert.rejects( + () => interceptorClient.submit('g.inject(1)'), + (err) => { + assert.ok(err.message.includes('interceptor') || err.message.includes('broke'), + `Expected error about interceptor, got: ${err.message}`); + return true; + } + ); + + // Subsequent request should succeed, proving connection recovery + const result = await interceptorClient.submit('g.inject(2)'); + assert.strictEqual(result.first(), 2); + } finally { + await interceptorClient.close(); + } + }); + + it('should propagate error when interceptor sets unsupported body type', async function () { + let callCount = 0; + const interceptor = (request) => { + callCount++; + if (callCount === 1) { + request.body = 42; + } + }; + const interceptorClient = new Client(serverUrl, { + traversalSource: 'gmodern', + interceptors: [interceptor], + }); + await interceptorClient.open(); + try { + // First request should fail with serialization error + await assert.rejects( + () => interceptorClient.submit('g.inject(1)'), + (err) => { + assert.ok(err.message.includes('unsupported body type') || err.message.includes('serialize'), + `Expected error about unsupported body type, got: ${err.message}`); + return true; + } + ); + + // Subsequent request should succeed, proving connection recovery + const result = await interceptorClient.submit('g.inject(2)'); + assert.strictEqual(result.first(), 2); + } finally { + await interceptorClient.close(); + } + }); }); \ No newline at end of file diff --git a/gremlin-js/gremlin-javascript/test/unit/auth-test.js b/gremlin-js/gremlin-javascript/test/unit/auth-test.js index cab2a5d24f5..e35fb1f8f47 100644 --- a/gremlin-js/gremlin-javascript/test/unit/auth-test.js +++ b/gremlin-js/gremlin-javascript/test/unit/auth-test.js @@ -20,18 +20,14 @@ import assert from 'assert'; import { Buffer } from 'buffer'; import { basic, sigv4 } from '../../lib/driver/auth.js'; +import { HttpRequest } from '../../lib/driver/http-request.js'; describe('auth', function () { describe('basic', function () { function createMockRequest() { - return { - url: 'https://localhost:8182/gremlin', - method: 'POST', - headers: { - 'accept': 'application/vnd.graphbinary-v4.0', - }, - body: new Uint8Array(0), - }; + return new HttpRequest('POST', 'https://localhost:8182/gremlin', { + 'accept': 'application/vnd.graphbinary-v4.0', + }, Buffer.from('')); } it('should add authorization header', function () { @@ -39,40 +35,28 @@ describe('auth', function () { assert.strictEqual(request.headers['authorization'], undefined); const interceptor = basic('username', 'password'); - const result = interceptor(request); + interceptor(request); - assert.ok(result.headers['authorization'].startsWith('Basic ')); + assert.ok(request.headers['authorization'].startsWith('Basic ')); }); it('should encode credentials correctly', function () { const request = createMockRequest(); const interceptor = basic('username', 'password'); - const result = interceptor(request); + interceptor(request); - const encoded = result.headers['authorization'].substring('Basic '.length); + const encoded = request.headers['authorization'].substring('Basic '.length); const decoded = Buffer.from(encoded, 'base64').toString(); assert.strictEqual(decoded, 'username:password'); }); - - it('should return the same request object', function () { - const request = createMockRequest(); - const interceptor = basic('username', 'password'); - const result = interceptor(request); - - assert.strictEqual(result, request); - }); }); describe('sigv4', function () { function createMockRequest() { - return { - url: 'https://localhost:8182/gremlin', - method: 'POST', - headers: { - 'accept': 'application/vnd.graphbinary-v4.0', - }, - body: new Uint8Array(Buffer.from('{"gremlin":"g.V()"}')), - }; + return new HttpRequest('POST', 'https://localhost:8182/gremlin', { + 'accept': 'application/vnd.graphbinary-v4.0', + 'content-type': 'application/json', + }, Buffer.from('{"gremlin":"g.V()"}')); } const mockProvider = () => ({ @@ -85,10 +69,10 @@ describe('auth', function () { assert.strictEqual(request.headers['authorization'], undefined); const interceptor = sigv4('xx-dummy-1', 'test-service', mockProvider); - const result = await interceptor(request); + await interceptor(request); - assert.ok(result.headers['x-amz-date']); - const authHeader = result.headers['authorization']; + assert.ok(request.headers['x-amz-date']); + const authHeader = request.headers['authorization']; assert.ok(authHeader.startsWith('AWS4-HMAC-SHA256 Credential=MOCK_ACCESS_KEY')); assert.ok(authHeader.includes('xx-dummy-1/test-service/aws4_request')); assert.ok(authHeader.includes('Signature=')); @@ -103,20 +87,33 @@ describe('auth', function () { }); const interceptor = sigv4('xx-dummy-1', 'test-service', providerWithToken); - const result = await interceptor(request); + await interceptor(request); - assert.strictEqual(result.headers['x-amz-security-token'], 'MOCK_SESSION_TOKEN'); - const authHeader = result.headers['authorization']; + assert.strictEqual(request.headers['x-amz-security-token'], 'MOCK_SESSION_TOKEN'); + const authHeader = request.headers['authorization']; assert.ok(authHeader.startsWith('AWS4-HMAC-SHA256 Credential=')); assert.ok(authHeader.includes('Signature=')); }); - it('should return the same request object', async function () { + it('should preserve pre-existing headers after signing', async function () { const request = createMockRequest(); - const interceptor = sigv4('xx-dummy-1', 'test-service', mockProvider); - const result = await interceptor(request); + // Capture the headers present before signing (accept + content-type) + const preSignKeys = Object.keys(request.headers); - assert.strictEqual(result, request); + const interceptor = sigv4('xx-dummy-1', 'test-service', mockProvider); + await interceptor(request); + + // The original headers must still be present (not dropped by wholesale replacement) + for (const key of preSignKeys) { + assert.ok( + key in request.headers, + `expected header '${key}' to be preserved after signing`); + } + assert.strictEqual(request.headers['accept'], 'application/vnd.graphbinary-v4.0'); + assert.strictEqual(request.headers['content-type'], 'application/json'); + + // Signing adds at least authorization and x-amz-date on top of the originals + assert.ok(Object.keys(request.headers).length >= preSignKeys.length + 2); }); }); }); diff --git a/gremlin-js/gremlin-javascript/test/unit/connection-test.js b/gremlin-js/gremlin-javascript/test/unit/connection-test.js new file mode 100644 index 00000000000..2ef6cc477ca --- /dev/null +++ b/gremlin-js/gremlin-javascript/test/unit/connection-test.js @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import assert from 'assert'; +import Connection from '../../lib/driver/connection.js'; +import { RequestMessage } from '../../lib/driver/request-message.js'; + +// Connection-level unit tests that verify interceptor wiring and auto-serialization +// by mocking the global fetch and capturing what the Connection sends. +describe('Connection (request pipeline)', function () { + let originalFetch; + let captured; + + beforeEach(function () { + captured = null; + originalFetch = global.fetch; + // Return an error response so we don't need a valid GraphBinary body; the test only + // cares about what was sent to fetch. The resulting ResponseError is swallowed per-test. + global.fetch = (url, init) => { + captured = { url, init }; + return Promise.resolve({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + headers: { get: () => 'text/plain' }, + arrayBuffer: () => Promise.resolve(Buffer.from('error')), + }); + }; + }); + + afterEach(function () { + global.fetch = originalFetch; + }); + + function makeConnection(options = {}) { + const conn = new Connection('http://localhost:8182/gremlin', + { enableUserAgentOnConnect: false, ...options }); + conn.isOpen = true; + return conn; + } + + async function submitAndIgnoreError(conn, msg) { + try { + await conn.submit(msg); + } catch (e) { + // Expected: the mock returns a 500 response. + } + } + + it('auto-serializes the body to JSON when no interceptor does', async function () { + const conn = makeConnection(); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.ok(captured, 'fetch should have been called'); + assert.strictEqual(captured.init.headers['Content-Type'], 'application/json'); + const parsed = JSON.parse(Buffer.from(captured.init.body).toString('utf-8')); + assert.strictEqual(parsed.gremlin, 'g.V()'); + assert.strictEqual(parsed.g, 'g'); + }); + + it('keeps the GraphBinary Accept header for responses', async function () { + const conn = makeConnection(); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.strictEqual(captured.init.headers['Accept'], 'application/vnd.graphbinary-v4.0'); + }); + + it('runs interceptors in registration order', async function () { + const order = []; + const conn = makeConnection({ + interceptors: [ + (req) => { order.push(1); }, + (req) => { order.push(2); }, + (req) => { order.push(3); }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.deepStrictEqual(order, [1, 2, 3]); + }); + + it('reflects interceptor body mutation in the serialized payload', async function () { + const conn = makeConnection({ + interceptors: [ + (req) => { + req.body = RequestMessage.build('g.inject(99)').addG('gmodern').create(); + }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + const parsed = JSON.parse(Buffer.from(captured.init.body).toString('utf-8')); + assert.strictEqual(parsed.gremlin, 'g.inject(99)'); + assert.strictEqual(parsed.g, 'gmodern'); + }); + + it('lets an interceptor add headers that reach the request', async function () { + const conn = makeConnection({ + interceptors: [ + (req) => { req.headers['X-Custom'] = 'value'; }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.strictEqual(captured.init.headers['X-Custom'], 'value'); + }); + + it('propagates interceptor errors to the caller', async function () { + const conn = makeConnection({ + interceptors: [ + () => { throw new Error('interceptor broke'); }, + ], + }); + + await assert.rejects( + () => conn.submit(RequestMessage.build('g.V()').addG('g').create()), + /interceptor broke/); + assert.strictEqual(captured, null, 'fetch should not be called when an interceptor throws'); + }); + + it('auth interceptor always runs last regardless of option order', async function () { + const order = []; + const conn = makeConnection({ + auth: (req) => { order.push(3); }, + interceptors: [ + (req) => { order.push(1); }, + (req) => { order.push(2); }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.deepStrictEqual(order, [1, 2, 3], 'auth interceptor should always run last'); + }); +}); diff --git a/gremlin-js/gremlin-javascript/test/unit/http-request-test.js b/gremlin-js/gremlin-javascript/test/unit/http-request-test.js new file mode 100644 index 00000000000..b3d03e442a1 --- /dev/null +++ b/gremlin-js/gremlin-javascript/test/unit/http-request-test.js @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import assert from 'assert'; +import { Buffer } from 'buffer'; +import { HttpRequest } from '../../lib/driver/http-request.js'; +import { RequestMessage } from '../../lib/driver/request-message.js'; + +describe('HttpRequest', function () { + describe('serializeBody()', function () { + it('should serialize RequestMessage to JSON bytes and set body to Buffer', function () { + const msg = RequestMessage.build("g.V().has('name','marko')") + .addG('g') + .create(); + + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + const data = httpReq.serializeBody(); + + assert(Buffer.isBuffer(data), 'should return a Buffer'); + assert(Buffer.isBuffer(httpReq.body), 'body should now be a Buffer'); + + const parsed = JSON.parse(data.toString('utf-8')); + assert.strictEqual(parsed.gremlin, "g.V().has('name','marko')"); + assert.strictEqual(parsed.g, 'g'); + assert.strictEqual(parsed.language, 'gremlin-lang'); + }); + + it('should set Content-Type to application/json', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + httpReq.serializeBody(); + + assert.strictEqual(httpReq.headers['Content-Type'], 'application/json'); + }); + + it('should set Content-Length to byte length of the serialized body', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const data = httpReq.serializeBody(); + + assert.strictEqual(httpReq.headers['Content-Length'], String(data.length)); + }); + + it('should be idempotent when body is already a Buffer', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const data1 = httpReq.serializeBody(); + const data2 = httpReq.serializeBody(); + + assert(data1.equals(data2), 'subsequent calls should return identical bytes'); + }); + + it('should produce identical results on multiple calls', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const results = []; + for (let i = 0; i < 3; i++) { + results.push(httpReq.serializeBody()); + } + + assert(results[0].equals(results[1])); + assert(results[1].equals(results[2])); + }); + + it('should return existing bytes if body is already a Buffer', function () { + const existing = Buffer.from('{"gremlin":"g.V()"}', 'utf-8'); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, existing); + + const data = httpReq.serializeBody(); + + assert.strictEqual(data, existing, 'should return the same Buffer reference'); + }); + + it('should include all fields from RequestMessage', function () { + const msg = RequestMessage.build("g.V().has('age',x)") + .addG('gCustom') + .addTimeoutMillis(5000) + .addMaterializeProperties('all') + .addBulkResults(true) + .addField('customKey', 'customValue') + .create(); + + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + const data = httpReq.serializeBody(); + const parsed = JSON.parse(data.toString('utf-8')); + + assert.strictEqual(parsed.gremlin, "g.V().has('age',x)"); + assert.strictEqual(parsed.g, 'gCustom'); + assert.strictEqual(parsed.language, 'gremlin-lang'); + assert.strictEqual(parsed.timeoutMs, 5000); + assert.strictEqual(parsed.materializeProperties, 'all'); + assert.strictEqual(parsed.bulkResults, true); + assert.strictEqual(parsed.customKey, 'customValue'); + }); + + it('should throw for unsupported body types', function () { + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, 42); + + assert.throws(() => httpReq.serializeBody(), /unsupported body type/); + }); + + it('should throw for null body', function () { + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, null); + + assert.throws(() => httpReq.serializeBody(), /unsupported body type/); + }); + }); + + describe('interceptors', function () { + it('should receive HttpRequest with body as RequestMessage', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', { 'Accept': 'application/vnd.graphbinary-v4.0' }, msg); + + let receivedBody; + const interceptor = (req) => { + receivedBody = req.body; + }; + + interceptor(httpReq); + + assert(receivedBody instanceof RequestMessage, 'interceptor should receive RequestMessage as body'); + }); + + it('should allow reading and modifying headers', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', { 'Accept': 'application/vnd.graphbinary-v4.0' }, msg); + + const interceptor = (req) => { + assert.strictEqual(req.headers['Accept'], 'application/vnd.graphbinary-v4.0'); + req.headers['X-Custom'] = 'test-value'; + }; + + interceptor(httpReq); + + assert.strictEqual(httpReq.headers['X-Custom'], 'test-value'); + }); + + it('should allow reading and modifying url', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const interceptor = (req) => { + req.url = 'http://other-host:8182/gremlin'; + }; + + interceptor(httpReq); + + assert.strictEqual(httpReq.url, 'http://other-host:8182/gremlin'); + }); + + it('should run in registration order', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + const order = []; + + const interceptor1 = (req) => { order.push(1); }; + const interceptor2 = (req) => { order.push(2); }; + const interceptor3 = (req) => { order.push(3); }; + + interceptor1(httpReq); + interceptor2(httpReq); + interceptor3(httpReq); + + assert.deepStrictEqual(order, [1, 2, 3]); + }); + + it('should allow field mutation before serialization to affect output', function () { + const msg = RequestMessage.build('g.V()') + .addG('g') + .addField('providerField', 'original') + .create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + // Interceptor that calls serializeBody after modifying a header + // but a field-mutating interceptor must work at the RequestMessage level + // For field mutation, the interceptor replaces the body with a new RequestMessage + const interceptor = (req) => { + // Build a new request message with modified fields + const newMsg = RequestMessage.build('g.V().count()') + .addG('gModified') + .addField('providerField', 'modified') + .create(); + req.body = newMsg; + }; + + interceptor(httpReq); + const data = httpReq.serializeBody(); + const parsed = JSON.parse(data.toString('utf-8')); + + assert.strictEqual(parsed.gremlin, 'g.V().count()'); + assert.strictEqual(parsed.g, 'gModified'); + assert.strictEqual(parsed.providerField, 'modified'); + }); + + it('should allow interceptor to call serializeBody for payload hashing', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + // Simulates a SigV4 interceptor that needs the serialized bytes + const signingInterceptor = (req) => { + const bytes = req.serializeBody(); + // Use bytes for hashing (simulated) + req.headers['X-Payload-Hash'] = String(bytes.length); + }; + + signingInterceptor(httpReq); + + // Body should now be serialized + assert(Buffer.isBuffer(httpReq.body)); + assert.strictEqual(httpReq.headers['Content-Type'], 'application/json'); + assert(httpReq.headers['X-Payload-Hash'].length > 0); + + // Subsequent serializeBody call should be idempotent + const data = httpReq.serializeBody(); + assert(Buffer.isBuffer(data)); + }); + }); +}); diff --git a/gremlin-python/src/main/python/examples/connections.py b/gremlin-python/src/main/python/examples/connections.py index 654079d027e..a496d730df0 100644 --- a/gremlin-python/src/main/python/examples/connections.py +++ b/gremlin-python/src/main/python/examples/connections.py @@ -23,7 +23,6 @@ from gremlin_python.process.anonymous_traversal import traversal from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection from gremlin_python.driver.auth import basic -from gremlin_python.driver.serializer import GraphBinarySerializersV4 VERTEX_LABEL = os.getenv('VERTEX_LABEL', 'connection') @@ -83,7 +82,6 @@ def with_configs(): server_url = os.getenv('GREMLIN_SERVER_URL', 'http://localhost:8182/gremlin').format(45940) rc = DriverRemoteConnection( server_url, 'g', - request_serializer=GraphBinarySerializersV4(), headers=None, ) g = traversal().with_remote(rc) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py index 08a24624279..bdd100ce246 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py @@ -132,11 +132,6 @@ async def async_connect(): def write(self, message): # Inner function to perform async write. async def async_write(): - # To pass url into message for request authentication processing - message.update({'url': self._url}) - if message['auth']: - message['auth'](message) - async with async_timeout.timeout(self._write_timeout): self._http_req_resp = await self._client_session.post(url=self._url, data=message['payload'], diff --git a/gremlin-python/src/main/python/gremlin_python/driver/auth.py b/gremlin-python/src/main/python/gremlin_python/driver/auth.py index e7bd453b1d3..89d6ccf8891 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/auth.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/auth.py @@ -16,24 +16,29 @@ # specific language governing permissions and limitations # under the License. # +import base64 def basic(username, password): - from aiohttp import BasicAuth as aiohttpBasicAuth + """Returns an interceptor that adds Basic auth to the request.""" + def interceptor(request): + credentials = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode("utf-8") + request.headers['authorization'] = f"Basic {credentials}" - def apply(request): - return request['headers'].update({'authorization': aiohttpBasicAuth(username, password).encode()}) - - return apply + return interceptor def sigv4(region, service): + """Returns an interceptor that signs the request with AWS SigV4.""" import os from boto3 import Session from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest - def apply(request): + def interceptor(request): + # Ensure body is serialized so we can sign it + body_bytes = request.serialize_body() + access_key = os.environ.get('AWS_ACCESS_KEY_ID', '') secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', '') session_token = os.environ.get('AWS_SESSION_TOKEN', '') @@ -45,11 +50,8 @@ def apply(request): region_name=region ) - sigv4_request = AWSRequest(method="POST", url=request['url'], data=request['payload']) + sigv4_request = AWSRequest(method=request.method, url=request.url, data=body_bytes) SigV4Auth(session.get_credentials(), service, region).add_auth(sigv4_request) - request['headers'].update(sigv4_request.headers) - request['payload'] = sigv4_request.data - return request - - return apply + request.headers.update(dict(sigv4_request.headers)) + return interceptor diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py b/gremlin-python/src/main/python/gremlin_python/driver/client.py index fa4dbc32c2b..c20556851e0 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/client.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py @@ -39,7 +39,6 @@ def cpu_count(): class Client: def __init__(self, url, traversal_source, pool_size=None, max_workers=None, - request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=None, interceptors=None, auth=None, headers=None, enable_user_agent_on_connect=True, bulk_results=False, pdt_registry=None, **transport_kwargs): @@ -63,7 +62,6 @@ def __init__(self, url, traversal_source, pool_size=None, max_workers=None, self._auth = auth self._response_serializer = response_serializer - self._request_serializer = request_serializer self._interceptors = interceptors self._transport_kwargs = transport_kwargs @@ -149,7 +147,6 @@ def _get_connection(self): return connection.Connection( self._url, self._traversal_source, self._executor, self._pool, - request_serializer=self._request_serializer, response_serializer=self._response_serializer, auth=self._auth, interceptors=self._interceptors, headers=self._headers, diff --git a/gremlin-python/src/main/python/gremlin_python/driver/connection.py b/gremlin-python/src/main/python/gremlin_python/driver/connection.py index 64a3f000818..c9b65eab788 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/connection.py @@ -19,6 +19,7 @@ from gremlin_python.driver import resultset, useragent from gremlin_python.driver.aiohttp.transport import AiohttpHTTPTransport +from gremlin_python.driver.http_request import HttpRequest __author__ = 'David M. Brown (davebshow@gmail.com)' @@ -34,17 +35,23 @@ def __init__(self, status): class Connection: def __init__(self, url, traversal_source, - executor, pool, request_serializer=None, + executor, pool, response_serializer=None, auth=None, interceptors=None, headers=None, enable_user_agent_on_connect=True, bulk_results=False, **transport_kwargs): if callable(interceptors): interceptors = [interceptors] - elif not (isinstance(interceptors, tuple) - or isinstance(interceptors, list) - or interceptors is None): + elif isinstance(interceptors, tuple): + interceptors = list(interceptors) + elif not (isinstance(interceptors, list) or interceptors is None): raise TypeError("interceptors must be a callable, tuple, list or None") + # Auth is just an interceptor. As a convenience (and for discoverability), the auth + # interceptor is appended to the end of the interceptor list so it runs last, after + # any user interceptors have modified the request. + if auth is not None: + interceptors = (interceptors or []) + [auth] + self._url = url self._headers = headers self._traversal_source = traversal_source @@ -54,9 +61,7 @@ def __init__(self, url, traversal_source, self._pool = pool self._result_set = None self._inited = False - self._request_serializer = request_serializer self._response_serializer = response_serializer - self._auth = auth self._interceptors = interceptors self._enable_user_agent_on_connect = enable_user_agent_on_connect if self._enable_user_agent_on_connect: @@ -78,22 +83,33 @@ def close(self): def _write_request(self, request_message): accept = str(self._response_serializer.version, encoding='utf-8') - message = { - 'headers': {'accept': accept}, - 'payload': self._request_serializer.serialize_message(request_message) - if self._request_serializer is not None else request_message, - 'auth': self._auth - } - if self._request_serializer is not None: - content_type = str(self._request_serializer.version, encoding='utf-8') - message['headers']['content-type'] = content_type + + headers = {'accept': accept} + # Promote transactionId to HTTP header before interceptors run. # The field remains in the serialized body as well (dual transmission # per the HTTP transaction protocol specification). if hasattr(request_message, 'fields') and 'transactionId' in request_message.fields: - message['headers']['X-Transaction-Id'] = request_message.fields['transactionId'] + headers['X-Transaction-Id'] = request_message.fields['transactionId'] + + http_request = HttpRequest( + method="POST", + url=self._url, + headers=headers, + body=request_message + ) + for interceptor in self._interceptors or []: - message = interceptor(message) + interceptor(http_request) + + # Auto-serialize if no interceptor already did so + http_request.serialize_body() + + # Build the transport message in the format the transport expects + message = { + 'headers': http_request.headers, + 'payload': http_request.body + } self._transport.write(message) def write(self, request_message): diff --git a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py index 279c2f48453..8c82b3d7294 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py @@ -33,7 +33,6 @@ class DriverRemoteConnection(RemoteConnection): def __init__(self, url, traversal_source="g", pool_size=None, max_workers=None, - request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=None, interceptors=None, auth=None, headers=None, enable_user_agent_on_connect=True, bulk_results=False, pdt_registry=None, **transport_kwargs): @@ -54,7 +53,6 @@ def __init__(self, url, traversal_source="g", self._client = client.Client(url, traversal_source, pool_size=pool_size, max_workers=max_workers, - request_serializer=request_serializer, response_serializer=response_serializer, interceptors=interceptors, auth=auth, headers=headers, diff --git a/gremlin-python/src/main/python/gremlin_python/driver/http_request.py b/gremlin-python/src/main/python/gremlin_python/driver/http_request.py new file mode 100644 index 00000000000..304acf28a27 --- /dev/null +++ b/gremlin-python/src/main/python/gremlin_python/driver/http_request.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json + +from gremlin_python.driver.request import RequestMessage + + +class HttpRequest: + """Represents the HTTP request passed through the interceptor chain. + + The body starts as a RequestMessage and can be serialized to JSON bytes + via serialize_body(). Interceptors mutate this object in place. + """ + + def __init__(self, method, url, headers, body): + self.method = method + self.url = url + self.headers = headers + self.body = body + + def serialize_body(self): + """Serialize the body to JSON bytes if it is still a RequestMessage. + + If the body is already bytes, returns them as-is (idempotent). + Sets the Content-Type header to application/json and Content-Length + to the byte length of the serialized body. + + Returns: + bytes: the serialized body + + Raises: + TypeError: if the body is neither a RequestMessage nor bytes + """ + if isinstance(self.body, bytes): + return self.body + + if not isinstance(self.body, RequestMessage): + raise TypeError( + f"Cannot serialize body of type {type(self.body).__name__}. " + "Expected RequestMessage or bytes." + ) + + payload = {"gremlin": self.body.gremlin} + payload.update(self.body.fields) + data = json.dumps(payload).encode("utf-8") + self.body = data + self.headers["content-type"] = "application/json" + self.headers["content-length"] = str(len(data)) + return data diff --git a/gremlin-python/src/main/python/tests/feature/terrain.py b/gremlin-python/src/main/python/tests/feature/terrain.py index d81db72c19d..781216bb639 100644 --- a/gremlin-python/src/main/python/tests/feature/terrain.py +++ b/gremlin-python/src/main/python/tests/feature/terrain.py @@ -107,5 +107,5 @@ def __create_remote(server_graph_name): bulking = world.config.user_data["bulking"] == "true" if "bulking" in world.config.user_data else False return DriverRemoteConnection(test_no_auth_url, server_graph_name, - request_serializer=s, response_serializer=s, + response_serializer=s, bulk_results=bulking) diff --git a/gremlin-python/src/main/python/tests/integration/conftest.py b/gremlin-python/src/main/python/tests/integration/conftest.py index de685fb4b91..8776859ed2c 100644 --- a/gremlin-python/src/main/python/tests/integration/conftest.py +++ b/gremlin-python/src/main/python/tests/integration/conftest.py @@ -19,7 +19,6 @@ import concurrent.futures from collections import namedtuple -from json import dumps import os import ssl import pytest @@ -62,7 +61,6 @@ def connection(request): try: conn = Connection(anonymous_url, 'gmodern', executor, pool, - request_serializer=GraphBinarySerializersV4(), response_serializer=GraphBinarySerializersV4(), auth=basic('stephen', 'password')) except OSError: @@ -122,7 +120,6 @@ def graphbinary_serializer_v4(request): def remote_connection(request): try: remote_conn = DriverRemoteConnection(anonymous_url, 'gmodern', - request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=serializer.GraphBinarySerializersV4()) except OSError: pytest.skip('Gremlin Server is not running') @@ -192,8 +189,7 @@ def fin(): def remote_connection_with_interceptor(request): try: remote_conn = DriverRemoteConnection(anonymous_url, 'gmodern', - request_serializer=None, - interceptors=json_interceptor) + interceptors=mutating_interceptor) except OSError: pytest.skip('Gremlin Server is not running') else: @@ -207,9 +203,8 @@ def fin(): @pytest.fixture() def client_with_interceptor(request): try: - client = Client(anonymous_url, 'gmodern', request_serializer=None, - response_serializer=GraphBinarySerializersV4(), - interceptors=json_interceptor) + client = Client(anonymous_url, 'gmodern', + interceptors=mutating_interceptor) except OSError: pytest.skip('Gremlin Server is not running') else: @@ -241,7 +236,8 @@ def fin(): return remote_conn -def json_interceptor(request): - request['headers']['content-type'] = "application/json" - request['payload'] = dumps({"gremlin": "g.inject(2)", "g": "g"}) - return request +def mutating_interceptor(http_request): + """Interceptor that replaces the gremlin query with g.inject(2) for testing.""" + from gremlin_python.driver.request import RequestMessage + if isinstance(http_request.body, RequestMessage): + http_request.body = RequestMessage(fields={"g": "g"}, gremlin="g.inject(2)") diff --git a/gremlin-python/src/main/python/tests/integration/driver/test_auth.py b/gremlin-python/src/main/python/tests/integration/driver/test_auth.py index 7efd1fe0422..fa6837098b3 100644 --- a/gremlin-python/src/main/python/tests/integration/driver/test_auth.py +++ b/gremlin-python/src/main/python/tests/integration/driver/test_auth.py @@ -16,51 +16,58 @@ # specific language governing permissions and limitations # under the License. # +import base64 import os -from aiohttp import BasicAuth as aiohttpBasicAuth + from gremlin_python.driver.auth import basic, sigv4 +from gremlin_python.driver.http_request import HttpRequest def create_mock_request(): - return {'headers': - {'content-type': 'application/vnd.graphbinary-v4.0', - 'accept': 'application/vnd.graphbinary-v4.0'}, - 'payload': b'', - 'url': 'https://test_url:8182/gremlin'} + return HttpRequest( + method="POST", + url="https://test_url:8182/gremlin", + headers={ + 'content-type': 'application/vnd.graphbinary-v4.0', + 'accept': 'application/vnd.graphbinary-v4.0', + }, + body=b'', + ) class TestAuth(object): def test_basic_auth_request(self): mock_request = create_mock_request() - assert 'authorization' not in mock_request['headers'] + assert 'authorization' not in mock_request.headers basic('username', 'password')(mock_request) - assert 'authorization' in mock_request['headers'] - assert aiohttpBasicAuth('username', 'password').encode() == mock_request['headers']['authorization'] + assert 'authorization' in mock_request.headers + expected = 'Basic ' + base64.b64encode('username:password'.encode('utf-8')).decode('utf-8') + assert expected == mock_request.headers['authorization'] def test_sigv4_auth_request(self): mock_request = create_mock_request() - assert 'Authorization' not in mock_request['headers'] - assert 'X-Amz-Date' not in mock_request['headers'] + assert 'Authorization' not in mock_request.headers + assert 'X-Amz-Date' not in mock_request.headers os.environ['AWS_ACCESS_KEY_ID'] = 'MOCK_ID' os.environ['AWS_SECRET_ACCESS_KEY'] = 'MOCK_KEY' sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request) - assert mock_request['headers']['X-Amz-Date'] is not None - assert mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=MOCK_ID') - assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request['headers']['Authorization'] - assert 'Signature=' in mock_request['headers']['Authorization'] + assert mock_request.headers['X-Amz-Date'] is not None + assert mock_request.headers['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=MOCK_ID') + assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request.headers['Authorization'] + assert 'Signature=' in mock_request.headers['Authorization'] def test_sigv4_auth_request_session_token(self): mock_request = create_mock_request() - assert 'Authorization' not in mock_request['headers'] - assert 'X-Amz-Date' not in mock_request['headers'] - assert 'X-Amz-Security-Token' not in mock_request['headers'] + assert 'Authorization' not in mock_request.headers + assert 'X-Amz-Date' not in mock_request.headers + assert 'X-Amz-Security-Token' not in mock_request.headers + os.environ['AWS_ACCESS_KEY_ID'] = 'MOCK_ID' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'MOCK_KEY' os.environ['AWS_SESSION_TOKEN'] = 'MOCK_TOKEN' sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request) - assert mock_request['headers']['X-Amz-Date'] is not None - assert mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=') - assert mock_request['headers']['X-Amz-Security-Token'] == 'MOCK_TOKEN' - assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request['headers']['Authorization'] - assert 'Signature=' in mock_request['headers']['Authorization'] - - + assert mock_request.headers['X-Amz-Date'] is not None + assert mock_request.headers['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=') + assert mock_request.headers['X-Amz-Security-Token'] == 'MOCK_TOKEN' + assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request.headers['Authorization'] + assert 'Signature=' in mock_request.headers['Authorization'] diff --git a/gremlin-python/src/main/python/tests/integration/driver/test_client.py b/gremlin-python/src/main/python/tests/integration/driver/test_client.py index 34fa004e125..b6a2236eda9 100644 --- a/gremlin-python/src/main/python/tests/integration/driver/test_client.py +++ b/gremlin-python/src/main/python/tests/integration/driver/test_client.py @@ -612,3 +612,47 @@ def test_pdt_in_collection(client): assert pdt_list[1].name == 'Point' assert pdt_list[1].fields['x'] == 3 assert pdt_list[1].fields['y'] == 4 + + +def test_auto_serializes_request_message_with_interceptor_mutation(): + """Verifies the driver auto-serializes when an interceptor modifies the RequestMessage body.""" + from gremlin_python.driver.request import RequestMessage + + def swap_query(http_request): + if isinstance(http_request.body, RequestMessage): + http_request.body = RequestMessage(fields={"g": "gmodern"}, gremlin="g.inject(99)") + + client = Client(test_no_auth_url, 'gmodern', + pool_size=1, interceptors=swap_query) + try: + result = client.submit("g.inject(1)").next() + assert 99 == result + finally: + client.close() + + +def test_interceptor_errors_propagate(): + """Verifies that an interceptor error propagates to the caller, the request is not sent, + and the client remains usable for subsequent requests.""" + call_count = [0] + + def failing_interceptor(http_request): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("interceptor broke") + + client = Client(test_no_auth_url, 'gmodern', + pool_size=1, interceptors=failing_interceptor) + try: + # First request should fail with interceptor error + try: + client.submit("g.inject(1)").next() + assert False, "Should have thrown an exception" + except RuntimeError as e: + assert "interceptor broke" in str(e) + + # Subsequent request should succeed, proving the client is still usable + result = client.submit("g.inject(2)").next() + assert 2 == result + finally: + client.close() diff --git a/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py b/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py index 23415fdef32..1f6d2e1f90b 100644 --- a/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py +++ b/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py @@ -29,6 +29,7 @@ import asyncio import io +import json import queue import struct from concurrent.futures import Future @@ -1247,12 +1248,11 @@ def test_receive_uses_serializer_version_for_content_type_check(self): class TestConnectionWriteRequest: """ - Tests for Connection._write_request() which handles serialization, - header construction, auth, and interceptors before calling transport.write(). + Tests for Connection._write_request() which builds an HttpRequest, runs interceptors, + auto-serializes the body to JSON, and calls transport.write(). """ - def _make_connection(self, request_serializer=None, response_serializer=None, - auth=None, interceptors=None): + def _make_connection(self, response_serializer=None, interceptors=None): from gremlin_python.driver.connection import Connection from gremlin_python.driver.serializer import GraphBinarySerializersV4 @@ -1260,31 +1260,21 @@ def _make_connection(self, request_serializer=None, response_serializer=None, response_serializer = GraphBinarySerializersV4() conn = Connection.__new__(Connection) - conn._response_serializer = GraphBinarySerializersV4() - conn._request_serializer = request_serializer + conn._url = 'http://localhost:8182/gremlin' conn._response_serializer = response_serializer - conn._auth = auth conn._interceptors = interceptors conn._transport = MagicMock() return conn - def test_none_request_serializer_passes_raw_message(self): - conn = self._make_connection(request_serializer=None) - msg = RequestMessage(fields={}, gremlin="g.V()") + def test_auto_serializes_body_to_json(self): + conn = self._make_connection() + msg = RequestMessage(fields={"g": "g"}, gremlin="g.V()") conn._write_request(msg) written = conn._transport.write.call_args[0][0] - assert written['payload'] == msg - assert 'content-type' not in written['headers'] - - def test_graphbinary_serializer_serializes_payload(self): - from gremlin_python.driver.serializer import GraphBinarySerializersV4 - gb = GraphBinarySerializersV4() - conn = self._make_connection(request_serializer=gb, response_serializer=gb) - msg = RequestMessage(fields={}, gremlin="g.V()") - conn._write_request(msg) - written = conn._transport.write.call_args[0][0] - assert written['payload'] == gb.serialize_message(msg) - assert written['headers']['content-type'] == str(gb.version, encoding='utf-8') + payload = json.loads(written['payload']) + assert payload['gremlin'] == "g.V()" + assert payload['g'] == "g" + assert written['headers']['content-type'] == "application/json" def test_accept_header_set_from_response_serializer(self): from gremlin_python.driver.serializer import GraphBinarySerializersV4 @@ -1294,45 +1284,30 @@ def test_accept_header_set_from_response_serializer(self): written = conn._transport.write.call_args[0][0] assert written['headers']['accept'] == str(gb.version, encoding='utf-8') - def test_auth_passed_in_message(self): - auth_fn = lambda req: req - conn = self._make_connection(auth=auth_fn) - conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) - written = conn._transport.write.call_args[0][0] - assert written['auth'] is auth_fn - def test_single_interceptor_runs(self): - changed = RequestMessage(fields={}, gremlin="changed") def interceptor(request): - request['payload'] = changed - return request + request.body = RequestMessage(fields={"g": "g"}, gremlin="changed") conn = self._make_connection(interceptors=[interceptor]) conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) written = conn._transport.write.call_args[0][0] - assert written['payload'] == changed + assert json.loads(written['payload'])['gremlin'] == "changed" def test_interceptors_run_sequentially(self): - def one(req): req['payload'].gremlin.append(1); return req - def two(req): req['payload'].gremlin.append(2); return req - def three(req): req['payload'].gremlin.append(3); return req + order = [] + def one(req): order.append(1) + def two(req): order.append(2) + def three(req): order.append(3) conn = self._make_connection(interceptors=[one, two, three]) - conn._write_request(RequestMessage(fields={}, gremlin=[])) - written = conn._transport.write.call_args[0][0] - assert written['payload'].gremlin == [1, 2, 3] + conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) + assert order == [1, 2, 3] - def test_interceptor_works_with_serializer(self): - from gremlin_python.driver.serializer import GraphBinarySerializersV4 - gb = GraphBinarySerializersV4() - msg = RequestMessage(fields={}, gremlin="g.E()") - def assert_interceptor(request): - assert request['payload'] == gb.serialize_message(msg) - request['payload'] = "changed" - return request - conn = self._make_connection(request_serializer=gb, response_serializer=gb, - interceptors=[assert_interceptor]) - conn._write_request(msg) + def test_interceptor_can_modify_headers(self): + def interceptor(request): + request.headers['x-custom'] = "value" + conn = self._make_connection(interceptors=[interceptor]) + conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) written = conn._transport.write.call_args[0][0] - assert written['payload'] == "changed" + assert written['headers']['x-custom'] == "value" class TestConnectionInterceptorValidation: diff --git a/gremlin-python/src/main/python/tests/unit/driver/test_interceptor.py b/gremlin-python/src/main/python/tests/unit/driver/test_interceptor.py new file mode 100644 index 00000000000..456d7fb4a58 --- /dev/null +++ b/gremlin-python/src/main/python/tests/unit/driver/test_interceptor.py @@ -0,0 +1,221 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json + +from gremlin_python.driver.http_request import HttpRequest +from gremlin_python.driver.request import RequestMessage + + +def make_request(gremlin="g.V()", fields=None): + msg = RequestMessage(fields=fields or {"g": "g"}, gremlin=gremlin) + return HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={"accept": "application/vnd.graphbinary-v4.0"}, body=msg) + + +class TestBasicInterceptorExecution: + + def test_interceptor_receives_request_message_in_body(self): + request = make_request() + captured = [] + + def interceptor(req): + captured.append(req.body) + + interceptor(request) + assert len(captured) == 1 + assert isinstance(captured[0], RequestMessage) + + def test_interceptor_can_read_and_modify_headers(self): + request = make_request() + request.headers["x-existing"] = "original" + + def interceptor(req): + assert req.headers["x-existing"] == "original" + req.headers["x-existing"] = "modified" + req.headers["x-new"] = "added" + + interceptor(request) + assert request.headers["x-existing"] == "modified" + assert request.headers["x-new"] == "added" + + def test_interceptor_can_modify_uri(self): + request = make_request() + + def interceptor(req): + req.url = "http://other-host:9999/gremlin" + + interceptor(request) + assert request.url == "http://other-host:9999/gremlin" + + def test_interceptors_run_in_registration_order(self): + request = make_request() + order = [] + + interceptors = [ + lambda req: order.append(1), + lambda req: order.append(2), + lambda req: order.append(3), + ] + + for i in interceptors: + i(request) + + assert order == [1, 2, 3] + + +class TestSerializeBody: + + def test_converts_request_message_to_json_bytes(self): + request = make_request(gremlin="g.V().count()") + result = request.serialize_body() + + parsed = json.loads(result) + assert parsed["gremlin"] == "g.V().count()" + assert parsed["g"] == "g" + + def test_sets_content_type_header(self): + request = make_request() + request.serialize_body() + assert request.headers["content-type"] == "application/json" + + def test_sets_content_length_header(self): + request = make_request() + result = request.serialize_body() + assert request.headers["content-length"] == str(len(result)) + + def test_is_idempotent_with_pre_serialized_bytes(self): + existing = b'{"gremlin":"g.V()"}' + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=existing) + first = request.serialize_body() + second = request.serialize_body() + assert first is existing + assert first is second + + def test_is_idempotent_with_request_message(self): + request = make_request() + first = request.serialize_body() + second = request.serialize_body() + assert first is second + + def test_includes_all_fields_in_json(self): + msg = RequestMessage( + fields={"g": "g", "language": "gremlin-lang", "timeoutMs": 5000}, + gremlin="g.V()" + ) + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=msg) + result = request.serialize_body() + parsed = json.loads(result) + + assert parsed["gremlin"] == "g.V()" + assert parsed["g"] == "g" + assert parsed["language"] == "gremlin-lang" + assert parsed["timeoutMs"] == 5000 + + def test_raises_on_unsupported_body_type(self): + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=42) + try: + request.serialize_body() + assert False, "expected TypeError for unsupported body type" + except TypeError as e: + assert "int" in str(e) + + def test_raises_on_none_body(self): + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=None) + try: + request.serialize_body() + assert False, "expected TypeError for None body" + except TypeError as e: + assert "NoneType" in str(e) + + +class TestFieldMutation: + + def test_interceptor_can_replace_body_before_serialization(self): + request = make_request() + + def interceptor(req): + req.body = RequestMessage(fields={"g": "gmodern"}, gremlin="g.E()") + + interceptor(request) + request.serialize_body() + + parsed = json.loads(request.body) + assert parsed["gremlin"] == "g.E()" + assert parsed["g"] == "gmodern" + + + + +class TestAuthInterceptorOrdering: + + def test_auth_interceptor_is_always_last(self): + """Auth interceptor should always be appended to the end of the interceptor list.""" + from unittest.mock import MagicMock + from gremlin_python.driver.connection import Connection + + order = [] + + def interceptor1(req): + order.append(1) + + def interceptor2(req): + order.append(2) + + def auth_interceptor(req): + order.append(3) + + conn = Connection( + url="http://localhost:8182/gremlin", + traversal_source="g", + executor=MagicMock(), + pool=MagicMock(), + auth=auth_interceptor, + interceptors=[interceptor1, interceptor2], + enable_user_agent_on_connect=False + ) + + # Verify the internal interceptor list has auth at the end + assert len(conn._interceptors) == 3 + assert conn._interceptors[0] is interceptor1 + assert conn._interceptors[1] is interceptor2 + assert conn._interceptors[2] is auth_interceptor + + def test_auth_interceptor_is_last_even_without_other_interceptors(self): + """Auth interceptor works when no other interceptors are provided.""" + from unittest.mock import MagicMock + from gremlin_python.driver.connection import Connection + + def auth_interceptor(req): + pass + + conn = Connection( + url="http://localhost:8182/gremlin", + traversal_source="g", + executor=MagicMock(), + pool=MagicMock(), + auth=auth_interceptor, + enable_user_agent_on_connect=False + ) + + assert len(conn._interceptors) == 1 + assert conn._interceptors[0] is auth_interceptor diff --git a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java index 99007d71155..8a688b6d6bd 100644 --- a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java +++ b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java @@ -210,6 +210,12 @@ private RequestMessage getRequestMessageFromHttpRequest(final FullHttpRequest re final JsonNode txIdNode = body.get(Tokens.ARGS_TRANSACTION_ID); if (null != txIdNode) builder.addTransactionId(txIdNode.asText()); + // bulkResults was previously only sent as an HTTP header by GLV drivers using + // GraphBinary. With the move to JSON requests in 4.x, drivers may include it in + // the body instead. + final JsonNode bulkResultsNode = body.get(Tokens.BULK_RESULTS); + if (null != bulkResultsNode) builder.addBulkResults(bulkResultsNode.asBoolean()); + return builder.create(); } } diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java index dd277cf8dbf..99f24283f24 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java @@ -29,7 +29,6 @@ import org.apache.tinkerpop.gremlin.driver.ResultSet; import org.apache.tinkerpop.gremlin.driver.exception.NoHostAvailableException; import org.apache.tinkerpop.gremlin.driver.exception.ResponseException; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.jsr223.ScriptFileGremlinPlugin; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; @@ -45,8 +44,7 @@ import org.apache.tinkerpop.gremlin.util.ExceptionHelper; import org.apache.tinkerpop.gremlin.util.TimeUtil; import org.apache.tinkerpop.gremlin.util.function.FunctionUtils; -import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; -import org.apache.tinkerpop.gremlin.util.ser.GraphSONMessageSerializerV4; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.apache.tinkerpop.gremlin.util.ser.Serializers; import org.junit.AfterClass; import org.junit.Before; @@ -174,10 +172,7 @@ public void shouldInterceptRequests() throws Exception { final AtomicInteger httpRequests = new AtomicInteger(0); final Cluster cluster = TestClientFactory.build(). - addInterceptor("counter", r -> { - httpRequests.incrementAndGet(); - return r; - }).create(); + interceptors(r -> httpRequests.incrementAndGet()).create(); try { final Client client = cluster.connect(); @@ -195,18 +190,63 @@ public void shouldInterceptRequests() throws Exception { public void shouldRunInterceptorsInOrder() throws Exception { AtomicReference body = new AtomicReference<>(); final Cluster cluster = TestClientFactory.build(). - addInterceptor("first", r -> { - body.set(r.getBody()); - r.setBody(null); - return r; - }). - addInterceptor("second", r -> { - r.setBody(body.get()); - return r; + interceptors( + r -> { + body.set(r.getBody()); + r.setBody(null); + }, + r -> r.setBody(body.get()) + ).create(); + + try { + final Client client = cluster.connect(); + assertEquals(2, client.submit("g.inject(2)").all().get().get(0).getInt()); + } finally { + cluster.close(); + } + } + + @Test + public void shouldAutoSerializeRequestMessageWithInterceptorMutation() throws Exception { + // Verifies the driver auto-serializes when an interceptor modifies the RequestMessage + // body but does not call serializeBody() itself. + final Cluster cluster = TestClientFactory.build(). + interceptors(r -> { + if (r.getBody() instanceof RequestMessage) { + r.setBody(RequestMessage.build("g.inject(99)").create()); + } }).create(); try { final Client client = cluster.connect(); + assertEquals(99, client.submit("g.inject(1)").all().get().get(0).getInt()); + } finally { + cluster.close(); + } + } + + @Test + public void shouldPropagateExceptionThrownDuringInterceptor() throws Exception { + final AtomicInteger callCount = new AtomicInteger(0); + final Cluster cluster = TestClientFactory.build(). + interceptors(r -> { + // Only throw on the first request to verify recovery + if (callCount.incrementAndGet() == 1) { + throw new RuntimeException("interceptor broke"); + } + }).create(); + try { + final Client client = cluster.connect(); + + // First request should fail with interceptor error + try { + client.submit("g.inject(1)").all().get(); + fail("Should have thrown an exception"); + } catch (Exception ex) { + assertTrue(ex.getCause().getMessage().contains("interceptor broke")); + } + + // Subsequent request should succeed, proving the connection is still usable assertEquals(2, client.submit("g.inject(2)").all().get().get(0).getInt()); } finally { cluster.close(); @@ -214,8 +254,39 @@ public void shouldRunInterceptorsInOrder() throws Exception { } @Test - public void shouldWorkWithGraphSONSerializer() throws Exception { - final Cluster cluster = TestClientFactory.build(new PayloadSerializingInterceptor(new GraphSONMessageSerializerV4())) + public void shouldPropagateErrorWhenInterceptorSetsUnsupportedBodyType() throws Exception { + // This error occurs after interceptors run, when serializeBody() encounters + // a body that is neither RequestMessage nor byte[]. + final AtomicInteger callCount = new AtomicInteger(0); + final Cluster cluster = TestClientFactory.build(). + interceptors(r -> { + if (callCount.incrementAndGet() == 1) { + r.setBody(42); + } + }).create(); + try { + final Client client = cluster.connect(); + + // First request should fail with body type error + try { + client.submit("g.inject(1)").all().get(); + fail("Should have thrown an exception"); + } catch (Exception ex) { + final String msg = ex.getCause().getMessage(); + assertTrue(msg.contains("Cannot serialize body of type")); + assertTrue(msg.contains("Integer")); + } + + // Subsequent request should succeed, proving the connection is still usable + assertEquals(3, client.submit("g.inject(3)").all().get().get(0).getInt()); + } finally { + cluster.close(); + } + } + + @Test + public void shouldWorkWithGraphSONResponse() throws Exception { + final Cluster cluster = TestClientFactory.build() .serializer(Serializers.GRAPHSON_V4.simpleInstance()).create(); try { @@ -232,8 +303,8 @@ public void shouldWorkWithGraphSONSerializer() throws Exception { } @Test - public void shouldWorkWithGraphSONRequestAndGraphBinaryResponse() throws Exception { - final Cluster cluster = TestClientFactory.build(new PayloadSerializingInterceptor(new GraphSONMessageSerializerV4())) + public void shouldWorkWithJsonRequestAndGraphBinaryResponse() throws Exception { + final Cluster cluster = TestClientFactory.build() .serializer(Serializers.GRAPHBINARY_V4.simpleInstance()).create(); try { @@ -245,8 +316,8 @@ public void shouldWorkWithGraphSONRequestAndGraphBinaryResponse() throws Excepti } @Test - public void shouldWorkWithGraphBinaryRequestAndGraphSONResponse() throws Exception { - final Cluster cluster = TestClientFactory.build(new PayloadSerializingInterceptor(new GraphBinaryMessageSerializerV4())) + public void shouldWorkWithJsonRequestAndGraphSONResponse() throws Exception { + final Cluster cluster = TestClientFactory.build() .serializer(Serializers.GRAPHSON_V4.simpleInstance()).create(); try { @@ -264,10 +335,7 @@ public void shouldInterceptRequestsWithHandshake() throws Exception { final Cluster cluster = TestClientFactory.build(). maxConnectionPoolSize(1). - addInterceptor("counter", r -> { - handshakeRequests.incrementAndGet(); - return r; - }).create(); + interceptors(r -> handshakeRequests.incrementAndGet()).create(); try { final Client client = cluster.connect(); diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java index 8c200c5128b..897b3740b24 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java @@ -94,11 +94,10 @@ public void shouldPassSigv4ToServer() throws Exception { final AtomicReference httpRequest = new AtomicReference<>(); final Cluster cluster = TestClientFactory.build() - .auth(sigv4("us-west2", credentialsProvider, "service-name")) - .addInterceptor("header-checker", r -> { - httpRequest.set(r); - return r; - }) + .interceptors( + sigv4("us-west2", credentialsProvider, "service-name"), + r -> httpRequest.set(r) + ) .create(); final Client client = cluster.connect(); client.submit("g.inject(2)").all().get(); diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java index 8c31e41c221..2f2e1445642 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java @@ -19,7 +19,6 @@ package org.apache.tinkerpop.gremlin.server; import org.apache.tinkerpop.gremlin.driver.Cluster; -import org.apache.tinkerpop.gremlin.driver.RequestInterceptor; import org.apache.tinkerpop.gremlin.driver.simple.SimpleHttpClient; import java.net.URI; @@ -43,10 +42,6 @@ public static Cluster.Builder build(final String address) { return Cluster.build(address).port(PORT); } - public static Cluster.Builder build(final RequestInterceptor serializingInterceptor) { - return Cluster.build(serializingInterceptor).port(PORT); - } - public static Cluster open() { return build().create(); } diff --git a/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java b/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java index 3ea75f2f7d3..cdc49ff93f8 100644 --- a/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java +++ b/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java @@ -31,15 +31,17 @@ import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerFactory; -import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.apache.tinkerpop.gremlin.util.message.ResponseMessage; import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; import org.apache.tinkerpop.gremlin.util.ser.SerTokens; import org.apache.tinkerpop.gremlin.util.ser.SerializationException; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Random; import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; @@ -56,10 +58,17 @@ public class TestHttpGremlinHandler extends SimpleChannelInboundHandler Date: Fri, 12 Jun 2026 14:24:52 -0700 Subject: [PATCH 2/3] missing import --- .../tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java index 99f24283f24..d6ccc6707ac 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java @@ -45,6 +45,7 @@ import org.apache.tinkerpop.gremlin.util.TimeUtil; import org.apache.tinkerpop.gremlin.util.function.FunctionUtils; import org.apache.tinkerpop.gremlin.util.message.RequestMessage; +import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; import org.apache.tinkerpop.gremlin.util.ser.Serializers; import org.junit.AfterClass; import org.junit.Before; From a637b0e86fa4306d04b90eba4b27434b06a938bd Mon Sep 17 00:00:00 2001 From: Ken Hu <106191785+kenhuuu@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:23:58 -0700 Subject: [PATCH 3/3] fixes to PDT rebase --- gremlin-js/gremlin-javascript/lib/driver/connection.ts | 4 +--- .../src/main/python/gremlin_python/driver/client.py | 2 -- .../tests/unit/structure/io/test_provider_defined_type.py | 1 - 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/gremlin-js/gremlin-javascript/lib/driver/connection.ts b/gremlin-js/gremlin-javascript/lib/driver/connection.ts index d007fd04ea4..1be591ca9a2 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/connection.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/connection.ts @@ -34,8 +34,6 @@ import { HttpRequest, RequestInterceptor } from './http-request.js'; import ResponseError from './response-error.js'; import { Traverser } from '../process/traversal.js'; -const { graphBinaryReader } = ioc; - const responseStatusCode = { success: 200, noContent: 204, @@ -83,7 +81,7 @@ export default class Connection extends EventEmitter { ) { super(); - this._reader = options.reader || (options.preciseNumbers === true ? createPreciseReader() : graphBinaryReader); + this._reader = options.reader || (options.preciseNumbers === true ? createPreciseReader() : new GraphBinaryReader(ioc)); if (options.pdtRegistry) { this._reader.pdtRegistry = options.pdtRegistry; } diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py b/gremlin-python/src/main/python/gremlin_python/driver/client.py index c20556851e0..a32e8f9d0b3 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/client.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py @@ -56,8 +56,6 @@ def __init__(self, url, traversal_source, pool_size=None, max_workers=None, if response_serializer is None: response_serializer = serializer.GraphBinarySerializersV4() if pdt_registry is not None: - if request_serializer is not None: - request_serializer.configure_pdt_registry(pdt_registry) response_serializer.configure_pdt_registry(pdt_registry) self._auth = auth diff --git a/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py b/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py index aac685982c7..4f9b086fc7f 100644 --- a/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py +++ b/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py @@ -230,7 +230,6 @@ def test_client_passes_registry_to_serializers(self): registry = ProviderDefinedTypeRegistry() with patch.object(Client, '_fill_pool'): c = Client("ws://localhost:8182/gremlin", "g", pdt_registry=registry) - assert c._request_serializer._graphbinary_reader.pdt_registry is registry assert c._response_serializer._graphbinary_reader.pdt_registry is registry def test_driver_remote_connection_passes_registry(self):