diff --git a/docs/src/reference/gremlin-variants.asciidoc b/docs/src/reference/gremlin-variants.asciidoc index 7a8405b2924..8003ffb73ac 100644 --- a/docs/src/reference/gremlin-variants.asciidoc +++ b/docs/src/reference/gremlin-variants.asciidoc @@ -271,8 +271,50 @@ More details can be found in provider docs link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_graph_driver_provider_requirements[here].|true |RequestInterceptors |Functions that modify HTTP requests before sending. Used for authentication and custom headers. |empty |PDTRegistry |A `*PDTRegistry` for hydrating and dehydrating <>. |`nil` +|Auth |A single RequestInterceptor for authentication (e.g. `BasicAuth`). Always appended to the end of the interceptor list so it runs last. |nil |========================================================= +[[gremlin-go-interceptors]] +=== RequestInterceptor + +Gremlin-Go allows modification of the underlying HTTP request through `RequestInterceptor` functions. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin-Go is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A `RequestInterceptor` is a function with the signature `func(*HttpRequest) error` that mutates the `HttpRequest` +in place. A slice of these is maintained and will be run sequentially for each request. When creating a +`DriverRemoteConnection` or `Client`, the `RequestInterceptors` field on the settings struct accepts an ordered +slice of interceptors. Order matters, so if one interceptor depends on another's output, ensure they are added in +the correct order. Note that authentication (e.g. `BasicAuth`, `SigV4Auth`) is also implemented using interceptors. +The `auth` convenience on connection settings appends the auth interceptor to the end of the list so it runs last. + +[source,go] +---- +remote, err := gremlingo.NewDriverRemoteConnection("http://localhost:8182/gremlin", + func(settings *gremlingo.DriverRemoteConnectionSettings) { + settings.RequestInterceptors = []gremlingo.RequestInterceptor{ + gremlingo.BasicAuth("username", "password"), + func(req *gremlingo.HttpRequest) error { + req.Headers.Set("X-Custom-Header", "value") + return nil + }, + } + }) +---- + +Each interceptor receives a `*gremlingo.HttpRequest` whose `Body` field initially contains a `*RequestMessage`. +The `RequestMessage` struct has mutable fields (`Gremlin string` and `Fields map[string]interface{}`), so an +interceptor can modify them directly or replace the body entirely. By default, the driver serializes the body +to JSON (`application/json`) by calling `HttpRequest.SerializeBody()` after all interceptors have run. An +interceptor that needs the serialized bytes (for example, to compute a signature hash) can call +`SerializeBody()` itself; the method is idempotent. If you require a GraphBinary-encoded request body instead +of JSON, you can write a custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer +and sets the resulting `[]byte` as the body. + +For an example of a simple `RequestInterceptor` that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-go/driver/auth.go[basic authentication]. + [[gremlin-go-strategies]] === Traversal Strategies @@ -967,9 +1009,11 @@ The following table describes the various configuration options for the Gremlin |connectionPool.trustStore |File location for a SSL Certificate Chain to use when SSL is enabled. If this value is not provided and SSL is enabled, the default `TrustManager` will be used. |_none_ |connectionPool.trustStorePassword |The password of the `trustStore` if it is password-protected |_none_ |connectionPool.validationRequest |A script that is used to test server connectivity. A good script to use is one that evaluates quickly and returns no data. The default simply returns an empty string, but if a graph is required by a particular provider, a good traversal might be `g.inject()`. |_''_ +|auth |An authentication `RequestInterceptor` (e.g. `Auth.basic()`). Always appended to the end of the interceptor list so it runs last. |_none_ |bulkResults |Sets whether the server should attempt to get bulk results or not. |false |enableUserAgentOnConnect |Enables sending a user agent to the server during connection requests. More details can be found in provider docs link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_graph_driver_provider_requirements[here].|true |hosts |The list of hosts that the driver will connect to. |localhost +|interceptors |A list of `RequestInterceptor` instances that modify the HTTP request before sending. |empty |nioPoolSize |Size of the pool for handling request/response operations. |available processors |password |The password to submit on requests that require authentication. |_none_ |path |The URL path to the Gremlin Server. |_/gremlin_ @@ -1273,21 +1317,26 @@ to "g" and "g1" and "g2" are automatically rebound into "g" on the server-side. Gremlin-Java allows for modification of the underlying HTTP request through the use of `RequestInterceptors`. This is intended to be an advanced feature which means that you will need to understand how the implementation works in order -to safely utilize it. Gremlin-Java is written in a way that you should be able to interact with a TinkerPop-enabled -server without having to use interceptors. This is intended for cases where the server has special capabilities. - -A `RequestInterceptor` is simply a `UnaryOperator`. A list of these are maintained and will be run sequentially for -each request. When building a `Cluster` instance, the methods `addInterceptorAfter()`, `addInterceptorBefore()`, -`addInterceptor()`, and `removeInterceptor()` can be used to add or remove interceptors. It's important to remember -that order matters so if one interceptor depends on another's output then ensure they are added in the correct order. -Note that `Auth` is also implemented using interceptors, and `Auth` is always run last after your list of interceptors -has already ran. By default, the `PayloadSerializingInterceptor` with the name `serializer` is added to your list of -interceptors. This interceptor is used for serializing the body of the request. The first interceptor is provided with -a `org.apache.tinkerpop.gremlin.driver.HttpRequest` that contains a `RequestMessage` in the body. As a reminder -`RequestMessage` is immutable and only certain keys can be added to them. If you want to customize the body by adding -other fields, you will need to make a different copy of the `RequestMessage` or completely change the body to contain a -different data type. The very last interceptor should have a `org.apache.tinkerpop.gremlin.driver.HttpRequest` that -contains a byte[] in the body. +to safely utilize it. Gremlin-Java is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A `RequestInterceptor` is a function that mutates the `HttpRequest` in place (it returns nothing). +A list of these is maintained and will be run sequentially for each request. When building a +`Cluster` instance, the `interceptors()` method accepts an ordered list of interceptors. It's +important to remember that order matters so if one interceptor depends on another's output then +ensure they are added in the correct order. Note that `Auth` is also implemented using interceptors. +The `auth()` convenience on the `Cluster.Builder` appends the auth interceptor to the end of the list so it runs last. + +Each interceptor is provided with a `org.apache.tinkerpop.gremlin.driver.HttpRequest` that contains +a `RequestMessage` in the body. As a reminder `RequestMessage` is immutable and only certain keys +can be added to them. If you want to customize the body by adding other fields, you will need to +make a different copy of the `RequestMessage` or completely change the body to contain a different +data type. By default, the driver serializes the body to JSON (`application/json`) by calling +`HttpRequest.serializeBody()` after all interceptors have run. An interceptor that needs the +serialized bytes (for example, to compute a signature hash) can call `serializeBody()` itself; the +method is idempotent. If you require a GraphBinary-encoded request body instead of JSON, you can +write a custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer and +sets the resulting `byte[]` as the body. For an example of a simple `RequestInterceptor` that only modifies the header of the request take a look at link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java[basic authentication]. @@ -1826,6 +1875,7 @@ can be passed in the constructor of a new `Client` or `DriverRemoteConnection` : |options.traversalSource |String |The traversal source. |'g' |options.headers |Object |Additional HTTP header key/values included with each request. |undefined |options.interceptors |RequestInterceptor/RequestInterceptor[] |One or more functions that can modify the HTTP request before it is sent. |undefined +|options.auth |RequestInterceptor |An auth interceptor that is always appended to the end of the interceptor list so it runs last. |undefined |options.preciseNumbers |Boolean |When `true`, wraps deserialized numbers in typed wrappers that preserve the server's original type. |undefined |options.reader |GraphBinaryReader |The reader to use for deserializing responses. |GraphBinaryReader |options.writer |GraphBinaryWriter |The writer to use for serializing requests. |GraphBinaryWriter @@ -1834,6 +1884,47 @@ can be passed in the constructor of a new `Client` or `DriverRemoteConnection` : |options.pdtRegistry |ProviderDefinedTypeRegistry |A registry for hydrating and dehydrating <>. |undefined |========================================================= +[[gremlin-javascript-interceptors]] +=== RequestInterceptor + +Gremlin-JavaScript allows modification of the underlying HTTP request through `RequestInterceptor` functions. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin-JavaScript is written in a way that you should be able to interact with most + TinkerPop-enabled servers without having to use interceptors. This is intended for cases where the server has special + capabilities. + +A `RequestInterceptor` is a function with the signature `(request: HttpRequest) => void | Promise` that mutates +the `HttpRequest` in place (it returns nothing). A list of these is maintained and will be run sequentially for each +request. When creating a `DriverRemoteConnection` or `Client`, the `interceptors` option accepts one interceptor or an +array of interceptors. Order matters, so if one interceptor depends on another's output, ensure they are added in the +correct order. Note that authentication (e.g. `auth.basic()`, `auth.sigv4()`) is also implemented using interceptors. +The `auth` convenience on connection options appends the auth interceptor to the end of the list so it runs last. + +[source,javascript] +---- +const { auth } = gremlin.driver; +const g = traversal().with_(new DriverRemoteConnection('http://localhost:8182/gremlin', { + interceptors: [ + auth.basic('myuser', 'mypassword'), + (request) => { + request.headers['X-Custom-Header'] = 'value'; + } + ] +})); +---- + +Each interceptor receives an `HttpRequest` whose `body` field initially contains a `RequestMessage`. The +`RequestMessage` class is immutable (fields are private, accessed via getters). To customize the body by adding other +fields, you must build a new `RequestMessage` via `RequestMessage.build(gremlin)` and its builder methods, or set a +completely different body (such as a `Buffer`). By default, the driver serializes the body to JSON +(`application/json`) by calling `HttpRequest.serializeBody()` after all interceptors have run. An interceptor that +needs the serialized bytes (for example, to compute a signature hash) can call `serializeBody()` itself; the method +is idempotent. If you require a GraphBinary-encoded request body instead of JSON, you can write a custom interceptor +that serializes the `RequestMessage` with the GraphBinary writer and sets the resulting `Buffer` as the body. + +For an example of a simple `RequestInterceptor` that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-js/gremlin-javascript/lib/driver/auth.ts[basic authentication]. + [[gremlin-javascript-logging]] === Logging @@ -2458,8 +2549,52 @@ The following options can be passed to the `GremlinClient` constructor: |loggerFactory |An `ILoggerFactory` for logging. |`NullLoggerFactory` |interceptors |A list of `Func` that modify HTTP requests before sending. |_none_ |pdtRegistry |A `ProviderDefinedTypeRegistry` for hydrating and dehydrating <>. |`null` +|auth |A single `Func` for authentication. Always appended to the end of the interceptor list so it runs last. |_none_ |========================================================= +[[gremlin-dotnet-interceptors]] +=== RequestInterceptor + +Gremlin.Net allows modification of the underlying HTTP request through request interceptors. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin.Net is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A request interceptor is a `Func` that mutates the `HttpRequestContext` in place (it returns +`Task` for async support but does not produce a value). A list of these is maintained and will be run sequentially for +each request. When creating a `GremlinClient`, the `interceptors` parameter accepts an ordered collection of +interceptors. Order matters, so if one interceptor depends on another's output, ensure they are added in the correct +order. Note that authentication (e.g. `Auth.BasicAuth()`, `Auth.SigV4Auth()`) is also implemented using interceptors. +These factory methods return interceptor delegates that can be included in the `interceptors` list. Alternatively, the +`auth` parameter on `GremlinClient` appends the auth interceptor to the end of the list so it runs last. + +[source,csharp] +---- +var server = new GremlinServer("localhost", 8182); +using var client = new GremlinClient(server, + interceptors: new Func[] + { + Auth.BasicAuth("username", "password"), + context => + { + context.Headers["X-Custom-Header"] = "value"; + return Task.CompletedTask; + } + }); +---- + +Each interceptor receives an `HttpRequestContext` whose `Body` property initially contains a `RequestMessage`. +The `RequestMessage` has an immutable `Gremlin` property and a mutable `Fields` dictionary, so an interceptor can +add or modify fields directly. Alternatively, you can replace the body entirely. By default, the driver serializes +the body to JSON (`application/json`) by calling `HttpRequestContext.SerializeBody()` after all interceptors have +run. An interceptor that needs the serialized bytes (for example, to compute a signature hash) can call +`SerializeBody()` itself; the method is idempotent. If you require a GraphBinary-encoded request body instead of +JSON, you can write a custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer and +sets the resulting `byte[]` as the body. + +For an example of a simple request interceptor that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs[basic authentication]. + [[gremlin-dotnet-logging]] === Logging @@ -3058,7 +3193,7 @@ can be passed to the `Client` or `DriverRemoteConnection` instance as keyword ar |request_serializer |The request serializer implementation.|`gremlin_python.driver.serializer.GraphBinarySerializersV4` |response_serializer |The response serializer implementation.|`gremlin_python.driver.serializer.GraphBinarySerializersV4` |interceptors |The request interceptors to run after request serialization.|`None` -|auth |The authentication scheme to use when submitting requests that require authentication. |`None` +|auth |An authentication interceptor. Always appended to the end of the interceptor list so it runs last. |`None` |pool_size |The number of connections used by the pool. |4 |enable_user_agent_on_connect |Enables sending a user agent to the server during connection requests. More details can be found in provider docs @@ -3080,6 +3215,47 @@ g = traversal().with_( read_timeout=30)) ---- +[[gremlin-python-interceptors]] +=== RequestInterceptor + +Gremlin-Python allows modification of the underlying HTTP request through request interceptors. This is +intended to be an advanced feature which means that you will need to understand how the implementation works in order +to safely utilize it. Gremlin-Python is written in a way that you should be able to interact with most TinkerPop-enabled +servers without having to use interceptors. This is intended for cases where the server has special capabilities. + +A request interceptor is a callable that receives an `HttpRequest` and mutates it in place (it returns nothing). +A list of these is maintained and will be run sequentially for each request. When creating a +`DriverRemoteConnection` or `Client`, the `interceptors` keyword argument accepts an ordered list of interceptors. +Order matters, so if one interceptor depends on another's output, ensure they are added in the correct order. Note +that authentication (e.g. `basic()`, `sigv4()`) is also implemented using interceptors. The `auth` convenience +parameter on `DriverRemoteConnection` and `Client` appends the auth interceptor to the end of the list so it +runs last. + +[source,python] +---- +from gremlin_python.driver.auth import basic + +g = traversal().with_(DriverRemoteConnection( + 'http://localhost:8182/gremlin', 'g', + auth=basic('username', 'password'), + interceptors=[ + lambda req: req.headers.update({'X-Custom-Header': 'value'}) + ])) +---- + +Each interceptor receives an `HttpRequest` whose `body` attribute initially contains a `RequestMessage`. +The `RequestMessage` is a named tuple with mutable `fields` (a `dict`) and a `gremlin` string. Because +`fields` is a regular dictionary, an interceptor can add or modify entries directly. To replace the body +entirely, assign a new value to `req.body`. By default, the driver serializes the body to JSON +(`application/json`) by calling `HttpRequest.serialize_body()` after all interceptors have run. An interceptor +that needs the serialized bytes (for example, to compute a signature hash) can call `serialize_body()` itself; +the method is idempotent. If you require a GraphBinary-encoded request body instead of JSON, you can write a +custom interceptor that serializes the `RequestMessage` with the GraphBinary serializer and sets the resulting +`bytes` as the body. + +For an example of a simple request interceptor that only modifies the header of the request, see +link:https://github.com/apache/tinkerpop/blob/x.y.z/gremlin-python/src/main/python/gremlin_python/driver/auth.py[basic authentication]. + [[gremlin-python-strategies]] === Traversal Strategies diff --git a/docs/src/upgrade/release-4.x.x.asciidoc b/docs/src/upgrade/release-4.x.x.asciidoc index ac8debb0a99..0f78dde55b9 100644 --- a/docs/src/upgrade/release-4.x.x.asciidoc +++ b/docs/src/upgrade/release-4.x.x.asciidoc @@ -32,6 +32,37 @@ complete list of all the modifications that are part of this release. === Upgrading for Users +==== Request Interceptors + +When TinkerPop supported WebSockets prior to 4.0.0, the Java driver offered a `RequestInterceptor` interface (and +its predecessor, `HandshakeInterceptor`) that allowed modification of the raw Netty `FullHttpRequest`. For WebSocket +connections, the interceptor only ran on the initial HTTP upgrade handshake. With the move to HTTP, the notion of +the "interceptor" has shifted to a per-request concern that is now standardized across all GLVs. + +All GLVs now support request interceptors, which allow modification of the HTTP request before it is sent to the +server. An interceptor is a function that receives the mutable HTTP request object and can modify headers, the +request body, the URI, or the HTTP method. Interceptors are run in the order they are registered. + +The most common use case for interceptors is authentication (e.g., SigV4 signing), but they can also be used to add +provider-specific fields, inject custom headers, or transform the request body. + +Here is a simple Java example that adds a custom header: + +[source,java] +---- +Cluster cluster = Cluster.build("localhost") + .interceptors(request -> request.headers().put("X-Custom-Header", "value")) + .create(); +---- + +Authentication is also an interceptor. Each GLV provides convenience methods (e.g., `Auth.basic()`, `Auth.sigv4()`) +that return interceptors and can be registered alongside custom ones. + +For full details on the interceptor API for each language variant, refer to the RequestInterceptor section in +each GLV's documentation in the +link:https://tinkerpop.apache.org/docs/x.y.z/reference/#gremlin-drivers-variants[Gremlin Drivers and Variants] +reference. + ==== Gremlator link:https://gremlator.com[Gremlator] has been rebuilt entirely in JavaScript as a browser-based single-page @@ -567,6 +598,20 @@ registration. ==== Graph Driver Providers +===== Request Interceptors + +Graph driver providers should implement request interceptor support in their drivers. Interceptors allow users to +modify the HTTP request before it is sent, which is essential for authentication schemes (like SigV4), adding +provider-specific request fields, and other server-specific capabilities. + +All TinkerPop reference drivers now include interceptor support. The interceptor contract is standardized across all +GLVs: interceptors receive a mutable HTTP request object, can modify headers/body/URI, and the driver auto-serializes +the request body to JSON after all interceptors have run. + +For the full specification of how interceptors should behave, see the +link:https://tinkerpop.apache.org/docs/x.y.z/dev/provider/#_http_request_interceptor[HTTP Request Interceptor] +section in the provider documentation. + == TinkerPop 4.0.0-beta.2 *Release Date: April 1, 2026* diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs index e10aafe55fb..eaf706f6dda 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Auth.cs @@ -94,6 +94,9 @@ public static Func SigV4Auth( } } + // Ensure the body is serialized before signing so we have bytes to hash. + context.SerializeBody(); + // Use the async path — important for credential providers that perform // network I/O (e.g. IMDS on EC2, ECS task role endpoint). var immutableCreds = await cachedProvider.GetCredentialsAsync() @@ -116,7 +119,7 @@ private static void SignRequest(HttpRequestContext context, ? bytes : throw new InvalidOperationException( "SigV4 signing requires Body to be byte[]. " + - "Ensure serialization occurs before the SigV4 interceptor."), + "Ensure SerializeBody() was called before signing."), AuthenticationRegion = clientConfig.AuthenticationRegion, OverrideSigningServiceName = clientConfig.AuthenticationServiceName, }; diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs index e5c94611d4a..82a0d17724a 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Connection.cs @@ -32,8 +32,6 @@ using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; using Gremlin.Net.Process; -using Gremlin.Net.Process.Traversal; -using Gremlin.Net.Structure.IO; namespace Gremlin.Net.Driver { @@ -44,7 +42,6 @@ internal class Connection : IDisposable { private readonly HttpClient _httpClient; private readonly Uri _uri; - private readonly IMessageSerializer? _requestSerializer; private readonly IMessageSerializer _responseSerializer; private readonly ConnectionSettings _settings; private readonly IReadOnlyList> _interceptors; @@ -55,24 +52,15 @@ internal class Connection : IDisposable /// so a single instance handles concurrent requests efficiently. /// /// The Gremlin Server URI. - /// - /// The serializer for outgoing requests. When non-null, the request body is serialized - /// to byte[] before interceptors run and the Content-Type header is set - /// automatically. When null, the body is passed as a - /// and an interceptor is responsible for serializing it to byte[] and setting - /// Content-Type. This follows the Python driver's request_serializer=None - /// pattern. - /// /// The serializer for incoming responses (always required). /// Connection settings. /// Optional request interceptors. - public Connection(Uri uri, IMessageSerializer? requestSerializer, + public Connection(Uri uri, IMessageSerializer responseSerializer, ConnectionSettings settings, IReadOnlyList>? interceptors = null) { _uri = uri; - _requestSerializer = requestSerializer; _responseSerializer = responseSerializer; _settings = settings; _interceptors = interceptors ?? Array.Empty>(); @@ -97,13 +85,12 @@ public Connection(Uri uri, IMessageSerializer? requestSerializer, /// /// Constructor that accepts a pre-configured HttpClient (for testing). /// - internal Connection(Uri uri, IMessageSerializer? requestSerializer, + internal Connection(Uri uri, IMessageSerializer responseSerializer, ConnectionSettings settings, HttpClient httpClient, IReadOnlyList>? interceptors = null) { _uri = uri; - _requestSerializer = requestSerializer; _responseSerializer = responseSerializer; _settings = settings; _httpClient = httpClient; @@ -139,7 +126,7 @@ public async Task> SubmitAsync(RequestMessage requestMessage, headers["bulkResults"] = "true"; } - // Promote transactionId to HTTP header before serialization. + // Promote transactionId to HTTP header before interceptors run. // The field remains in the serialized body as well (dual transmission // per the HTTP transaction protocol specification). if (requestMessage.Fields.TryGetValue(Tokens.ArgsTransactionId, out var txIdObj) && @@ -148,26 +135,20 @@ public async Task> SubmitAsync(RequestMessage requestMessage, headers["X-Transaction-Id"] = txId; } - object body; - if (_requestSerializer != null) - { - var requestBytes = await _requestSerializer.SerializeMessageAsync(requestMessage, cancellationToken) - .ConfigureAwait(false); - body = requestBytes; - headers["Content-Type"] = _requestSerializer.MimeType; - } - else - { - body = requestMessage; - } - - var context = new HttpRequestContext("POST", _uri, headers, body); + var context = new HttpRequestContext("POST", _uri, headers, requestMessage); foreach (var interceptor in _interceptors) { await interceptor(context).ConfigureAwait(false); } + // Auto-serialize after interceptors: idempotent if already serialized by an interceptor. + // Skip if body is HttpContent (an escape hatch for full wire-format control). + if (context.Body is not System.Net.Http.HttpContent) + { + context.SerializeBody(); + } + // The HttpResponseMessage is NOT disposed here — ownership transfers to // StreamingResponseContext via the background task. HttpResponseMessage response; @@ -184,10 +165,8 @@ public async Task> SubmitAsync(RequestMessage requestMessage, else { throw new InvalidOperationException( - "Request body must be byte[] or HttpContent after all interceptors complete, " + - "but found " + (context.Body?.GetType().Name ?? "null") + - ". Either provide a requestSerializer or add an interceptor " + - "that serializes the RequestMessage."); + "Request body must be byte[] or HttpContent after serialization, " + + "but found " + (context.Body?.GetType().Name ?? "null") + "."); } foreach (var header in context.Headers) @@ -196,6 +175,10 @@ public async Task> SubmitAsync(RequestMessage requestMessage, { httpRequest.Content.Headers.ContentType = new MediaTypeHeaderValue(header.Value); } + else if (string.Equals(header.Key, "Content-Length", StringComparison.OrdinalIgnoreCase)) + { + // Content-Length is set automatically by ByteArrayContent; skip to avoid conflict. + } else { httpRequest.Headers.TryAddWithoutValidation(header.Key, header.Value); @@ -345,5 +328,7 @@ protected virtual void Dispose(bool disposing) } #endregion + + internal IReadOnlyList> Interceptors => _interceptors; } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs index 4599e47de5e..402a8f77111 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/GremlinClient.cs @@ -24,6 +24,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Gremlin.Net.Driver.Messages; @@ -48,86 +49,64 @@ public class GremlinClient : IGremlinClient /// Initializes a new instance of the class for the specified Gremlin Server. /// /// The the requests should be sent to. - /// - /// A instance to serialize outgoing request messages. - /// When null, the request body is passed as a to - /// interceptors, and an interceptor must serialize it to byte[] and set the - /// Content-Type header. This follows the Python driver's - /// request_serializer=None pattern. - /// /// /// A instance to deserialize incoming response messages. - /// Always required. + /// Defaults to . /// /// The for the HTTP connection. /// A factory to create loggers. If not provided, then nothing will be logged. /// /// An optional list of request interceptors. Each interceptor receives a mutable /// and can modify headers, body, URI, and method - /// before the request is sent. + /// before the request is sent. Interceptors that need the serialized bytes (e.g. + /// for payload signing) should call . + /// + /// + /// An optional auth interceptor. As a convenience, this is appended to the end of the + /// interceptor list so it runs last (after any user interceptors have modified the request). + /// This is equivalent to including the auth interceptor as the last element of . /// /// /// An optional for automatic hydration of /// provider-defined types. /// - public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? requestSerializer, - IMessageSerializer responseSerializer, + public GremlinClient(GremlinServer gremlinServer, + IMessageSerializer? responseSerializer = null, ConnectionSettings? connectionSettings = null, ILoggerFactory? loggerFactory = null, IReadOnlyList>? interceptors = null, + Func? auth = null, ProviderDefinedTypeRegistry? pdtRegistry = null) { connectionSettings ??= new ConnectionSettings(); LoggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + var actualResponseSerializer = responseSerializer ?? new GraphBinary4MessageSerializer(); + if (pdtRegistry != null) { - requestSerializer?.SetPdtRegistry(pdtRegistry); - responseSerializer.SetPdtRegistry(pdtRegistry); + actualResponseSerializer.SetPdtRegistry(pdtRegistry); + } + + // Append auth interceptor to the end of the list so it runs last. + IReadOnlyList>? allInterceptors = interceptors; + if (auth != null) + { + var list = interceptors?.ToList() ?? new List>(); + list.Add(auth); + allInterceptors = list; } _connection = new Connection( gremlinServer.Uri, - requestSerializer, - responseSerializer, + actualResponseSerializer, connectionSettings, - interceptors); + allInterceptors); var logger = LoggerFactory.CreateLogger(); logger.InitializedHttpConnection(gremlinServer.Uri); } - /// - /// Initializes a new instance of the class with a single - /// serializer used for both request serialization and response deserialization. - /// This is the backward-compatible convenience constructor. - /// - /// The the requests should be sent to. - /// - /// A instance used for both request serialization and - /// response deserialization. Defaults to . - /// - /// The for the HTTP connection. - /// A factory to create loggers. If not provided, then nothing will be logged. - /// - /// An optional list of request interceptors. - /// - /// - /// An optional for automatic hydration of - /// provider-defined types. - /// - public GremlinClient(GremlinServer gremlinServer, IMessageSerializer? messageSerializer = null, - ConnectionSettings? connectionSettings = null, - ILoggerFactory? loggerFactory = null, - IReadOnlyList>? interceptors = null, - ProviderDefinedTypeRegistry? pdtRegistry = null) - : this(gremlinServer, - messageSerializer ?? new GraphBinary4MessageSerializer(), - messageSerializer ?? new GraphBinary4MessageSerializer(), - connectionSettings, loggerFactory, interceptors, pdtRegistry) - { - } - /// public async Task> SubmitAsync(RequestMessage requestMessage, CancellationToken cancellationToken = default) @@ -199,5 +178,7 @@ internal void UntrackTransaction(RemoteTransaction tx) } #endregion + + internal Connection Connection => _connection; } } diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs index 9fc170b45fe..ac98a91d8f7 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/HttpRequestContext.cs @@ -24,6 +24,8 @@ using System; using System.Collections.Generic; using System.Security.Cryptography; +using System.Text.Json; +using Gremlin.Net.Driver.Messages; namespace Gremlin.Net.Driver { @@ -48,10 +50,10 @@ public class HttpRequestContext public Dictionary Headers { get; } /// - /// Gets or sets the request body. This is byte[] when serialization has occurred - /// (default path), or RequestMessage when serialization is deferred to interceptors - /// (requestSerializer = null). Interceptors may also set this to an - /// instance for full control over the wire format. + /// Gets or sets the request body. Initially a , becomes + /// byte[] after is called. Interceptors may also + /// set this to an instance for full control + /// over the wire format. /// public object Body { get; set; } @@ -61,8 +63,8 @@ public class HttpRequestContext /// The HTTP method. /// The request URI. /// The HTTP headers. - /// The request body. Typically byte[] (post-serialization) or - /// RequestMessage (pre-serialization). + /// The request body. Typically RequestMessage (pre-serialization) + /// or byte[] (post-serialization). public HttpRequestContext(string method, Uri uri, Dictionary headers, object body) { Method = method ?? throw new ArgumentNullException(nameof(method)); @@ -71,6 +73,46 @@ public HttpRequestContext(string method, Uri uri, Dictionary hea Body = body; } + /// + /// Serializes the body to JSON if it is still a . + /// Sets the body to the resulting byte[], and sets the Content-Type + /// and Content-Length headers. This method is idempotent: if the body is + /// already byte[], it returns those bytes without re-serializing. + /// + /// The serialized body bytes. + /// + /// Thrown if the body is neither nor byte[]. + /// + public byte[] SerializeBody() + { + if (Body is byte[] existing) + { + return existing; + } + + if (Body is RequestMessage message) + { + var payload = new Dictionary + { + [Tokens.ArgsGremlin] = message.Gremlin + }; + foreach (var field in message.Fields) + { + payload[field.Key] = field.Value; + } + + var jsonBytes = JsonSerializer.SerializeToUtf8Bytes(payload); + Body = jsonBytes; + Headers["Content-Type"] = "application/json"; + Headers["Content-Length"] = jsonBytes.Length.ToString(); + return jsonBytes; + } + + throw new InvalidOperationException( + "Cannot serialize body of type " + (Body?.GetType().Name ?? "null") + + ". Body must be RequestMessage or byte[]."); + } + /// /// Returns the lowercase hex-encoded SHA-256 digest of the body. /// Throws if is not byte[], diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs index 831a8ee379a..4d69834e90a 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/IMessageSerializer.cs @@ -31,23 +31,24 @@ namespace Gremlin.Net.Driver { /// - /// Serializes data to and from Gremlin Server. + /// Serializes and deserializes data to and from Gremlin Server. /// public interface IMessageSerializer { /// /// Gets the MIME type produced by this serializer (e.g. /// "application/vnd.graphbinary-v4.0"). Used by the driver to set - /// Content-Type and Accept headers automatically. + /// the Accept header. /// string MimeType { get; } /// - /// Serializes a . + /// Serializes a to bytes. This can be called from + /// interceptors to produce a serialized request body. /// /// The to serialize. /// The token to cancel the operation. The default value is None. - /// The serialized message. + /// The serialized message bytes. Task SerializeMessageAsync(RequestMessage requestMessage, CancellationToken cancellationToken = default); diff --git a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs index e6d5c032f90..3112bf783ce 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Driver/Remote/DriverRemoteConnection.cs @@ -137,11 +137,11 @@ public async Task> SubmitAsync(GremlinLan } } - // Default bulkResults to "true" if not set per-request + // Default bulkResults to true if not set per-request // (consistent with Java RequestOptions.fromGremlinLang and Python extract_request_options) if (!requestMsg.HasField(Tokens.ArgsBulkResults)) { - requestMsg.AddField(Tokens.ArgsBulkResults, "true"); + requestMsg.AddField(Tokens.ArgsBulkResults, true); } var resultSet = await _client.SubmitAsync(requestMsg.Create(), cancellationToken) diff --git a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs index 675f01ee2ec..1c35bb822d6 100644 --- a/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs +++ b/gremlin-dotnet/src/Gremlin.Net/Structure/IO/GraphBinary4/RequestMessageSerializer.cs @@ -1,4 +1,4 @@ -#region License +#region License /* * Licensed to the Apache Software Foundation (ASF) under one @@ -31,6 +31,12 @@ namespace Gremlin.Net.Structure.IO.GraphBinary4 /// /// Serializes a in the GraphBinary 4.0 wire format. /// + /// + /// This serializer is no longer used by the default driver flow (which now serializes + /// requests as JSON via ), + /// but remains available for custom interceptors or alternative transport protocols that + /// require GraphBinary-encoded requests. + /// public class RequestMessageSerializer { /// diff --git a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs index a5fdac7c8e3..6370a5bb185 100644 --- a/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.IntegrationTest/Driver/GremlinClientTests.cs @@ -272,5 +272,94 @@ public async Task ShouldHandlePdtInCollection() Assert.Equal(3, p2.Fields["x"]); Assert.Equal(4, p2.Fields["y"]); } + + [Fact] + public async Task ShouldAutoSerializeRequestMessageWithInterceptorMutation() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + var interceptors = new List> + { + ctx => + { + if (ctx.Body is RequestMessage msg) + { + var g = msg.Fields.ContainsKey("g") ? (string)msg.Fields["g"] : "g"; + ctx.Body = RequestMessage.Build("g.inject(99)").AddG(g).Create(); + } + return Task.CompletedTask; + } + }; + + using var gremlinClient = new GremlinClient(gremlinServer, interceptors: interceptors); + + var response = await gremlinClient.SubmitWithSingleResultAsync("g.inject(1)"); + Assert.Equal(99, response); + } + + [Fact] + public async Task ShouldPropagateExceptionThrownDuringInterceptor() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + var callCount = 0; + var interceptors = new List> + { + ctx => + { + callCount++; + if (callCount == 1) + { + throw new InvalidOperationException("interceptor broke"); + } + return Task.CompletedTask; + } + }; + + using var gremlinClient = new GremlinClient(gremlinServer, interceptors: interceptors); + + // First request should fail with interceptor error + var ex = await Assert.ThrowsAsync(async () => + { + var resultSet = await gremlinClient.SubmitAsync("g.inject(1)"); + await resultSet.ToListAsync(); + }); + Assert.Contains("interceptor broke", ex.Message); + + // Subsequent request should succeed, proving connection recovery + var response = await gremlinClient.SubmitWithSingleResultAsync("g.inject(2)"); + Assert.Equal(2, response); + } + + [Fact] + public async Task ShouldPropagateErrorWhenInterceptorSetsUnsupportedBodyType() + { + var gremlinServer = new GremlinServer(TestHost, TestPort); + var callCount = 0; + var interceptors = new List> + { + ctx => + { + callCount++; + if (callCount == 1) + { + ctx.Body = 42; + } + return Task.CompletedTask; + } + }; + + using var gremlinClient = new GremlinClient(gremlinServer, interceptors: interceptors); + + // First request should fail with serialization error + var ex = await Assert.ThrowsAsync(async () => + { + var resultSet = await gremlinClient.SubmitAsync("g.inject(1)"); + await resultSet.ToListAsync(); + }); + Assert.Contains("Cannot serialize body", ex.Message); + + // Subsequent request should succeed, proving connection recovery + var response = await gremlinClient.SubmitWithSingleResultAsync("g.inject(2)"); + Assert.Equal(2, response); + } } } diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs index 6a8c2b78569..56010293637 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/ConnectionTests.cs @@ -33,7 +33,6 @@ using System.Threading.Tasks; using Gremlin.Net.Driver; using Gremlin.Net.Driver.Messages; -using Gremlin.Net.Process.Traversal; using Gremlin.Net.Structure.IO; using NSubstitute; using Xunit; @@ -85,12 +84,12 @@ public async Task ShouldSetContentTypeHeader() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(handler.CapturedRequest); - Assert.Equal(SerializationTokens.GraphBinary4MimeType, + Assert.Equal("application/json", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); } @@ -100,7 +99,7 @@ public async Task ShouldSetAcceptHeader() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -115,7 +114,7 @@ public async Task ShouldSendPostRequest() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -129,7 +128,7 @@ public async Task ShouldSendToCorrectUri() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -143,7 +142,7 @@ public async Task ShouldSetAcceptEncodingWhenCompressionEnabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -158,7 +157,7 @@ public async Task ShouldNotSetAcceptEncodingWhenCompressionDisabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = false }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -173,7 +172,7 @@ public async Task ShouldSetUserAgentWhenEnabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableUserAgentOnConnect = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -187,7 +186,7 @@ public async Task ShouldNotSetUserAgentWhenDisabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableUserAgentOnConnect = false }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -201,7 +200,7 @@ public async Task ShouldSetBulkResultsHeaderWhenEnabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { BulkResults = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -216,7 +215,7 @@ public async Task ShouldNotSetBulkResultsHeaderWhenDisabled() var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { BulkResults = false }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -242,7 +241,7 @@ public async Task ShouldDecompressDeflateResponse() var (httpClient, handler) = CreateMockHttpClient(compressedBytes, "deflate"); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings { EnableCompression = true }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); // Should not throw — decompression should work var result = await connection.SubmitAsync(CreateTestRequest()); @@ -256,7 +255,7 @@ public void ShouldDisposeWithoutError() var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + var connection = new Connection(TestUri, serializer, settings, httpClient); connection.Dispose(); // Double dispose should not throw @@ -280,8 +279,6 @@ private static IMessageSerializer CreateMockSerializer( { var serializer = Substitute.For(); serializer.MimeType.Returns(mimeType); - serializer.SerializeMessageAsync(Arg.Any(), Arg.Any()) - .Returns(Task.FromResult(new byte[] { 0x84 })); serializer.DeserializeMessageAsync(Arg.Any(), Arg.Any()) .Returns(callInfo => ToAsyncEnumerable(results)); return serializer; @@ -311,7 +308,7 @@ public async Task ShouldCallInterceptorsInOrder() ctx => { callOrder.Add(3); return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -331,7 +328,7 @@ public async Task ShouldPropagateInterceptorException() _ => throw expectedException, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var ex = await Assert.ThrowsAsync( () => connection.SubmitAsync(CreateTestRequest())); @@ -357,7 +354,7 @@ public async Task ShouldAllowInterceptorToModifyHeaders() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -390,7 +387,7 @@ public async Task ShouldSeeEarlierInterceptorModifications() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, interceptors); + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -401,7 +398,7 @@ public async Task ShouldSeeEarlierInterceptorModifications() } [Fact] - public async Task ShouldSerializeBeforeInterceptorsWhenRequestSerializerProvided() + public async Task ShouldPassRequestMessageToInterceptorsBeforeSerialization() { var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -417,66 +414,43 @@ public async Task ShouldSerializeBeforeInterceptorsWhenRequestSerializerProvided }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); - Assert.IsType(observedBody); + Assert.IsType(observedBody); } [Fact] - public async Task ShouldPassRequestMessageWhenRequestSerializerIsNull() + public async Task ShouldThrowWhenBodyIsUnsupportedTypeAfterInterceptors() { - var (httpClient, _) = CreateMockHttpClient(); + var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - object? observedBody = null; var interceptors = new List> { ctx => { - observedBody = ctx.Body; - // Serialize the body so the request can proceed - ctx.Body = new byte[] { 0x84 }; - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; + ctx.Body = "unsupported type"; return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, - interceptors); - - await connection.SubmitAsync(CreateTestRequest()); - - Assert.IsType(observedBody); - } - - [Fact] - public async Task ShouldThrowWhenBodyIsNotByteArrayAfterInterceptors() - { - var (httpClient, handler) = CreateMockHttpClient(); - var serializer = CreateMockSerializer(); - var settings = new ConnectionSettings(); - - // No interceptor serializes the body - var interceptors = new List>(); - - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var ex = await Assert.ThrowsAsync( () => connection.SubmitAsync(CreateTestRequest())); - Assert.Contains("byte[] or HttpContent", ex.Message); - Assert.Contains("RequestMessage", ex.Message); + Assert.Contains("String", ex.Message); // HTTP request should not have been sent Assert.Null(handler.CapturedRequest); } [Fact] - public async Task ShouldSucceedWhenInterceptorSerializesBodyWithNullRequestSerializer() + public async Task ShouldSucceedWhenInterceptorPreSerializesBody() { var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -484,17 +458,15 @@ public async Task ShouldSucceedWhenInterceptorSerializesBodyWithNullRequestSeria var interceptors = new List> { - async ctx => + ctx => { - if (ctx.Body is RequestMessage msg) - { - ctx.Body = await serializer.SerializeMessageAsync(msg); - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; - } + // Interceptor calls SerializeBody() early (e.g. for signing) + ctx.SerializeBody(); + return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -504,7 +476,7 @@ public async Task ShouldSucceedWhenInterceptorSerializesBodyWithNullRequestSeria } [Fact] - public async Task ShouldNotSetContentTypeWhenRequestSerializerIsNull() + public async Task ShouldNotSetContentTypeBeforeInterceptorsRun() { var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -516,21 +488,18 @@ public async Task ShouldNotSetContentTypeWhenRequestSerializerIsNull() ctx => { hadContentType = ctx.Headers.ContainsKey("Content-Type"); - // Serialize so the request can proceed - ctx.Body = new byte[] { 0x84 }; - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); - Assert.False(hadContentType, "Content-Type should not be set before interceptors when requestSerializer is null"); + Assert.False(hadContentType, "Content-Type should not be set before interceptors run"); Assert.NotNull(handler.CapturedRequest); - Assert.Equal("application/vnd.graphbinary-v4.0", + Assert.Equal("application/json", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); } @@ -542,7 +511,7 @@ public async Task ShouldWorkWithEmptyInterceptorList() var settings = new ConnectionSettings(); var interceptors = new List>(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -558,7 +527,7 @@ public async Task ShouldWorkWithNoInterceptorsParameter() var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -583,7 +552,7 @@ public async Task ShouldAllowInterceptorToModifyUri() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -609,7 +578,7 @@ public async Task ShouldAllowInterceptorToReplaceBody() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -637,7 +606,7 @@ public async Task ShouldStopInterceptorChainOnException() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await Assert.ThrowsAsync( @@ -666,7 +635,7 @@ public async Task ShouldSupportAsyncInterceptors() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -678,7 +647,7 @@ public async Task ShouldSupportAsyncInterceptors() } [Fact] - public async Task ShouldAllowInterceptorToReadSerializedBody() + public async Task ShouldAllowInterceptorToCallSerializeBody() { var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -689,19 +658,19 @@ public async Task ShouldAllowInterceptorToReadSerializedBody() { ctx => { - capturedBody = ctx.Body as byte[]; + capturedBody = ctx.SerializeBody(); return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(capturedBody); - // The mock serializer returns { 0x84 } - Assert.Equal(new byte[] { 0x84 }, capturedBody); + // Should be valid JSON containing "gremlin" field + Assert.Contains("gremlin", System.Text.Encoding.UTF8.GetString(capturedBody!)); } [Fact] @@ -721,7 +690,7 @@ public async Task ShouldWorkWithSingleInterceptor() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -746,7 +715,7 @@ public async Task ShouldAllowInterceptorToRemoveHeader() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -772,7 +741,7 @@ public async Task ShouldThrowWhenBodyIsNullAfterInterceptors() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); var ex = await Assert.ThrowsAsync( @@ -808,7 +777,7 @@ public async Task ShouldPreserveMultipleInterceptorHeaderModifications() }, }; - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -820,7 +789,7 @@ public async Task ShouldPreserveMultipleInterceptorHeaderModifications() } [Fact] - public async Task ShouldAllowCustomContentTypeWhenRequestSerializerIsNull() + public async Task ShouldAllowInterceptorToOverrideContentType() { var (httpClient, handler) = CreateMockHttpClient(); var serializer = CreateMockSerializer(); @@ -830,53 +799,37 @@ public async Task ShouldAllowCustomContentTypeWhenRequestSerializerIsNull() { ctx => { + // Pre-serialize with custom content type ctx.Body = new byte[] { 0x01 }; - ctx.Headers["Content-Type"] = "application/json"; + ctx.Headers["Content-Type"] = "application/custom"; return Task.CompletedTask; }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(handler.CapturedRequest); - Assert.Equal("application/json", + Assert.Equal("application/custom", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); } [Fact] - public async Task ShouldUseResponseSerializerWhenRequestSerializerIsNull() + public async Task ShouldUseResponseSerializerForDeserialization() { var (httpClient, _) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer(); var responseSerializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - var interceptors = new List> - { - async ctx => - { - if (ctx.Body is RequestMessage msg) - { - ctx.Body = await requestSerializer.SerializeMessageAsync(msg); - ctx.Headers["Content-Type"] = "application/vnd.graphbinary-v4.0"; - } - }, - }; - - using var connection = new Connection(TestUri, null, responseSerializer, settings, httpClient, - interceptors); + using var connection = new Connection(TestUri, responseSerializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); // Verify the response serializer was called for deserialization responseSerializer.Received(1) .DeserializeMessageAsync(Arg.Any(), Arg.Any()); - // Verify the request serializer was NOT called by Connection (interceptor called it directly) - requestSerializer.DidNotReceive() - .DeserializeMessageAsync(Arg.Any(), Arg.Any()); } [Fact] @@ -896,7 +849,7 @@ public async Task ShouldAcceptHttpContentBodyFromInterceptor() }, }; - using var connection = new Connection(TestUri, null, serializer, settings, httpClient, + using var connection = new Connection(TestUri, serializer, settings, httpClient, interceptors); await connection.SubmitAsync(CreateTestRequest()); @@ -910,10 +863,9 @@ public async Task ShouldAcceptHttpContentBodyFromInterceptor() public async Task ShouldUseResponseSerializerMimeTypeForAcceptHeader() { var (httpClient, handler) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer("application/custom-request"); var responseSerializer = CreateMockSerializer("application/custom-response"); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); + using var connection = new Connection(TestUri, responseSerializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); @@ -923,39 +875,18 @@ public async Task ShouldUseResponseSerializerMimeTypeForAcceptHeader() } [Fact] - public async Task ShouldUseRequestSerializerMimeTypeForContentTypeHeader() - { - var (httpClient, handler) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer("application/custom-request"); - var responseSerializer = CreateMockSerializer("application/custom-response"); - var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); - - await connection.SubmitAsync(CreateTestRequest()); - - Assert.NotNull(handler.CapturedRequest); - Assert.Equal("application/custom-request", - handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); - } - - [Fact] - public async Task ShouldUseDifferentMimeTypesForRequestAndResponseSerializers() + public async Task ShouldSetContentTypeToApplicationJson() { var (httpClient, handler) = CreateMockHttpClient(); - var requestSerializer = CreateMockSerializer("application/vnd.custom-request-v1.0"); - var responseSerializer = CreateMockSerializer("application/vnd.custom-response-v2.0"); + var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, requestSerializer, responseSerializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); await connection.SubmitAsync(CreateTestRequest()); Assert.NotNull(handler.CapturedRequest); - // Content-Type comes from request serializer - Assert.Equal("application/vnd.custom-request-v1.0", + Assert.Equal("application/json", handler.CapturedRequest!.Content!.Headers.ContentType!.MediaType); - // Accept comes from response serializer - Assert.Contains(handler.CapturedRequest.Headers.Accept, - h => h.MediaType == "application/vnd.custom-response-v2.0"); } [Fact] @@ -964,7 +895,7 @@ public async Task ShouldReturnStreamingResultSet() var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(new List { "hello", 42 }); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -980,7 +911,7 @@ public async Task ShouldReturnEmptyResultSetWhenNoResults() var (httpClient, _) = CreateMockHttpClient(); var serializer = CreateMockSerializer(new List()); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); @@ -1001,7 +932,7 @@ public async Task ShouldStreamResponseWithoutFullBuffering() var httpClient = new HttpClient(handler); var serializer = CreateMockSerializer(); var settings = new ConnectionSettings(); - using var connection = new Connection(TestUri, serializer, serializer, settings, httpClient); + using var connection = new Connection(TestUri, serializer, settings, httpClient); var result = await connection.SubmitAsync(CreateTestRequest()); diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs index 7d832e96e0a..e7ea25f6480 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/DriverRemoteConnectionTests.cs @@ -133,7 +133,7 @@ public async Task ShouldDefaultBulkResultsToTrue() await connection.SubmitAsync(gl); Assert.NotNull(capturedRequest); - Assert.Equal("true", capturedRequest!.Fields[Tokens.ArgsBulkResults]); + Assert.Equal(true, capturedRequest!.Fields[Tokens.ArgsBulkResults]); } [Fact] @@ -189,13 +189,13 @@ public async Task ShouldRespectPerRequestBulkResults() gl.AddStep("V", Array.Empty()); gl.OptionsStrategies.Add(new OptionsStrategy(new Dictionary { - { Tokens.ArgsBulkResults, "false" } + { Tokens.ArgsBulkResults, false } })); await connection.SubmitAsync(gl); Assert.NotNull(capturedRequest); - Assert.Equal("false", capturedRequest!.Fields[Tokens.ArgsBulkResults]); + Assert.Equal(false, capturedRequest!.Fields[Tokens.ArgsBulkResults]); } /// diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs index f187ada59cd..58442eb7d12 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/GremlinClientTests.cs @@ -21,6 +21,8 @@ #endregion +using System; +using System.Threading.Tasks; using Gremlin.Net.Driver; using Gremlin.Net.Structure.IO.GraphBinary4; using Xunit; @@ -37,10 +39,10 @@ public void ShouldCreateClientWithDefaultSettings() } [Fact] - public void ShouldCreateClientWithCustomSerializer() + public void ShouldCreateClientWithCustomResponseSerializer() { var serializer = new GraphBinary4MessageSerializer(); - using var client = new GremlinClient(new GremlinServer(), messageSerializer: serializer); + using var client = new GremlinClient(new GremlinServer(), responseSerializer: serializer); // Should not throw } @@ -64,5 +66,21 @@ public void ShouldDisposeWithoutError() // Should not throw on double dispose client.Dispose(); } + + [Fact] + public void ShouldAppendAuthInterceptorToEndOfList() + { + Func userInterceptor = ctx => Task.CompletedTask; + Func auth = ctx => Task.CompletedTask; + + using var client = new GremlinClient(new GremlinServer(), + auth: auth, + interceptors: new[] { userInterceptor }); + + var interceptors = client.Connection.Interceptors; + Assert.Equal(2, interceptors.Count); + Assert.Same(userInterceptor, interceptors[0]); + Assert.Same(auth, interceptors[1]); + } } } diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs index 86d2bc678d3..10522026445 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Driver/HttpRequestContextTests.cs @@ -24,6 +24,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Text.Json; using Gremlin.Net.Driver; using Gremlin.Net.Driver.Messages; using Xunit; @@ -37,7 +38,7 @@ public void ShouldConstructWithByteArrayBody() { var method = "POST"; var uri = new Uri("http://localhost:8182/gremlin"); - var headers = new Dictionary { { "Content-Type", "application/vnd.graphbinary-v4.0" } }; + var headers = new Dictionary { { "Content-Type", "application/json" } }; var body = new byte[] { 0x01, 0x02, 0x03 }; var context = new HttpRequestContext(method, uri, headers, body); @@ -127,5 +128,170 @@ public void ShouldThrowWhenComputingPayloadHashForNullBody() Assert.Contains("null", ex.Message); } + + #region SerializeBody Tests + + [Fact] + public void SerializeBodyShouldReturnJsonBytesForRequestMessage() + { + var message = RequestMessage.Build("g.V().has('name','marko')").AddG("g").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + Assert.IsType(context.Body); + Assert.Same(context.Body, result); + + var json = JsonDocument.Parse(result); + Assert.Equal("g.V().has('name','marko')", json.RootElement.GetProperty("gremlin").GetString()); + Assert.Equal("g", json.RootElement.GetProperty("g").GetString()); + Assert.Equal("gremlin-lang", json.RootElement.GetProperty("language").GetString()); + } + + [Fact] + public void SerializeBodyShouldSetContentTypeHeader() + { + var message = RequestMessage.Build("g.V()").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + context.SerializeBody(); + + Assert.Equal("application/json", context.Headers["Content-Type"]); + } + + [Fact] + public void SerializeBodyShouldSetContentLengthHeader() + { + var message = RequestMessage.Build("g.V()").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + Assert.Equal(result.Length.ToString(), context.Headers["Content-Length"]); + } + + [Fact] + public void SerializeBodyShouldBeIdempotentWhenBodyIsBytes() + { + var originalBytes = new byte[] { 0x01, 0x02, 0x03 }; + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), originalBytes); + + var result = context.SerializeBody(); + + Assert.Same(originalBytes, result); + Assert.Same(originalBytes, context.Body); + } + + [Fact] + public void SerializeBodyShouldBeIdempotentOnMultipleCalls() + { + var message = RequestMessage.Build("g.V()").AddG("g").Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var first = context.SerializeBody(); + var second = context.SerializeBody(); + + Assert.Same(first, second); + } + + [Fact] + public void SerializeBodyShouldIncludeAllFields() + { + var message = RequestMessage.Build("g.V()") + .AddG("g") + .AddLanguage("gremlin-lang") + .AddBatchSize(100) + .AddEvaluationTimeout(30000) + .Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("g.V()", json.RootElement.GetProperty("gremlin").GetString()); + Assert.Equal("g", json.RootElement.GetProperty("g").GetString()); + Assert.Equal("gremlin-lang", json.RootElement.GetProperty("language").GetString()); + Assert.Equal(100, json.RootElement.GetProperty("batchSize").GetInt32()); + Assert.Equal(30000, json.RootElement.GetProperty("evaluationTimeout").GetInt32()); + } + + [Fact] + public void SerializeBodyShouldThrowForUnsupportedType() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), "unsupported"); + + var ex = Assert.Throws(() => context.SerializeBody()); + + Assert.Contains("String", ex.Message); + } + + [Fact] + public void SerializeBodyShouldThrowForNullBody() + { + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), null!); + + var ex = Assert.Throws(() => context.SerializeBody()); + + Assert.Contains("null", ex.Message); + } + + [Fact] + public void SerializeBodyShouldReflectFieldMutations() + { + var message = RequestMessage.Build("g.V()").AddG("g").Create(); + // Mutate fields before serialization (simulating an interceptor) + message.Fields["customField"] = "customValue"; + + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("customValue", json.RootElement.GetProperty("customField").GetString()); + } + + [Fact] + public void SerializeBodyShouldReflectBodyReplacement() + { + var original = RequestMessage.Build("g.V()").AddG("g").Create(); + var replacement = RequestMessage.Build("g.E()").AddG("traversal").Create(); + + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), original); + + // Interceptor replaces the body + context.Body = replacement; + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("g.E()", json.RootElement.GetProperty("gremlin").GetString()); + Assert.Equal("traversal", json.RootElement.GetProperty("g").GetString()); + } + + [Fact] + public void SerializeBodyShouldIncludeBindingsField() + { + var message = RequestMessage.Build("g.V(x)") + .AddBindingsString("[x:1,y:'marko']") + .Create(); + var context = new HttpRequestContext("POST", new Uri("http://localhost:8182/gremlin"), + new Dictionary(), message); + + var result = context.SerializeBody(); + + var json = JsonDocument.Parse(result); + Assert.Equal("[x:1,y:'marko']", json.RootElement.GetProperty("bindings").GetString()); + } + + #endregion } } diff --git a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs index 41298a6f7e6..dbee0f209f5 100644 --- a/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs +++ b/gremlin-dotnet/test/Gremlin.Net.UnitTest/Structure/IO/GraphBinary4/GraphBinary4MessageSerializerTests.cs @@ -41,7 +41,7 @@ public async Task ShouldSerializeRequestMessageStartingWithVersionByte() var actual = await serializer.SerializeMessageAsync(msg); - // First byte should be version byte 0x84 — no MIME prefix + // First byte should be version byte 0x84, no MIME prefix Assert.Equal(0x84, actual[0]); } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java index d4a3967f029..a40e0f6a725 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/Cluster.java @@ -29,10 +29,8 @@ import io.netty.util.concurrent.Future; import org.apache.commons.configuration2.Configuration; import org.apache.commons.lang3.concurrent.BasicThreadFactory; -import org.apache.commons.lang3.tuple.Pair; import org.apache.tinkerpop.gremlin.driver.auth.Auth; import org.apache.tinkerpop.gremlin.driver.exception.NoHostAvailableException; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; import org.apache.tinkerpop.gremlin.driver.remote.HttpRemoteTransaction; import org.apache.tinkerpop.gremlin.structure.Transaction; import org.apache.tinkerpop.gremlin.util.MessageSerializer; @@ -63,7 +61,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.LinkedList; import java.util.List; import java.util.Optional; import java.util.Set; @@ -86,7 +83,6 @@ * @author Stephen Mallette (http://stephen.genoprime.com) */ public final class Cluster { - public static final String SERIALIZER_INTERCEPTOR_NAME = "serializer"; private static final Logger logger = LoggerFactory.getLogger(Cluster.class); private final Manager manager; @@ -141,10 +137,6 @@ public static Builder build(final String address) { return new Builder(address); } - public static Builder build(final RequestInterceptor serializingInterceptor) { - return new Builder(serializingInterceptor); - } - public static Builder build(final File configurationFile) throws FileNotFoundException { final Settings settings = Settings.read(new FileInputStream(configurationFile)); return getBuilderFromSettings(settings); @@ -370,8 +362,8 @@ MessageSerializer getSerializer() { return manager.serializer; } - List> getRequestInterceptors() { - return manager.interceptors; + List getRequestInterceptors() { + return Collections.unmodifiableList(manager.interceptors); } ScheduledExecutorService executor() { @@ -483,7 +475,6 @@ public boolean isBulkResultsEnabled() { } public final static class Builder { - private static int INTERCEPTOR_NOT_FOUND = -1; private final List addresses = new ArrayList<>(); private int port = 8182; @@ -510,25 +501,18 @@ public final static class Builder { private boolean sslSkipCertValidation = false; private SslContext sslContext = null; private LoadBalancingStrategy loadBalancingStrategy = new LoadBalancingStrategy.RoundRobin(); - private LinkedList> interceptors = new LinkedList<>(); + private List interceptors = new ArrayList<>(); + private Auth auth = null; private long connectionSetupTimeoutMillis = Connection.CONNECTION_SETUP_TIMEOUT_MILLIS; private long idleConnectionTimeoutMillis = Connection.CONNECTION_IDLE_TIMEOUT_MILLIS; private boolean enableUserAgentOnConnect = true; private boolean bulkResults = false; private Builder() { - addInterceptor(SERIALIZER_INTERCEPTOR_NAME, - new PayloadSerializingInterceptor(new GraphBinaryMessageSerializerV4())); } private Builder(final String address) { addContactPoint(address); - addInterceptor(SERIALIZER_INTERCEPTOR_NAME, - new PayloadSerializingInterceptor(new GraphBinaryMessageSerializerV4())); - } - - private Builder(final RequestInterceptor bodySerializer) { - addInterceptor(SERIALIZER_INTERCEPTOR_NAME, bodySerializer); } /** @@ -751,82 +735,28 @@ public Builder loadBalancingStrategy(final LoadBalancingStrategy loadBalancingSt } /** - * Adds a {@link RequestInterceptor} after another one that will allow manipulation of the {@code HttpRequest} - * prior to its being sent to the server. + * Sets the list of {@link RequestInterceptor} instances that will be run in order to allow + * modification of the {@link HttpRequest} prior to its being sent to the server. */ - public Builder addInterceptorAfter(final String priorInterceptorName, final String nameOfInterceptor, - final RequestInterceptor interceptor) { - final int index = getInterceptorIndex(priorInterceptorName); - if (INTERCEPTOR_NOT_FOUND == index) { - throw new IllegalArgumentException(priorInterceptorName + " interceptor not found"); - } else if (getInterceptorIndex(nameOfInterceptor) != INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(nameOfInterceptor + " interceptor already exists"); - } - interceptors.add(index + 1, Pair.of(nameOfInterceptor, interceptor)); - + public Builder interceptors(final List interceptors) { + this.interceptors = new ArrayList<>(interceptors); return this; } /** - * Adds a {@link RequestInterceptor} before another one that will allow manipulation of the {@code HttpRequest} - * prior to its being sent to the server. + * Sets the list of {@link RequestInterceptor} instances that will be run in order. */ - public Builder addInterceptorBefore(final String subsequentInterceptorName, final String nameOfInterceptor, - final RequestInterceptor interceptor) { - final int index = getInterceptorIndex(subsequentInterceptorName); - if (INTERCEPTOR_NOT_FOUND == index) { - throw new IllegalArgumentException(subsequentInterceptorName + " interceptor not found"); - } else if (getInterceptorIndex(nameOfInterceptor) != INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(nameOfInterceptor + " interceptor already exists"); - } else if (index == 0) { - interceptors.addFirst(Pair.of(nameOfInterceptor, interceptor)); - } else { - interceptors.add(index - 1, Pair.of(nameOfInterceptor, interceptor)); - } - + public Builder interceptors(final RequestInterceptor... interceptors) { + this.interceptors = new ArrayList<>(List.of(interceptors)); return this; } /** - * Adds a {@link RequestInterceptor} to the end of the list that will allow manipulation of the - * {@code HttpRequest} prior to its being sent to the server. - */ - public Builder addInterceptor(final String name, final RequestInterceptor interceptor) { - if (getInterceptorIndex(name) != INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(name + " interceptor already exists"); - } - interceptors.add(Pair.of(name, interceptor)); - return this; - } - - /** - * Removes a {@link RequestInterceptor} from the list. This can be used to remove the default interceptors that - * aren't needed. - */ - public Builder removeInterceptor(final String name) { - final int index = getInterceptorIndex(name); - if (index == INTERCEPTOR_NOT_FOUND) { - throw new IllegalArgumentException(name + " interceptor not found"); - } - interceptors.remove(index); - return this; - } - - private int getInterceptorIndex(final String name) { - for (int i = 0; i < interceptors.size(); i++) { - if (interceptors.get(i).getLeft().equals(name)) { - return i; - } - } - - return INTERCEPTOR_NOT_FOUND; - } - - /** - * Adds an Auth {@link RequestInterceptor} to the end of list of interceptors. + * Adds an Auth {@link RequestInterceptor} that will always be appended to the end of the interceptor list + * when the {@link Cluster} is created, regardless of the order in which builder methods are called. */ public Builder auth(final Auth auth) { - addInterceptor(auth.getClass().getSimpleName().toLowerCase() + "-auth", auth); + this.auth = auth; return this; } @@ -906,6 +836,7 @@ List getContactPoints() { public Cluster create() { if (addresses.isEmpty()) addContactPoint("localhost"); if (null == serializer) serializer = Serializers.GRAPHBINARY_V4.simpleInstance(); + if (null != auth) interceptors.add(auth); return new Cluster(this); } } @@ -946,7 +877,7 @@ class Manager { private final LoadBalancingStrategy loadBalancingStrategy; private final Optional sslContextOptional; private final Supplier validationRequest; - private final List> interceptors; + private final List interceptors; /** * Thread pool for requests. @@ -985,7 +916,7 @@ private Manager(final Builder builder) { this.loadBalancingStrategy = builder.loadBalancingStrategy; this.contactPoints = builder.getContactPoints(); - this.interceptors = builder.interceptors; + this.interceptors = Collections.unmodifiableList(new ArrayList<>(builder.interceptors)); this.enableUserAgentOnConnect = builder.enableUserAgentOnConnect; this.bulkResults = builder.bulkResults; diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java index 253274940f1..ebe5011b7e7 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/HttpRequest.java @@ -18,16 +18,27 @@ */ package org.apache.tinkerpop.gremlin.driver; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; +import org.apache.tinkerpop.shaded.jackson.core.JsonProcessingException; +import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper; + import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; import java.util.Map; /** - * HttpRequest represents the data that will be used to create the actual request to the remote endpoint. It will be - * passed to different {@link RequestInterceptor} that can update its values. The body can be anything as the - * interceptor may change what the payload is. Also contains some convenience Strings for common HTTP header key and - * values and HTTP methods. + * Represents the HTTP request that will be sent to the server. It is passed through the + * {@link RequestInterceptor} chain where interceptors can modify headers, body, URI, and method. + *

+ * The body starts as a {@link RequestMessage} and can be serialized to JSON bytes via {@link #serializeBody()}. + * After all interceptors run, if the body is still a {@code RequestMessage}, the driver will call + * {@code serializeBody()} automatically before sending. */ public class HttpRequest { + + private static final ObjectMapper mapper = new ObjectMapper(); + public static class Headers { // Add as needed. Headers are case-insensitive; lower case for now to match Netty. public static final String ACCEPT = "accept"; @@ -79,7 +90,7 @@ public Map headers() { /** * Get the body of the request. * - * @return an Object representing the body. + * @return an Object representing the body ({@link RequestMessage} or {@code byte[]}). */ public Object getBody() { return body; @@ -104,8 +115,7 @@ public String getMethod() { } /** - * Set the HTTP body of the request. During processing, the body can be any type but the final interceptor must set - * the body to a {@code byte[]}. + * Set the HTTP body of the request. * * @return this HttpRequest for method chaining. */ @@ -133,4 +143,45 @@ public HttpRequest setUri(final URI uri) { this.uri = uri; return this; } + + /** + * Serialize the body to JSON bytes if it is still a {@link RequestMessage}. If the body is already + * {@code byte[]}, this method is idempotent and returns the existing bytes. This method also sets the + * {@code Content-Type} header to {@code application/json} and the {@code Content-Length} header to the + * byte length of the serialized body. + *

+ * Interceptors that need the serialized payload (e.g., for computing a signature hash) should call + * this method rather than serializing independently. + * + * @return the serialized body bytes + * @throws IllegalStateException if the body is neither a {@link RequestMessage} nor {@code byte[]} + */ + public byte[] serializeBody() { + if (body instanceof byte[]) { + return (byte[]) body; + } + + if (!(body instanceof RequestMessage)) { + throw new IllegalStateException("Cannot serialize body of type " + + (body == null ? "null" : body.getClass().getSimpleName()) + + ". Expected RequestMessage or byte[]."); + } + + final RequestMessage requestMessage = (RequestMessage) body; + + // Build JSON map: gremlin is top-level, plus all fields from the message + final Map jsonMap = new LinkedHashMap<>(); + jsonMap.put("gremlin", requestMessage.getGremlin()); + jsonMap.putAll(requestMessage.getFields()); + + try { + final byte[] jsonBytes = mapper.writeValueAsBytes(jsonMap); + this.body = jsonBytes; + this.headers.put(Headers.CONTENT_TYPE, "application/json"); + this.headers.put(Headers.CONTENT_LENGTH, String.valueOf(jsonBytes.length)); + return jsonBytes; + } catch (JsonProcessingException e) { + throw new IllegalStateException("Failed to serialize RequestMessage to JSON", e); + } + } } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java index 74055303321..0a4a3d49706 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/RequestInterceptor.java @@ -18,14 +18,20 @@ */ package org.apache.tinkerpop.gremlin.driver; -import java.util.function.UnaryOperator; - /** - * Interceptors are run as a list to allow modification of the HTTP request before it is sent to the server. The first - * interceptor will be provided with a {@link HttpRequest} that holds a - * {@link org.apache.tinkerpop.gremlin.util.message.RequestMessage} in the body. The final interceptor should contain a - * {@code byte[]} in the body. + * Interceptors are run as an ordered list to allow modification of the {@link HttpRequest} before it is sent to the + * server. The interceptor receives an {@link HttpRequest} whose body starts as a + * {@link org.apache.tinkerpop.gremlin.util.message.RequestMessage}. Interceptors mutate the request in place. + * After all interceptors run, if the body is still a {@code RequestMessage} the driver will auto-serialize it to JSON. + * Interceptors that need the serialized bytes (e.g., for signing) should call {@link HttpRequest#serializeBody()}. */ -public interface RequestInterceptor extends UnaryOperator { +@FunctionalInterface +public interface RequestInterceptor { + /** + * Intercept and mutate the HTTP request before it is sent. + * + * @param httpRequest the mutable HTTP request + */ + void intercept(HttpRequest httpRequest); } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java index f5a48c21d6e..f33173958e0 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Basic.java @@ -18,10 +18,9 @@ */ package org.apache.tinkerpop.gremlin.driver.auth; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpHeaderNames; import org.apache.tinkerpop.gremlin.driver.HttpRequest; +import java.nio.charset.StandardCharsets; import java.util.Base64; public class Basic implements Auth { @@ -35,10 +34,9 @@ public Basic(final String username, final String password) { } @Override - public HttpRequest apply(final HttpRequest httpRequest) { + public void intercept(final HttpRequest httpRequest) { final String valueToEncode = username + ":" + password; httpRequest.headers().put(HttpRequest.Headers.AUTHORIZATION, - "Basic " + Base64.getEncoder().encodeToString(valueToEncode.getBytes())); - return httpRequest; + "Basic " + Base64.getEncoder().encodeToString(valueToEncode.getBytes(StandardCharsets.UTF_8))); } } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java index fb51aa80736..a19e75d7499 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; import java.util.Set; + import org.apache.tinkerpop.gremlin.driver.HttpRequest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,8 +48,9 @@ import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_SECURITY_TOKEN; /** - * A {@link org.apache.tinkerpop.gremlin.driver.RequestInterceptor} that provides headers required for SigV4. Because - * the signing process requires final header and body data, this interceptor should almost always be last. + * A {@link org.apache.tinkerpop.gremlin.driver.RequestInterceptor} that signs requests with AWS SigV4. + * This interceptor calls {@link HttpRequest#serializeBody()} to ensure the body is serialized before + * computing the payload hash for signing. It should typically be the last interceptor in the chain. */ public class Sigv4 implements Auth { private static final Logger logger = LoggerFactory.getLogger(Sigv4.class); @@ -70,8 +72,11 @@ public Sigv4(final String regionName, final AwsCredentialsProvider awsCredential } @Override - public HttpRequest apply(final HttpRequest httpRequest) { + public void intercept(final HttpRequest httpRequest) { try { + // Ensure the body is serialized to bytes so we can compute the payload hash. + httpRequest.serializeBody(); + final ContentStreamProvider content = toContentStream(httpRequest); // Convert Http request into an AWS SDK signable request final SdkHttpRequest awsSignableRequest = toSignableRequest(httpRequest); @@ -91,7 +96,6 @@ public HttpRequest apply(final HttpRequest httpRequest) { logger.error("Error signing HTTP request: {}", ex.getMessage(), ex); throw new AuthenticationException(ex); } - return httpRequest; } private void setSessionToken(final Map headers, final AwsCredentials credentials) { @@ -123,9 +127,6 @@ private String getSingleHeaderValue(final Map> headers, fin private ContentStreamProvider toContentStream(final HttpRequest httpRequest) { // carry over the entity (or an empty entity, if no entity is provided) - if (!(httpRequest.getBody() instanceof byte[])) { - throw new IllegalArgumentException("Expected byte[] in HttpRequest body but got " + httpRequest.getBody().getClass()); - } final byte[] body = (byte[]) httpRequest.getBody(); return (body.length != 0) ? ContentStreamProvider.fromByteArray(body) : ContentStreamProvider.fromUtf8String(""); } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java index 0f8ede4c7fd..c1f5eca0ea5 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/handler/HttpGremlinRequestEncoder.java @@ -29,7 +29,6 @@ import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpVersion; -import org.apache.commons.lang3.tuple.Pair; import org.apache.tinkerpop.gremlin.driver.HttpRequest; import org.apache.tinkerpop.gremlin.driver.RequestInterceptor; import org.apache.tinkerpop.gremlin.driver.UserAgent; @@ -49,7 +48,7 @@ import static org.apache.tinkerpop.gremlin.driver.handler.InactiveChannelHandler.REQUEST_SENT; /** - * Converts {@link RequestMessage} to a {@code HttpRequest}. + * Converts {@link RequestMessage} to an HTTP request, running interceptors and serializing the body to JSON. */ @ChannelHandler.Sharable public final class HttpGremlinRequestEncoder extends MessageToMessageEncoder { @@ -57,11 +56,11 @@ public final class HttpGremlinRequestEncoder extends MessageToMessageEncoder serializer; private final boolean userAgentEnabled; private final boolean bulkResults; - private final List> interceptors; + private final List interceptors; private final URI uri; public HttpGremlinRequestEncoder(final MessageSerializer serializer, - final List> interceptors, + final List interceptors, final boolean userAgentEnabled, boolean bulkResults, final URI uri) { this.serializer = serializer; this.interceptors = interceptors; @@ -72,7 +71,6 @@ public HttpGremlinRequestEncoder(final MessageSerializer serializer, @Override protected void encode(final ChannelHandlerContext channelHandlerContext, final RequestMessage requestMessage, final List objects) throws Exception { - final String mimeType = serializer.mimeTypesSupported()[0]; if (requestMessage.getField("gremlin") instanceof GremlinLang) { throw new ResponseException(HttpResponseStatus.BAD_REQUEST, String.format( "An error occurred during serialization of this request [%s] - it could not be sent to the server - Reason: GremlinLang is not intended to be send as query.", @@ -81,9 +79,10 @@ protected void encode(final ChannelHandlerContext channelHandlerContext, final R final InetSocketAddress remoteAddress = getRemoteAddress(channelHandlerContext.channel()); try { - Map headersMap = new HashMap<>(); + final Map headersMap = new HashMap<>(); headersMap.put(HttpRequest.Headers.HOST, remoteAddress.getAddress().getHostAddress()); - headersMap.put(HttpRequest.Headers.ACCEPT, mimeType); + // Accept header uses the response serializer's mime type (GraphBinary for responses) + headersMap.put(HttpRequest.Headers.ACCEPT, serializer.mimeTypesSupported()[0]); headersMap.put(HttpRequest.Headers.ACCEPT_ENCODING, HttpRequest.Headers.DEFLATE); if (userAgentEnabled) { headersMap.put(HttpRequest.Headers.USER_AGENT, UserAgent.USER_AGENT); @@ -91,21 +90,26 @@ protected void encode(final ChannelHandlerContext channelHandlerContext, final R if (bulkResults) { headersMap.put(Tokens.BULK_RESULTS, "true"); } - - // Add X-Transaction-Id header to comply with specification's dual transmission (header and body) + + // Promote transactionId to HTTP header for dual transmission (header and body) final String transactionId = requestMessage.getField(Tokens.ARGS_TRANSACTION_ID); if (transactionId != null) { headersMap.put(Tokens.Headers.TRANSACTION_ID, transactionId); } - - HttpRequest gremlinRequest = new HttpRequest(headersMap, requestMessage, uri); - for (final Pair interceptor : interceptors) { - gremlinRequest = interceptor.getRight().apply(gremlinRequest); + final HttpRequest gremlinRequest = new HttpRequest(headersMap, requestMessage, uri); + + for (final RequestInterceptor interceptor : interceptors) { + interceptor.intercept(gremlinRequest); } + // Auto-serialize if interceptors did not already produce bytes + gremlinRequest.serializeBody(); + + final byte[] bodyBytes = (byte[]) gremlinRequest.getBody(); + final FullHttpRequest finalRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, - uri.getPath(), convertBody(gremlinRequest)); + uri.getPath(), Unpooled.wrappedBuffer(bodyBytes)); gremlinRequest.headers().forEach((k, v) -> finalRequest.headers().add(k, v)); objects.add(finalRequest); @@ -128,15 +132,4 @@ private static InetSocketAddress getRemoteAddress(Channel channel) { } return remoteAddress; } - - private static ByteBuf convertBody(final HttpRequest request) { - final Object body = request.getBody(); - if (body instanceof byte[]) { - request.headers().put(HttpRequest.Headers.CONTENT_LENGTH, String.valueOf(((byte[]) body).length)); - return Unpooled.wrappedBuffer((byte[]) body); - } else { - throw new IllegalArgumentException("Final body must be byte[] but found " - + body.getClass().getSimpleName()); - } - } } diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/interceptor/PayloadSerializingInterceptor.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/interceptor/PayloadSerializingInterceptor.java deleted file mode 100644 index ae35bc05af4..00000000000 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/interceptor/PayloadSerializingInterceptor.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.tinkerpop.gremlin.driver.interceptor; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.ByteBufUtil; -import io.netty.handler.codec.http.HttpResponseStatus; -import org.apache.tinkerpop.gremlin.driver.HttpRequest; -import org.apache.tinkerpop.gremlin.driver.RequestInterceptor; -import org.apache.tinkerpop.gremlin.driver.exception.ResponseException; -import org.apache.tinkerpop.gremlin.util.MessageSerializer; -import org.apache.tinkerpop.gremlin.util.message.RequestMessage; -import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; -import org.apache.tinkerpop.gremlin.util.ser.SerializationException; -import org.apache.tinkerpop.gremlin.util.ser.Serializers; - -import java.util.Map; - -/** - * A {@link RequestInterceptor} that serializes the request body usng the provided {@link MessageSerializer}. This - * interceptor should be run before other interceptors that need to calculate values based on the request body. - */ -public class PayloadSerializingInterceptor implements RequestInterceptor { - // Should be thread-safe as the GraphBinaryWriter/GraphSONMessageSerializer doesn't maintain state. - private final MessageSerializer serializer; - - public PayloadSerializingInterceptor(final MessageSerializer serializer) { - this.serializer = serializer; - } - - @Override - public HttpRequest apply(HttpRequest httpRequest) { - if (!(httpRequest.getBody() instanceof RequestMessage)) { - throw new IllegalArgumentException("Only RequestMessage serialization is supported"); - } - - final RequestMessage request = (RequestMessage) httpRequest.getBody(); - final ByteBuf requestBuf; - try { - requestBuf = serializer.serializeRequestAsBinary(request, ByteBufAllocator.DEFAULT); - } catch (SerializationException se) { - throw new RuntimeException(new ResponseException(HttpResponseStatus.BAD_REQUEST, String.format( - "An error occurred during serialization of this request [%s] - it could not be sent to the server - Reason: %s", - request, se))); - } - - // Convert from ByteBuf to bytes[] because that's what the final request body should contain. - final byte[] requestBytes = ByteBufUtil.getBytes(requestBuf); - requestBuf.release(); - - httpRequest.setBody(requestBytes); - httpRequest.headers().put(HttpRequest.Headers.CONTENT_TYPE, serializer.mimeTypesSupported()[0]); - - return httpRequest; - } -} diff --git a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java index 984d6c7817f..faa7aac7df0 100644 --- a/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java +++ b/gremlin-driver/src/main/java/org/apache/tinkerpop/gremlin/driver/simple/SimpleHttpClient.java @@ -23,12 +23,10 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; -import org.apache.commons.lang3.tuple.Pair; import org.apache.tinkerpop.gremlin.driver.Channelizer; import org.apache.tinkerpop.gremlin.driver.handler.HttpContentDecompressionHandler; import org.apache.tinkerpop.gremlin.driver.handler.HttpGremlinResponseStreamDecoder; import org.apache.tinkerpop.gremlin.driver.handler.HttpGremlinRequestEncoder; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; import org.apache.tinkerpop.gremlin.util.MessageSerializer; import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import io.netty.bootstrap.Bootstrap; @@ -109,9 +107,7 @@ protected void initChannel(final SocketChannel ch) { new HttpContentDecompressionHandler(), new HttpGremlinResponseStreamDecoder(serializer, Integer.MAX_VALUE), new HttpGremlinRequestEncoder(serializer, - Collections.singletonList( - Pair.of("serializer", new PayloadSerializingInterceptor( - new GraphBinaryMessageSerializerV4()))), + Collections.emptyList(), false, false, uri), callbackResponseHandler); } diff --git a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/ClusterTest.java b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/ClusterTest.java deleted file mode 100644 index 0edc12cd0ae..00000000000 --- a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/ClusterTest.java +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.tinkerpop.gremlin.driver; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; -import org.junit.Test; - -import java.util.List; - -import static org.apache.tinkerpop.gremlin.driver.Cluster.SERIALIZER_INTERCEPTOR_NAME; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.core.IsInstanceOf.instanceOf; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -/** - * Test Cluster and Cluster.Builder. - */ -public class ClusterTest { - private static final RequestInterceptor TEST_INTERCEPTOR = httpRequest -> new HttpRequest(null, null, null); - - @Test - public void shouldNotAllowModifyingRelativeToNonExistentInterceptor() { - try { - Cluster.build().addInterceptorAfter("none", "test", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("none interceptor not found", e.getMessage()); - } - - try { - Cluster.build().addInterceptorBefore("none", "test", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("none interceptor not found", e.getMessage()); - } - - try { - Cluster.build().removeInterceptor("nonexistent"); - fail("Should not have allowed interceptor to be removed."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("nonexistent interceptor not found", e.getMessage()); - } - } - - @Test - public void shouldAddToInterceptorToBeginningIfBeforeFirst() { - final Cluster testCluster = Cluster.build() - .addInterceptor("b", req -> req) - .addInterceptorBefore(SERIALIZER_INTERCEPTOR_NAME, "a", TEST_INTERCEPTOR) - .create(); - assertEquals("a", testCluster.getRequestInterceptors().get(0).getLeft()); - assertEquals(TEST_INTERCEPTOR, testCluster.getRequestInterceptors().get(0).getRight()); - assertEquals(SERIALIZER_INTERCEPTOR_NAME, testCluster.getRequestInterceptors().get(1).getLeft()); - } - - @Test - public void shouldAddToInterceptorAfter() { - final Cluster testCluster = Cluster.build() - .addInterceptor("b", req -> req) - .addInterceptorAfter(SERIALIZER_INTERCEPTOR_NAME, "a", TEST_INTERCEPTOR) - .create(); - assertEquals(SERIALIZER_INTERCEPTOR_NAME, testCluster.getRequestInterceptors().get(0).getLeft()); - assertEquals("a", testCluster.getRequestInterceptors().get(1).getLeft()); - assertEquals(TEST_INTERCEPTOR, testCluster.getRequestInterceptors().get(1).getRight()); - assertEquals("b", testCluster.getRequestInterceptors().get(2).getLeft()); - - } - - @Test - public void shouldAddToInterceptorLast() { - final Cluster testCluster = Cluster.build() - .addInterceptor("c", req -> req) - .addInterceptor("b", req -> req) - .addInterceptor("a", req -> req) - .create(); - assertEquals(SERIALIZER_INTERCEPTOR_NAME, testCluster.getRequestInterceptors().get(0).getLeft()); - assertEquals("c", testCluster.getRequestInterceptors().get(1).getLeft()); - assertEquals("b", testCluster.getRequestInterceptors().get(2).getLeft()); - assertEquals("a", testCluster.getRequestInterceptors().get(3).getLeft()); - } - - @Test - public void shouldNotAllowAddingDuplicateName() { - try { - Cluster.build().addInterceptor("name", req -> req).addInterceptor("name", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("name interceptor already exists", e.getMessage()); - } - - try { - Cluster.build().addInterceptor("name", req -> req).addInterceptorAfter("name", "name", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("name interceptor already exists", e.getMessage()); - } - - try { - Cluster.build().addInterceptor("name", req -> req).addInterceptorBefore("name", "name", req -> req); - fail("Should not have allowed interceptor to be added."); - } catch (Exception e) { - assertThat(e, instanceOf(IllegalArgumentException.class)); - assertEquals("name interceptor already exists", e.getMessage()); - } - } - - @Test - public void shouldContainBodySerializerByDefault() { - final List> interceptors = Cluster.build().create().getRequestInterceptors(); - assertEquals(1, interceptors.size()); - assertTrue(interceptors.get(0).getRight() instanceof PayloadSerializingInterceptor); - } - - @Test - public void shouldRemoveDefaultSerializer() { - final Cluster testCluster = Cluster.build().removeInterceptor(SERIALIZER_INTERCEPTOR_NAME).create(); - assertEquals(0, testCluster.getRequestInterceptors().size()); - } -} diff --git a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/InterceptorTest.java b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/InterceptorTest.java new file mode 100644 index 00000000000..a10c36ae32a --- /dev/null +++ b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/InterceptorTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.tinkerpop.gremlin.driver; + +import org.apache.tinkerpop.gremlin.driver.auth.Auth; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; +import org.apache.tinkerpop.shaded.jackson.databind.JsonNode; +import org.apache.tinkerpop.shaded.jackson.databind.ObjectMapper; +import org.junit.Test; + +import java.net.URI; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +public class InterceptorTest { + + private static final ObjectMapper mapper = new ObjectMapper(); + + private HttpRequest createRequest(final RequestMessage msg) throws Exception { + return new HttpRequest(new HashMap<>(), msg, new URI("http://localhost:8182/gremlin")); + } + + private HttpRequest createRequest() throws Exception { + return createRequest(RequestMessage.build("g.V()").addG("g").create()); + } + + @Test + public void interceptorReceivesRequestMessageInBody() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V()").create(); + final HttpRequest request = createRequest(msg); + + final List captured = new ArrayList<>(); + final RequestInterceptor interceptor = req -> captured.add(req.getBody()); + interceptor.intercept(request); + + assertEquals(1, captured.size()); + assertSame(msg, captured.get(0)); + } + + @Test + public void interceptorCanReadAndModifyHeaders() throws Exception { + final HttpRequest request = createRequest(); + request.headers().put("X-Existing", "original"); + + final RequestInterceptor interceptor = req -> { + assertEquals("original", req.headers().get("X-Existing")); + req.headers().put("X-Existing", "modified"); + req.headers().put("X-New", "added"); + }; + interceptor.intercept(request); + + assertEquals("modified", request.headers().get("X-Existing")); + assertEquals("added", request.headers().get("X-New")); + } + + @Test + public void interceptorCanModifyUri() throws Exception { + final HttpRequest request = createRequest(); + final URI newUri = new URI("http://other-host:9999/gremlin"); + + final RequestInterceptor interceptor = req -> req.setUri(newUri); + interceptor.intercept(request); + + assertEquals(newUri, request.getUri()); + } + + @Test + public void interceptorsRunInRegistrationOrder() throws Exception { + final HttpRequest request = createRequest(); + final List order = new ArrayList<>(); + + final List interceptors = List.of( + req -> order.add(1), + req -> order.add(2), + req -> order.add(3) + ); + + for (final RequestInterceptor i : interceptors) { + i.intercept(request); + } + + assertEquals(List.of(1, 2, 3), order); + } + + @Test + public void serializeBodyConvertsRequestMessageToJsonBytes() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V().count()").addG("g").create(); + final HttpRequest request = createRequest(msg); + + final byte[] result = request.serializeBody(); + + final JsonNode json = mapper.readTree(result); + assertEquals("g.V().count()", json.get("gremlin").asText()); + assertEquals("g", json.get("g").asText()); + } + + @Test + public void serializeBodySetsContentTypeHeader() throws Exception { + final HttpRequest request = createRequest(); + + request.serializeBody(); + + assertEquals("application/json", request.headers().get(HttpRequest.Headers.CONTENT_TYPE)); + } + + @Test + public void serializeBodySetsContentLengthHeader() throws Exception { + final HttpRequest request = createRequest(); + + final byte[] result = request.serializeBody(); + + assertEquals(String.valueOf(result.length), request.headers().get(HttpRequest.Headers.CONTENT_LENGTH)); + } + + @Test + public void serializeBodyIsIdempotentWithPreSerializedBytes() throws Exception { + final byte[] existing = "{\"gremlin\":\"g.V()\"}".getBytes(); + final HttpRequest request = new HttpRequest(new HashMap<>(), existing, new URI("http://localhost:8182/gremlin")); + + final byte[] first = request.serializeBody(); + final byte[] second = request.serializeBody(); + + assertSame(existing, first); + assertSame(first, second); + } + + @Test + public void serializeBodyIsIdempotentWithRequestMessage() throws Exception { + final HttpRequest request = createRequest(); + + final byte[] first = request.serializeBody(); + final byte[] second = request.serializeBody(); + + assertSame(first, second); + } + + @Test + public void serializeBodyIncludesAllFields() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V()") + .addG("g") + .addLanguage("gremlin-lang") + .addTimeoutMillis(5000L) + .addBulkResults(true) + .create(); + final HttpRequest request = createRequest(msg); + + final byte[] result = request.serializeBody(); + final JsonNode json = mapper.readTree(result); + + assertEquals("g.V()", json.get("gremlin").asText()); + assertEquals("g", json.get("g").asText()); + assertEquals("gremlin-lang", json.get("language").asText()); + assertEquals(5000, json.get("timeoutMs").asLong()); + assertEquals("true", json.get("bulkResults").asText()); + } + + @Test + public void interceptorCanReplaceBodyBeforeSerialization() throws Exception { + final RequestMessage original = RequestMessage.build("g.V()").addG("g").create(); + final HttpRequest request = createRequest(original); + + final RequestInterceptor interceptor = req -> { + req.setBody(RequestMessage.build("g.E()").addG("gmodern").create()); + }; + interceptor.intercept(request); + request.serializeBody(); + + final JsonNode json = mapper.readTree((byte[]) request.getBody()); + assertEquals("g.E()", json.get("gremlin").asText()); + assertEquals("gmodern", json.get("g").asText()); + } + + @Test + public void authIsAlwaysLastInterceptorRegardlessOfBuilderCallOrder() throws Exception { + // auth called before interceptors + final Cluster cluster1 = Cluster.build("localhost") + .auth(Auth.basic("user", "pass")) + .interceptors(req -> req.headers().put("X-Custom", "value")) + .create(); + + final List interceptors1 = cluster1.getRequestInterceptors(); + assertEquals(2, interceptors1.size()); + assertTrue("Auth should be last interceptor", interceptors1.get(interceptors1.size() - 1) instanceof Auth); + cluster1.close(); + + // auth called after interceptors + final Cluster cluster2 = Cluster.build("localhost") + .interceptors(req -> req.headers().put("X-Custom", "value")) + .auth(Auth.basic("user", "pass")) + .create(); + + final List interceptors2 = cluster2.getRequestInterceptors(); + assertEquals(2, interceptors2.size()); + assertTrue("Auth should be last interceptor", interceptors2.get(interceptors2.size() - 1) instanceof Auth); + cluster2.close(); + } + +} diff --git a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java index 5c7ecd1fd70..a72bcd09467 100644 --- a/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java +++ b/gremlin-driver/src/test/java/org/apache/tinkerpop/gremlin/driver/auth/Sigv4Test.java @@ -20,10 +20,9 @@ package org.apache.tinkerpop.gremlin.driver.auth; import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; import java.util.HashMap; import org.apache.tinkerpop.gremlin.driver.HttpRequest; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -52,7 +51,6 @@ public class Sigv4Test { private static final String REGION = "us-west-2"; private static final String SERVICE_NAME = "service-name"; - private static final byte[] REQUEST_BODY = "{\"gremlin\":\"2-1\"}".getBytes(StandardCharsets.UTF_8); private static final String HOST = "localhost"; private static final String URI_WITH_QUERY_PARAMS = "http://" + HOST + ":8182?a=1&b=2"; private static final String KEY = "foo"; @@ -71,8 +69,16 @@ public void setup() { @Test public void shouldAddSignedHeaders() throws Exception { when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); - HttpRequest httpRequest = createRequest(); - sigv4.apply(httpRequest); + HttpRequest httpRequest = createRequestWithRequestMessage(); + sigv4.intercept(httpRequest); + validateExpectedHeaders(httpRequest); + } + + @Test + public void shouldSignPreSerializedBody() throws Exception { + when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); + HttpRequest httpRequest = createRequestWithBytes(); + sigv4.intercept(httpRequest); validateExpectedHeaders(httpRequest); } @@ -80,35 +86,48 @@ public void shouldAddSignedHeaders() throws Exception { public void shouldAddSignedHeadersAndSessionToken() throws Exception { String sessionToken = "foobarz"; when(credentialsProvider.resolveCredentials()).thenReturn(AwsSessionCredentials.create(KEY, SECRET, sessionToken)); - HttpRequest httpRequest = createRequest(); - sigv4.apply(httpRequest); + HttpRequest httpRequest = createRequestWithRequestMessage(); + sigv4.intercept(httpRequest); validateExpectedHeaders(httpRequest); assertEquals(sessionToken, httpRequest.headers().get(X_AMZ_SECURITY_TOKEN)); } @Test - public void shouldThrowIfRequestNonByteArray() { + public void shouldThrowIfBodyIsUnsupportedType() { + // serializeBody() will throw because body is a String Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class, - () -> sigv4.apply(new HttpRequest(new HashMap<>(), "not byte array", new URI(URI_WITH_QUERY_PARAMS)))); - assertTrue(ex.getMessage().contains("Expected byte[] in HttpRequest body")); + () -> sigv4.intercept(new HttpRequest(new HashMap<>(), "not valid", URI.create(URI_WITH_QUERY_PARAMS)))); + assertTrue(ex.getCause().getMessage().contains("Cannot serialize body of type")); } @Test - public void shouldThrowIfNoRequestMethod() { + public void shouldThrowIfNoRequestMethod() throws Exception { + when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); + final byte[] body = "{\"gremlin\":\"g.V()\"}".getBytes(); Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class, - () -> sigv4.apply(new HttpRequest(new HashMap<>(), REQUEST_BODY, new URI(URI_WITH_QUERY_PARAMS), null))); - assertTrue(ex.getMessage().contains("The request method must not be null")); + () -> sigv4.intercept(new HttpRequest(new HashMap<>(), body, URI.create(URI_WITH_QUERY_PARAMS), null))); + assertTrue(ex.getCause().getMessage().contains("The request method must not be null")); } @Test public void shouldThrowIfNoRequestURI() { + when(credentialsProvider.resolveCredentials()).thenReturn(AwsBasicCredentials.create(KEY, SECRET)); + final byte[] body = "{\"gremlin\":\"g.V()\"}".getBytes(); Auth.AuthenticationException ex = assertThrows(Auth.AuthenticationException.class, - () -> sigv4.apply(new HttpRequest(new HashMap<>(), REQUEST_BODY, null))); - assertTrue(ex.getMessage().contains("The request URI must not be null")); + () -> sigv4.intercept(new HttpRequest(new HashMap<>(), body, null))); + assertTrue(ex.getCause().getMessage().contains("The request URI must not be null")); + } + + private HttpRequest createRequestWithRequestMessage() throws Exception { + final RequestMessage msg = RequestMessage.build("g.V()").addG("g").create(); + HttpRequest httpRequest = new HttpRequest(new HashMap<>(), msg, new URI(URI_WITH_QUERY_PARAMS)); + httpRequest.headers().put("Host", "this-should-be-ignored-for-signed-host-header"); + return httpRequest; } - private HttpRequest createRequest() throws URISyntaxException { - HttpRequest httpRequest = new HttpRequest(new HashMap<>(), REQUEST_BODY, new URI(URI_WITH_QUERY_PARAMS)); + private HttpRequest createRequestWithBytes() throws Exception { + final byte[] body = "{\"gremlin\":\"2-1\"}".getBytes(); + HttpRequest httpRequest = new HttpRequest(new HashMap<>(), body, new URI(URI_WITH_QUERY_PARAMS)); httpRequest.headers().put("Content-Type", "application/json"); httpRequest.headers().put("Host", "this-should-be-ignored-for-signed-host-header"); return httpRequest; @@ -123,5 +142,4 @@ private void validateExpectedHeaders(HttpRequest httpRequest) { containsString("/" + REGION + "/service-name/aws4_request"), containsString("Signature="))); } - -} \ No newline at end of file +} diff --git a/gremlin-go/driver/auth.go b/gremlin-go/driver/auth.go index e0f279110ab..411458e65c3 100644 --- a/gremlin-go/driver/auth.go +++ b/gremlin-go/driver/auth.go @@ -49,13 +49,12 @@ func SigV4Auth(region, service string) RequestInterceptor { // SigV4AuthWithCredentials returns a RequestInterceptor that signs requests using AWS SigV4 // with the provided credentials provider. If provider is nil, uses default credential chain. // If the request body has not been serialized yet (*RequestMessage), it is automatically -// serialized to GraphBinary before signing. +// serialized to JSON before signing via SerializeBody(). // // Caches the signer and credentials provider for efficiency. func SigV4AuthWithCredentials(region, service string, credentialsProvider aws.CredentialsProvider) RequestInterceptor { // Create signer once - it's stateless and safe to reuse signer := v4.NewSigner() - serialize := SerializeRequest() // Cache for resolved credentials provider (lazy initialization) var cachedProvider aws.CredentialsProvider @@ -63,15 +62,10 @@ func SigV4AuthWithCredentials(region, service string, credentialsProvider aws.Cr var providerErr error return func(req *HttpRequest) error { - // If Body is still *RequestMessage, serialize it to GraphBinary before signing. - if _, ok := req.Body.(*RequestMessage); ok { - if err := serialize(req); err != nil { - return fmt.Errorf("SigV4 auto-serialization failed: %w", err) - } - } - - if _, ok := req.Body.([]byte); !ok { - return fmt.Errorf("SigV4 signing requires body to be []byte; got %T", req.Body) + // Ensure body is serialized to JSON bytes before signing. + // SerializeBody is idempotent: safe to call even if already serialized. + if _, err := req.SerializeBody(); err != nil { + return fmt.Errorf("SigV4 signing requires a serialized body: %w", err) } ctx := context.Background() diff --git a/gremlin-go/driver/client.go b/gremlin-go/driver/client.go index ced49a1184d..aa34406fcd6 100644 --- a/gremlin-go/driver/client.go +++ b/gremlin-go/driver/client.go @@ -67,6 +67,11 @@ type ClientSettings struct { // RequestInterceptors are functions that modify HTTP requests before sending. RequestInterceptors []RequestInterceptor + + // Auth is a RequestInterceptor for authentication (e.g. BasicAuth, SigV4Auth). + // As a convenience, this is always appended to the end of the interceptor list + // so it runs last, after any user interceptors have modified the request. + Auth RequestInterceptor } // Client is used to connect and interact with a Gremlin-supported server. @@ -122,6 +127,11 @@ func NewClient(url string, configurations ...func(settings *ClientSettings)) (*C conn.AddInterceptor(interceptor) } + // Auth interceptor is always last so it runs after user interceptors + if settings.Auth != nil { + conn.AddInterceptor(settings.Auth) + } + client := &Client{ url: url, traversalSource: settings.TraversalSource, diff --git a/gremlin-go/driver/connection.go b/gremlin-go/driver/connection.go index 6a61500132f..8673414d254 100644 --- a/gremlin-go/driver/connection.go +++ b/gremlin-go/driver/connection.go @@ -53,7 +53,6 @@ type connection struct { httpClient *http.Client connSettings *connectionSettings logHandler *logHandler - serializer *GraphBinarySerializer interceptors []RequestInterceptor wg sync.WaitGroup } @@ -111,7 +110,6 @@ func newConnection(handler *logHandler, url string, connSettings *connectionSett httpClient: &http.Client{Transport: transport}, // No Timeout - allows streaming connSettings: connSettings, logHandler: handler, - serializer: newGraphBinarySerializer(handler), } } @@ -205,19 +203,12 @@ func (c *connection) sendRequest(req *RequestMessage) (*http.Response, error) { } } - // After interceptors, serialize if Body is still *RequestMessage - if r, ok := httpReq.Body.(*RequestMessage); ok { - if c.serializer != nil { - data, err := c.serializer.SerializeMessage(r) - if err != nil { - c.logHandler.logf(Error, failedToSendRequest, err.Error()) - return nil, err - } - httpReq.Body = data - } else { - errMsg := "request body was not serialized; either provide a serializer or add an interceptor that serializes the request" - c.logHandler.logf(Error, failedToSendRequest, errMsg) - return nil, fmt.Errorf("%s", errMsg) + // After interceptors, auto-serialize the body to JSON if still a *RequestMessage. + // SerializeBody is idempotent: if an interceptor already called it, this is a no-op. + if _, ok := httpReq.Body.(*RequestMessage); ok { + if _, err := httpReq.SerializeBody(); err != nil { + c.logHandler.logf(Error, failedToSendRequest, err.Error()) + return nil, err } } @@ -277,7 +268,6 @@ func (c *connection) streamResponse(resp *http.Response, rs ResultSet) { // setHttpRequestHeaders sets default headers on HttpRequest (for interceptors) func (c *connection) setHttpRequestHeaders(req *HttpRequest) { - req.Headers.Set(HeaderContentType, graphBinaryMimeType) req.Headers.Set(HeaderAccept, graphBinaryMimeType) if c.connSettings.enableUserAgentOnConnect { diff --git a/gremlin-go/driver/connection_test.go b/gremlin-go/driver/connection_test.go index 572cc954600..9314b0616d7 100644 --- a/gremlin-go/driver/connection_test.go +++ b/gremlin-go/driver/connection_test.go @@ -907,13 +907,12 @@ func TestNewConnection(t *testing.T) { } func TestSetHttpRequestHeaders(t *testing.T) { - t.Run("sets content type and accept headers", func(t *testing.T) { + t.Run("sets accept header", func(t *testing.T) { conn := newConnection(newTestLogHandler(), "http://localhost/gremlin", &connectionSettings{}) req, _ := NewHttpRequest(http.MethodPost, "http://localhost/gremlin") conn.setHttpRequestHeaders(req) - assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Content-Type")) assert.Equal(t, graphBinaryMimeType, req.Headers.Get("Accept")) }) @@ -1012,7 +1011,7 @@ func TestConnectionWithMockServer(t *testing.T) { select { case receivedHeaders := <-headersCh: - assert.Equal(t, graphBinaryMimeType, receivedHeaders.Get("Content-Type")) + assert.Equal(t, "application/json", receivedHeaders.Get("Content-Type")) assert.Equal(t, "deflate", receivedHeaders.Get("Accept-Encoding")) assert.NotEmpty(t, receivedHeaders.Get(userAgentHeader)) case <-time.After(time.Second): @@ -1105,8 +1104,7 @@ func TestConnectionWithMockServer(t *testing.T) { require.NoError(t, err) _, _ = rs.All() - // Interceptor should see the default headers - assert.Equal(t, graphBinaryMimeType, interceptorHeaders.Get("Content-Type")) + // Interceptor should see the default headers (Content-Type is not set until SerializeBody runs) assert.Equal(t, graphBinaryMimeType, interceptorHeaders.Get("Accept")) }) @@ -1311,6 +1309,142 @@ func TestDriverRemoteConnectionSettingsWiring(t *testing.T) { }) } +func TestInterceptorIntegration(t *testing.T) { + testNoAuthUrl := getEnvOrDefaultString("GREMLIN_SERVER_URL", noAuthUrl) + testNoAuthEnable := getEnvOrDefaultBool("RUN_INTEGRATION_TESTS", true) + + t.Run("should auto serialize request message with interceptor mutation", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + // Interceptor replaces the RequestMessage body with a different query. + // The driver should auto-serialize the modified RequestMessage and the server + // should execute the modified query. + client, err := NewClient(testNoAuthUrl, + func(settings *ClientSettings) { + settings.TraversalSource = testServerModernGraphAlias + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { + if msg, ok := req.Body.(*RequestMessage); ok { + req.Body = &RequestMessage{ + Gremlin: "g.inject(99)", + Fields: map[string]interface{}{"g": msg.Fields["g"], "language": "gremlin-lang"}, + } + } + return nil + }, + } + }) + assert.Nil(t, err) + assert.NotNil(t, client) + defer client.Close() + + rs, err := client.Submit("g.inject(1)") + assert.Nil(t, err) + result, ok, err := rs.One() + assert.Nil(t, err) + assert.True(t, ok) + val, err := result.GetInt() + assert.Nil(t, err) + assert.Equal(t, 99, val) + }) + + t.Run("should propagate exception thrown during interceptor", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + // Only throw on the first request to verify recovery. + var callCount int + var mu sync.Mutex + + client, err := NewClient(testNoAuthUrl, + func(settings *ClientSettings) { + settings.TraversalSource = testServerModernGraphAlias + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { + mu.Lock() + callCount++ + count := callCount + mu.Unlock() + if count == 1 { + return fmt.Errorf("interceptor broke") + } + return nil + }, + } + }) + assert.Nil(t, err) + assert.NotNil(t, client) + defer client.Close() + + // First request should fail with interceptor error + rs, err := client.Submit("g.inject(1)") + if err != nil { + assert.Contains(t, err.Error(), "interceptor broke") + } else { + _, err = rs.All() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "interceptor broke") + } + + // Subsequent request should succeed, proving connection recovery + rs, err = client.Submit("g.inject(2)") + assert.Nil(t, err) + result, ok, err := rs.One() + assert.Nil(t, err) + assert.True(t, ok) + val, err := result.GetInt() + assert.Nil(t, err) + assert.Equal(t, 2, val) + }) + + t.Run("should propagate error when interceptor sets unsupported body type", func(t *testing.T) { + skipTestsIfNotEnabled(t, integrationTestSuiteName, testNoAuthEnable) + + // Only set invalid body on the first request to verify recovery. + var callCount int + var mu sync.Mutex + + client, err := NewClient(testNoAuthUrl, + func(settings *ClientSettings) { + settings.TraversalSource = testServerModernGraphAlias + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { + mu.Lock() + callCount++ + count := callCount + mu.Unlock() + if count == 1 { + req.Body = 42 + } + return nil + }, + } + }) + assert.Nil(t, err) + assert.NotNil(t, client) + defer client.Close() + + // First request should fail with serialization error + rs, err := client.Submit("g.inject(1)") + if err != nil { + assert.Contains(t, err.Error(), "unsupported body type") + } else { + _, err = rs.All() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "unsupported body type") + } + + // Subsequent request should succeed, proving connection recovery + rs, err = client.Submit("g.inject(2)") + assert.Nil(t, err) + result, ok, err := rs.One() + assert.Nil(t, err) + assert.True(t, ok) + val, err := result.GetInt() + assert.Nil(t, err) + assert.Equal(t, 2, val) + }) +} + // TestConnectionWithMockServer_BasicAuth verifies that BasicAuth interceptor sets the correct // Authorization header and the body is still valid serialized bytes. func TestConnectionWithMockServer_BasicAuth(t *testing.T) { @@ -1340,6 +1474,6 @@ func TestConnectionWithMockServer_BasicAuth(t *testing.T) { // Body should still be valid serialized bytes assert.NotEmpty(t, capturedBody, "serialized body should be non-empty with BasicAuth") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + assert.Equal(t, byte('{'), capturedBody[0], + "body should start with '{' (JSON)") } diff --git a/gremlin-go/driver/driverRemoteConnection.go b/gremlin-go/driver/driverRemoteConnection.go index a5a47a8c517..3450b7e1792 100644 --- a/gremlin-go/driver/driverRemoteConnection.go +++ b/gremlin-go/driver/driverRemoteConnection.go @@ -60,6 +60,11 @@ type DriverRemoteConnectionSettings struct { // RequestInterceptors are functions that modify HTTP requests before sending. RequestInterceptors []RequestInterceptor + // Auth is a RequestInterceptor for authentication (e.g. BasicAuth, SigV4Auth). + // As a convenience, this is always appended to the end of the interceptor list + // so it runs last, after any user interceptors have modified the request. + Auth RequestInterceptor + // PDTRegistry enables registry-based dehydration in the gremlin-lang translator. PDTRegistry *PDTRegistry } @@ -118,6 +123,11 @@ func NewDriverRemoteConnection( conn.AddInterceptor(interceptor) } + // Auth interceptor is always last so it runs after user interceptors + if settings.Auth != nil { + conn.AddInterceptor(settings.Auth) + } + client := &Client{ url: url, traversalSource: settings.TraversalSource, diff --git a/gremlin-go/driver/interceptor.go b/gremlin-go/driver/interceptor.go index e7e0c8e0879..7485303d400 100644 --- a/gremlin-go/driver/interceptor.go +++ b/gremlin-go/driver/interceptor.go @@ -23,9 +23,12 @@ import ( "bytes" "crypto/sha256" "encoding/hex" + "encoding/json" + "fmt" "io" "net/http" "net/url" + "strconv" ) // Common HTTP header keys @@ -90,24 +93,33 @@ func (r *HttpRequest) PayloadHash() string { } } -// RequestInterceptor is a function that modifies an HTTP request before it is sent. -type RequestInterceptor func(*HttpRequest) error - -// SerializeRequest returns a RequestInterceptor that serializes the raw *RequestMessage body -// to GraphBinary []byte. Place this before auth interceptors (e.g., SigV4Auth) that -// need the serialized body bytes. -func SerializeRequest() RequestInterceptor { - serializer := newGraphBinarySerializer(nil) - return func(req *HttpRequest) error { - r, ok := req.Body.(*RequestMessage) - if !ok { - return nil // already serialized or not a *RequestMessage +// SerializeBody serializes the request body to JSON if it is still a *RequestMessage. +// If the body is already []byte, it returns those bytes (idempotent). +// On successful serialization from *RequestMessage, it sets the body to the resulting bytes +// and updates the Content-Type and Content-Length headers. +func (r *HttpRequest) SerializeBody() ([]byte, error) { + switch b := r.Body.(type) { + case []byte: + return b, nil + case *RequestMessage: + payload := make(map[string]interface{}) + payload["gremlin"] = b.Gremlin + for k, v := range b.Fields { + payload[k] = v } - data, err := serializer.SerializeMessage(r) + data, err := json.Marshal(payload) if err != nil { - return err + return nil, fmt.Errorf("failed to serialize request to JSON: %w", err) } - req.Body = data - return nil + r.Body = data + r.Headers.Set(HeaderContentType, "application/json") + r.Headers.Set("Content-Length", strconv.Itoa(len(data))) + return data, nil + default: + return nil, fmt.Errorf("cannot serialize request body of type %T; expected *RequestMessage or []byte. "+ + "If an interceptor modified the body, it must set it to []byte or leave it as *RequestMessage", r.Body) } } + +// RequestInterceptor is a function that modifies an HTTP request before it is sent. +type RequestInterceptor func(*HttpRequest) error diff --git a/gremlin-go/driver/interceptor_test.go b/gremlin-go/driver/interceptor_test.go index adb51786fc9..3931dff5795 100644 --- a/gremlin-go/driver/interceptor_test.go +++ b/gremlin-go/driver/interceptor_test.go @@ -21,12 +21,13 @@ package gremlingo import ( "bytes" + "encoding/json" "fmt" "io" "net/http" "net/http/httptest" "reflect" - "strings" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -70,15 +71,210 @@ func TestInterceptorReceivesRawRequest(t *testing.T) { } } -// TestSigV4AuthWithSerializeInterceptor verifies that SerializeRequest() + SigV4Auth -// works in a chain. SerializeRequest converts *RequestMessage to []byte, then SigV4Auth -// can sign the serialized body. -func TestSigV4AuthWithSerializeInterceptor(t *testing.T) { +// TestInterceptorCanModifyHeaders verifies that interceptors can read and modify headers. +func TestInterceptorCanModifyHeaders(t *testing.T) { var capturedHeaders http.Header - var capturedBody []byte server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + conn.AddInterceptor(func(req *HttpRequest) error { + req.Headers.Set("X-Custom-Header", "custom-value") + return nil + }) + + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header"), + "interceptor should be able to set custom headers") +} + +// TestInterceptorCanModifyURI verifies that interceptors can modify the request URI. +func TestInterceptorCanModifyURI(t *testing.T) { + var capturedPath string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL+"/gremlin", &connectionSettings{}) + + conn.AddInterceptor(func(req *HttpRequest) error { + req.URL.Path = "/custom-path" + return nil + }) + + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, "/custom-path", capturedPath, + "interceptor should be able to modify the URL path") +} + +// TestInterceptor_ChainOrder verifies that interceptors run in the order they are added. +func TestInterceptor_ChainOrder(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) + + var order []int + + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 1) + return nil + }) + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 2) + return nil + }) + conn.AddInterceptor(func(req *HttpRequest) error { + order = append(order, 3) + return nil + }) + + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, []int{1, 2, 3}, order, + "interceptors should run in the order they were added") +} + +// TestSerializeBody_JSONOutput verifies that SerializeBody produces valid JSON containing +// the gremlin field and all fields from RequestMessage.Fields. +func TestSerializeBody_JSONOutput(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V().has('name','marko')", + Fields: map[string]interface{}{"language": "gremlin-lang", "g": "g"}, + } + + data, err := req.SerializeBody() + require.NoError(t, err) + + var parsed map[string]interface{} + err = json.Unmarshal(data, &parsed) + require.NoError(t, err, "SerializeBody should produce valid JSON") + + assert.Equal(t, "g.V().has('name','marko')", parsed["gremlin"]) + assert.Equal(t, "gremlin-lang", parsed["language"]) + assert.Equal(t, "g", parsed["g"]) +} + +// TestSerializeBody_SetsContentTypeHeader verifies that SerializeBody sets Content-Type to application/json. +func TestSerializeBody_SetsContentTypeHeader(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V()", + Fields: map[string]interface{}{}, + } + + _, err = req.SerializeBody() + require.NoError(t, err) + + assert.Equal(t, "application/json", req.Headers.Get(HeaderContentType), + "SerializeBody should set Content-Type to application/json") +} + +// TestSerializeBody_SetsContentLengthHeader verifies that SerializeBody sets Content-Length +// to the byte length of the serialized body. +func TestSerializeBody_SetsContentLengthHeader(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V()", + Fields: map[string]interface{}{}, + } + + data, err := req.SerializeBody() + require.NoError(t, err) + + expected := strconv.Itoa(len(data)) + assert.Equal(t, expected, req.Headers.Get("Content-Length"), + "SerializeBody should set Content-Length to byte length of the body") +} + +// TestSerializeBody_Idempotent verifies that calling SerializeBody when body is already +// []byte returns the same bytes without re-serialization. +func TestSerializeBody_Idempotent(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V()", + Fields: map[string]interface{}{"g": "g"}, + } + + // First call serializes + data1, err := req.SerializeBody() + require.NoError(t, err) + + // Second call should return same bytes + data2, err := req.SerializeBody() + require.NoError(t, err) + + assert.Equal(t, data1, data2, + "SerializeBody should be idempotent, returning the same bytes on subsequent calls") +} + +// TestSerializeBody_MultipleCalls verifies that multiple calls produce identical results. +func TestSerializeBody_MultipleCalls(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = &RequestMessage{ + Gremlin: "g.V().count()", + Fields: map[string]interface{}{"language": "gremlin-lang"}, + } + + results := make([][]byte, 3) + for i := range results { + data, err := req.SerializeBody() + require.NoError(t, err) + results[i] = data + } + + assert.Equal(t, results[0], results[1]) + assert.Equal(t, results[1], results[2]) +} + +// TestSerializeBody_UnsupportedType verifies that SerializeBody returns an error when +// body is neither *RequestMessage nor []byte. +func TestSerializeBody_UnsupportedType(t *testing.T) { + req, err := NewHttpRequest("POST", "http://localhost:8182/gremlin") + require.NoError(t, err) + + req.Body = 42 + + _, err = req.SerializeBody() + require.Error(t, err, "SerializeBody should return error for unsupported body type") + assert.Contains(t, err.Error(), "cannot serialize request body") +} + +// TestFieldMutationBeforeSerialization verifies that an interceptor can modify +// RequestMessage fields and the serialized output reflects those changes. +func TestFieldMutationBeforeSerialization(t *testing.T) { + var capturedBody []byte + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err == nil { capturedBody = body @@ -89,35 +285,34 @@ func TestSigV4AuthWithSerializeInterceptor(t *testing.T) { conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) - mockProvider := &mockCredentialsProvider{ - accessKey: "MOCK_ID", - secretKey: "MOCK_KEY", - } - - conn.AddInterceptor(SerializeRequest()) - conn.AddInterceptor(SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", mockProvider)) + // Interceptor that adds a custom field before serialization + conn.AddInterceptor(func(req *HttpRequest) error { + r, ok := req.Body.(*RequestMessage) + if !ok { + return fmt.Errorf("expected *RequestMessage, got %T", req.Body) + } + r.Fields["customField"] = "customValue" + return nil + }) - rs, err := conn.submit(&RequestMessage{Gremlin: "g.V().count()", Fields: map[string]interface{}{}}) + rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{"g": "g"}}) require.NoError(t, err) - _, _ = rs.All() // drain + _, _ = rs.All() - // SigV4 should have added Authorization and X-Amz-Date headers - assert.NotEmpty(t, capturedHeaders.Get("Authorization"), - "SigV4Auth should set Authorization header after SerializeRequest") - assert.NotEmpty(t, capturedHeaders.Get("X-Amz-Date"), - "SigV4Auth should set X-Amz-Date header") - assert.Contains(t, capturedHeaders.Get("Authorization"), "AWS4-HMAC-SHA256", - "Authorization header should use AWS4-HMAC-SHA256 signing algorithm") + // Parse the JSON body sent to the server + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "server should receive valid JSON") - // Body should be valid serialized bytes - assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + assert.Equal(t, "g.V()", parsed["gremlin"]) + assert.Equal(t, "g", parsed["g"]) + assert.Equal(t, "customValue", parsed["customField"], + "interceptor field mutation should be reflected in the serialized output") } -// TestSigV4Auth_AutoSerializesInChain verifies that SigV4Auth works as the only -// interceptor — it auto-serializes *RequestMessage before signing. -func TestSigV4Auth_AutoSerializesInChain(t *testing.T) { +// TestSigV4AuthWithSerializeBody verifies that SigV4Auth calls SerializeBody and signs +// the request properly. +func TestSigV4AuthWithSerializeBody(t *testing.T) { var capturedHeaders http.Header var capturedBody []byte @@ -145,17 +340,86 @@ func TestSigV4Auth_AutoSerializesInChain(t *testing.T) { require.NoError(t, err) _, _ = rs.All() + // SigV4 should have added Authorization and X-Amz-Date headers assert.NotEmpty(t, capturedHeaders.Get("Authorization"), "SigV4Auth should set Authorization header") - assert.Contains(t, capturedHeaders.Get("Authorization"), "AWS4-HMAC-SHA256") + assert.NotEmpty(t, capturedHeaders.Get("X-Amz-Date"), + "SigV4Auth should set X-Amz-Date header") + assert.Contains(t, capturedHeaders.Get("Authorization"), "AWS4-HMAC-SHA256", + "Authorization header should use AWS4-HMAC-SHA256 signing algorithm") + + // Body should be valid JSON assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "body should be valid JSON") + assert.Equal(t, "g.V().count()", parsed["gremlin"]) +} + +// TestSigV4Auth_AutoSerializesRequestMessage verifies that SigV4Auth automatically +// serializes *RequestMessage to JSON bytes before signing. +func TestSigV4Auth_AutoSerializesRequestMessage(t *testing.T) { + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", + } + interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) + + req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") + require.NoError(t, err) + req.Headers.Set("Content-Type", "application/json") + req.Headers.Set("Accept", graphBinaryMimeType) + + // Set Body to *RequestMessage + req.Body = &RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}} + + err = interceptor(req) + require.NoError(t, err, "SigV4Auth should auto-serialize *RequestMessage") + + // Body should now be []byte (serialized JSON) + bodyBytes, ok := req.Body.([]byte) + assert.True(t, ok, "Body should be []byte after SigV4Auth auto-serialization") + assert.NotEmpty(t, bodyBytes, "serialized body should be non-empty") + + // Verify it's valid JSON + var parsed map[string]interface{} + err = json.Unmarshal(bodyBytes, &parsed) + require.NoError(t, err, "body should be valid JSON after auto-serialization") + assert.Equal(t, "g.V()", parsed["gremlin"]) + + // SigV4 headers should be set + assert.NotEmpty(t, req.Headers.Get("Authorization"), "Authorization header should be set") + assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"), "X-Amz-Date header should be set") + assert.Contains(t, req.Headers.Get("Authorization"), "AWS4-HMAC-SHA256") +} + +// TestSigV4Auth_RejectsNonByteBody verifies that SigV4Auth returns an error when Body +// is not []byte and not *RequestMessage (e.g., an io.Reader). +func TestSigV4Auth_RejectsNonByteBody(t *testing.T) { + provider := &mockCredentialsProvider{ + accessKey: "MOCK_ID", + secretKey: "MOCK_KEY", + } + interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) + + req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") + require.NoError(t, err) + req.Headers.Set("Content-Type", "application/json") + req.Headers.Set("Accept", graphBinaryMimeType) + + // Set Body to an unsupported type (not []byte and not *RequestMessage) + req.Body = bytes.NewReader([]byte("not bytes")) + + err = interceptor(req) + require.Error(t, err, "SigV4Auth should reject non-[]byte, non-*RequestMessage body") + assert.Contains(t, err.Error(), "cannot serialize request body", + "error message should indicate unsupported body type") } -// TestMultipleInterceptors_SerializeThenAuth verifies that a custom interceptor can -// modify the raw request, then SerializeRequest serializes it, then BasicAuth adds headers. -func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) { +// TestMultipleInterceptors_MutateThenAuth verifies that a custom interceptor can +// modify the raw request fields, then BasicAuth adds headers, and the driver +// auto-serializes to JSON. +func TestMultipleInterceptors_MutateThenAuth(t *testing.T) { var capturedAuthHeader string var capturedBody []byte @@ -182,9 +446,6 @@ func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) { return nil }) - // SerializeRequest converts the modified *RequestMessage to []byte - conn.AddInterceptor(SerializeRequest()) - // BasicAuth adds the Authorization header (works on any body type) conn.AddInterceptor(BasicAuth("admin", "secret")) @@ -196,10 +457,13 @@ func TestMultipleInterceptors_SerializeThenAuth(t *testing.T) { assert.Equal(t, "Basic YWRtaW46c2VjcmV0", capturedAuthHeader, "Authorization header should be Basic base64(admin:secret)") - // Body should be valid serialized bytes (from SerializeRequest) - assert.NotEmpty(t, capturedBody, "body should be non-empty serialized bytes") - assert.Equal(t, byte(0x84), capturedBody[0], - "body should start with GraphBinary version byte 0x84") + // Body should be valid JSON (auto-serialized by driver after interceptors) + assert.NotEmpty(t, capturedBody, "body should be non-empty") + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "body should be valid JSON") + assert.Equal(t, "g.V()", parsed["gremlin"]) + assert.Equal(t, "customValue", parsed["customField"]) } // TestInterceptor_IoReaderBody verifies that an interceptor can set Body to an io.Reader @@ -235,23 +499,6 @@ func TestInterceptor_IoReaderBody(t *testing.T) { "server should receive the custom payload set via io.Reader") } -// TestInterceptor_NilSerializerNoSerialization verifies that when serializer is nil -// and no interceptor serializes, the correct error message is produced. -func TestInterceptor_NilSerializerNoSerialization(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - defer server.Close() - - conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) - conn.serializer = nil // explicitly nil serializer - - _, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) - require.Error(t, err, "should get an error when serializer is nil and no interceptor serializes") - assert.Contains(t, err.Error(), "request body was not serialized", - "error message should indicate the body was not serialized") -} - // TestInterceptor_HttpRequestBody verifies that an interceptor can set Body to *http.Request // and the driver sends it directly, using the *http.Request's headers and body instead of // HttpRequest.Headers. @@ -285,28 +532,13 @@ func TestInterceptor_HttpRequestBody(t *testing.T) { return nil }) - // Also set a header on HttpRequest.Headers that should NOT appear, - // because *http.Request body bypasses HttpRequest.Headers - conn.AddInterceptor(func(req *HttpRequest) error { - req.Headers.Set("X-Should-Not-Appear", "ignored") - return nil - }) - rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) require.NoError(t, err) _, _ = rs.All() // drain - // The server should receive headers from the *http.Request, not from HttpRequest.Headers - assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header"), - "server should receive custom header from *http.Request") - assert.Equal(t, "application/octet-stream", capturedHeaders.Get("Content-Type"), - "server should receive Content-Type from *http.Request") - assert.Empty(t, capturedHeaders.Get("X-Should-Not-Appear"), - "headers set on HttpRequest.Headers should not appear when Body is *http.Request") - - // The server should receive the body from the *http.Request - assert.Equal(t, customBody, capturedBody, - "server should receive body from the *http.Request") + assert.Equal(t, "custom-value", capturedHeaders.Get("X-Custom-Header")) + assert.Equal(t, "application/octet-stream", capturedHeaders.Get("Content-Type")) + assert.Equal(t, customBody, capturedBody) } // TestInterceptor_ErrorPropagation verifies that when an interceptor returns an error, @@ -351,88 +583,92 @@ func TestInterceptor_UnsupportedBodyType(t *testing.T) { "error message should indicate unsupported body type") } -// TestInterceptor_ChainOrder verifies that interceptors run in the order they are added. -func TestInterceptor_ChainOrder(t *testing.T) { +// TestDriverAutoSerializes verifies that without any interceptors, the driver +// auto-serializes the request body to JSON. +func TestDriverAutoSerializes(t *testing.T) { + var capturedBody []byte + var capturedContentType string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedContentType = r.Header.Get("Content-Type") + body, err := io.ReadAll(r.Body) + if err == nil { + capturedBody = body + } w.WriteHeader(http.StatusOK) })) defer server.Close() conn := newConnection(newTestLogHandler(), server.URL, &connectionSettings{}) - var order []int - - conn.AddInterceptor(func(req *HttpRequest) error { - order = append(order, 1) - return nil - }) - conn.AddInterceptor(func(req *HttpRequest) error { - order = append(order, 2) - return nil - }) - conn.AddInterceptor(func(req *HttpRequest) error { - order = append(order, 3) - return nil + // No interceptors, driver should auto-serialize + rs, err := conn.submit(&RequestMessage{ + Gremlin: "g.V().count()", + Fields: map[string]interface{}{"language": "gremlin-lang", "g": "g"}, }) - - rs, err := conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) require.NoError(t, err) - _, _ = rs.All() // drain - - assert.Equal(t, []int{1, 2, 3}, order, - "interceptors should run in the order they were added") -} - -// TestSigV4Auth_RejectsNonByteBody verifies that SigV4Auth returns an error when Body -// is not []byte and not *RequestMessage (e.g., an io.Reader). -func TestSigV4Auth_RejectsNonByteBody(t *testing.T) { - provider := &mockCredentialsProvider{ - accessKey: "MOCK_ID", - secretKey: "MOCK_KEY", - } - interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) - - req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") - require.NoError(t, err) - req.Headers.Set("Content-Type", graphBinaryMimeType) - req.Headers.Set("Accept", graphBinaryMimeType) + _, _ = rs.All() - // Set Body to an unsupported type (not []byte and not *RequestMessage) - req.Body = strings.NewReader("not bytes") + assert.Equal(t, "application/json", capturedContentType, + "Content-Type should be application/json") - err = interceptor(req) - require.Error(t, err, "SigV4Auth should reject non-[]byte, non-*RequestMessage body") - assert.Contains(t, err.Error(), "SigV4 signing requires body to be []byte", - "error message should indicate SigV4 requires []byte body") + var parsed map[string]interface{} + err = json.Unmarshal(capturedBody, &parsed) + require.NoError(t, err, "body should be valid JSON") + assert.Equal(t, "g.V().count()", parsed["gremlin"]) + assert.Equal(t, "gremlin-lang", parsed["language"]) + assert.Equal(t, "g", parsed["g"]) } -// TestSigV4Auth_AutoSerializesRequestMessage verifies that SigV4Auth automatically -// serializes *RequestMessage to []byte before signing. -func TestSigV4Auth_AutoSerializesRequestMessage(t *testing.T) { - provider := &mockCredentialsProvider{ - accessKey: "MOCK_ID", - secretKey: "MOCK_KEY", - } - interceptor := SigV4AuthWithCredentials("gremlin-east-1", "tinkerpop-sigv4", provider) - - req, err := NewHttpRequest("POST", "https://test_url:8182/gremlin") - require.NoError(t, err) - req.Headers.Set("Content-Type", graphBinaryMimeType) - req.Headers.Set("Accept", graphBinaryMimeType) - - // Set Body to *RequestMessage — SigV4Auth should auto-serialize it - req.Body = &RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}} +// TestAuthInterceptorIsAlwaysLast verifies that the Auth field is always appended +// to the end of the interceptor chain, regardless of configuration order. +func TestAuthInterceptorIsAlwaysLast(t *testing.T) { + var order []int - err = interceptor(req) - require.NoError(t, err, "SigV4Auth should auto-serialize *RequestMessage") + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() - // Body should now be []byte (serialized) - bodyBytes, ok := req.Body.([]byte) - assert.True(t, ok, "Body should be []byte after SigV4Auth auto-serialization") - assert.NotEmpty(t, bodyBytes, "serialized body should be non-empty") + t.Run("Auth runs after RequestInterceptors on Client", func(t *testing.T) { + order = nil + client, err := NewClient(server.URL, + func(settings *ClientSettings) { + settings.Auth = func(req *HttpRequest) error { order = append(order, 3); return nil } + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { order = append(order, 1); return nil }, + func(req *HttpRequest) error { order = append(order, 2); return nil }, + } + }) + require.NoError(t, err) + defer client.Close() + + rs, err := client.conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, []int{1, 2, 3}, order, + "Auth interceptor should always run last") + }) - // SigV4 headers should be set - assert.NotEmpty(t, req.Headers.Get("Authorization"), "Authorization header should be set") - assert.NotEmpty(t, req.Headers.Get("X-Amz-Date"), "X-Amz-Date header should be set") - assert.Contains(t, req.Headers.Get("Authorization"), "AWS4-HMAC-SHA256") + t.Run("Auth runs after RequestInterceptors on DriverRemoteConnection", func(t *testing.T) { + order = nil + remote, err := NewDriverRemoteConnection(server.URL, + func(settings *DriverRemoteConnectionSettings) { + settings.Auth = func(req *HttpRequest) error { order = append(order, 3); return nil } + settings.RequestInterceptors = []RequestInterceptor{ + func(req *HttpRequest) error { order = append(order, 1); return nil }, + func(req *HttpRequest) error { order = append(order, 2); return nil }, + } + }) + require.NoError(t, err) + defer remote.Close() + + rs, err := remote.client.conn.submit(&RequestMessage{Gremlin: "g.V()", Fields: map[string]interface{}{}}) + require.NoError(t, err) + _, _ = rs.All() + + assert.Equal(t, []int{1, 2, 3}, order, + "Auth interceptor should always run last") + }) } diff --git a/gremlin-go/driver/request.go b/gremlin-go/driver/request.go index 65f8662cf0c..ebf375d49bf 100644 --- a/gremlin-go/driver/request.go +++ b/gremlin-go/driver/request.go @@ -19,8 +19,6 @@ under the License. package gremlingo -import "strconv" - // RequestMessage represents a request to the server. type RequestMessage struct { Gremlin string @@ -31,7 +29,7 @@ type RequestMessage struct { // // This function is exposed publicly to enable alternative transport protocols (gRPC, HTTP/2, etc.) // to construct properly formatted requests outside the standard HTTP client. The returned -// request can then be serialized using SerializeMessage(). +// request can be placed in an HttpRequest.Body and serialized via SerializeBody(). // // Parameters: // - stringGremlin: The Gremlin query string to execute @@ -44,8 +42,9 @@ type RequestMessage struct { // Example for alternative transports: // // req := MakeStringRequest("g.V().count()", "g", RequestOptions{}) -// serializer := newGraphBinarySerializer(nil) -// bytes, _ := serializer.(graphBinarySerializer).SerializeMessage(&req) +// httpReq, _ := NewHttpRequest("POST", "http://localhost:8182/gremlin") +// httpReq.Body = &req +// bytes, _ := httpReq.SerializeBody() // // Send bytes over gRPC, HTTP/2, etc. func MakeStringRequest(stringGremlin string, traversalSource string, requestOptions RequestOptions) (req RequestMessage) { newFields := map[string]interface{}{ @@ -74,7 +73,7 @@ func MakeStringRequest(stringGremlin string, traversalSource string, requestOpti } if requestOptions.bulkResults != nil { - newFields["bulkResults"] = strconv.FormatBool(*requestOptions.bulkResults) + newFields["bulkResults"] = *requestOptions.bulkResults } if requestOptions.transactionId != "" { diff --git a/gremlin-go/driver/request_test.go b/gremlin-go/driver/request_test.go index 2da846cc929..f1752b0f6c7 100644 --- a/gremlin-go/driver/request_test.go +++ b/gremlin-go/driver/request_test.go @@ -55,7 +55,7 @@ func TestRequest(t *testing.T) { t.Run("Test makeStringRequest() with bulkResults", func(t *testing.T) { r := MakeStringRequest("g.V()", "g", new(RequestOptionsBuilder).SetBulkResults(true).Create()) - assert.Equal(t, "true", r.Fields["bulkResults"]) + assert.Equal(t, true, r.Fields["bulkResults"]) }) t.Run("Test makeStringRequest() with string bindings", func(t *testing.T) { diff --git a/gremlin-go/driver/serializer.go b/gremlin-go/driver/serializer.go index 38d618600cc..509d51657f9 100644 --- a/gremlin-go/driver/serializer.go +++ b/gremlin-go/driver/serializer.go @@ -28,7 +28,7 @@ import ( const graphBinaryMimeType = "application/vnd.graphbinary-v4.0" -// Serializer interface for serializers. +// Serializer interface for serializing requests and deserializing responses. type Serializer interface { SerializeMessage(request *RequestMessage) ([]byte, error) DeserializeMessage(message []byte) (Response, error) @@ -58,27 +58,17 @@ const versionByte byte = 0x84 // SerializeMessage serializes a request message into GraphBinary format. // -// This method is part of the serializer interface and is used internally by the HTTP driver. -// It is also exposed publicly to enable alternative transport protocols (gRPC, HTTP/2, etc.) to -// serialize requests created with MakeBytecodeRequest() or MakeStringRequest(). -// -// The serialized bytes can be transmitted over any transport protocol that supports binary data. +// This method is part of the Serializer interface. It is no longer used by the default driver +// flow (which now serializes requests as JSON via HttpRequest.SerializeBody), but remains +// available for custom interceptors or alternative transport protocols that require +// GraphBinary-encoded requests. // // Parameters: // - request: The request to serialize (created via MakeBytecodeRequest or MakeStringRequest) // // Returns: -// - []byte: The GraphBinary-encoded request ready for transmission +// - serialized: The GraphBinary-encoded request bytes // - error: Any serialization error encountered -// -// Example for alternative transports: -// -// req := MakeBytecodeRequest(bytecode, "g", "") -// serializer := newGraphBinarySerializer(nil) -// bytes, err := serializer.(graphBinarySerializer).SerializeMessage(&req) -// // Send bytes over custom transport -// -// SerializeMessage serializes a request message into GraphBinary. func (gs *GraphBinarySerializer) SerializeMessage(request *RequestMessage) ([]byte, error) { finalMessage, err := gs.buildMessage(request.Gremlin, request.Fields) if err != nil { diff --git a/gremlin-js/gremlin-javascript/lib/driver/auth.ts b/gremlin-js/gremlin-javascript/lib/driver/auth.ts index 41f5573b27a..4d14f7c4f36 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/auth.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/auth.ts @@ -18,12 +18,11 @@ */ import { Buffer } from 'buffer'; -import type { HttpRequest, RequestInterceptor } from './connection.js'; +import type { HttpRequest, RequestInterceptor } from './http-request.js'; export function basic(username: string, password: string): RequestInterceptor { return (request: HttpRequest) => { request.headers['authorization'] = 'Basic ' + Buffer.from(`${username}:${password}`).toString('base64'); - return request; }; } @@ -40,6 +39,10 @@ export function sigv4(region: string, service: string, credentialsProvider?: Aws let resolvedProvider: AwsCredentialsProvider; return async (request: HttpRequest) => { + // Ensure body is serialized to JSON bytes before signing. + // serializeBody is idempotent: safe to call even if already serialized. + request.serializeBody(); + // Lazy-initialize signer and credentials provider on first use if (!signer) { const { SignatureV4 } = await import('@smithy/signature-v4'); @@ -76,6 +79,5 @@ export function sigv4(region: string, service: string, credentialsProvider?: Aws }); request.headers = signed.headers; - return request; }; } diff --git a/gremlin-js/gremlin-javascript/lib/driver/connection.ts b/gremlin-js/gremlin-javascript/lib/driver/connection.ts index fb91060af3f..1be591ca9a2 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/connection.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/connection.ts @@ -30,26 +30,16 @@ import StreamReader from '../structure/io/binary/internals/StreamReader.js'; import * as utils from '../utils.js'; import ResultSet from './result-set.js'; import {RequestMessage} from "./request-message.js"; +import { HttpRequest, RequestInterceptor } from './http-request.js'; import ResponseError from './response-error.js'; import { Traverser } from '../process/traversal.js'; -const { graphBinaryWriter } = ioc; - const responseStatusCode = { success: 200, noContent: 204, partialContent: 206, }; -export type HttpRequest = { - url: string; - method: string; - headers: Record; - body: any; -}; - -export type RequestInterceptor = (request: HttpRequest) => HttpRequest | Promise; - export type ConnectionOptions = { ca?: string[]; cert?: string | string[] | Buffer; @@ -59,11 +49,13 @@ export type ConnectionOptions = { reader?: any; rejectUnauthorized?: boolean; traversalSource?: string; - writer?: any | null; headers?: Record; enableUserAgentOnConnect?: boolean; agent?: Agent; interceptors?: RequestInterceptor | RequestInterceptor[]; + /** An optional auth interceptor. As a convenience, this is always appended to the end of the + * interceptor list so it runs last, after any user interceptors have modified the request. */ + auth?: RequestInterceptor; }; /** @@ -71,7 +63,6 @@ export type ConnectionOptions = { */ export default class Connection extends EventEmitter { private readonly _reader: any; - private readonly _writer: any | null; isOpen = true; traversalSource: string; @@ -91,7 +82,6 @@ export default class Connection extends EventEmitter { super(); this._reader = options.reader || (options.preciseNumbers === true ? createPreciseReader() : new GraphBinaryReader(ioc)); - this._writer = 'writer' in options ? options.writer : graphBinaryWriter; if (options.pdtRegistry) { this._reader.pdtRegistry = options.pdtRegistry; } @@ -102,12 +92,17 @@ export default class Connection extends EventEmitter { if (typeof interceptors === 'function') { this._interceptors = [interceptors]; } else if (Array.isArray(interceptors)) { - this._interceptors = interceptors; + this._interceptors = [...interceptors]; } else if (interceptors === undefined || interceptors === null) { this._interceptors = []; } else { throw new TypeError('interceptors must be a function, array, or undefined'); } + + // Auth interceptor is always last so it runs after user interceptors + if (options.auth) { + this._interceptors.push(options.auth); + } } /** @@ -123,8 +118,7 @@ export default class Connection extends EventEmitter { * Send a request and buffer the entire response. Returns a Promise. */ async submit(request: RequestMessage) { - const body = this._writer ? this._writer.writeRequest(request) : request; - const response = await this.#makeHttpRequest(body); + const response = await this.#makeHttpRequest(request); return this.#handleResponse(response); } @@ -142,12 +136,11 @@ export default class Connection extends EventEmitter { * @returns {AsyncGenerator} */ async *stream(request: RequestMessage): AsyncGenerator { - const body = this._writer ? this._writer.writeRequest(request) : request; const abortController = new AbortController(); let response: Response; try { - response = await this.#makeHttpRequest(body, abortController.signal); + response = await this.#makeHttpRequest(request, abortController.signal); } catch (e: any) { throw new Error(`Stream request failed: ${e.message}`, { cause: e }); } @@ -215,15 +208,11 @@ export default class Connection extends EventEmitter { return null; } - async #makeHttpRequest(body: any, signal?: AbortSignal): Promise { + async #makeHttpRequest(request: RequestMessage, signal?: AbortSignal): Promise { const headers: Record = { - 'Accept': this._reader.mimeType + 'Accept': this._reader.mimeType, }; - if (this._writer) { - headers['Content-Type'] = this._writer.mimeType; - } - if (this._enableUserAgentOnConnect) { const userAgent = await utils.getUserAgent(); if (userAgent !== undefined) { @@ -237,18 +226,13 @@ export default class Connection extends EventEmitter { }); } - let httpRequest: HttpRequest = { - url: this.url, - method: 'POST', - headers, - body, - }; + const httpRequest = new HttpRequest('POST', this.url, headers, request); // Promote transactionId to HTTP header before interceptors run. // The field remains in the serialized body as well (dual transmission // per the HTTP transaction protocol specification). - if (body instanceof RequestMessage) { - const fields = body.getFields(); + if (request instanceof RequestMessage) { + const fields = request.getFields(); if (fields.has('transactionId')) { httpRequest.headers['X-Transaction-Id'] = fields.get('transactionId'); } @@ -256,12 +240,15 @@ export default class Connection extends EventEmitter { for (let i = 0; i < this._interceptors.length; i++) { try { - httpRequest = await this._interceptors[i](httpRequest); + await this._interceptors[i](httpRequest); } catch (e: any) { throw new Error(`Request interceptor at index ${i} failed: ${e.message}`, { cause: e }); } } + // Auto-serialize body to JSON after interceptors run (idempotent if already serialized) + httpRequest.serializeBody(); + return fetch(httpRequest.url, { method: httpRequest.method, headers: httpRequest.headers, diff --git a/gremlin-js/gremlin-javascript/lib/driver/http-request.ts b/gremlin-js/gremlin-javascript/lib/driver/http-request.ts new file mode 100644 index 00000000000..f47fb62bb25 --- /dev/null +++ b/gremlin-js/gremlin-javascript/lib/driver/http-request.ts @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Buffer } from 'buffer'; +import { RequestMessage } from './request-message.js'; + +/** + * Represents an HTTP request that is passed through interceptors before being sent to the server. + * Interceptors may mutate any field. The {@link serializeBody} method serializes the body to JSON + * if it is still a RequestMessage, or returns the existing bytes if already serialized. + */ +export class HttpRequest { + url: string; + method: string; + headers: Record; + body: RequestMessage | Buffer | any; + + constructor(method: string, url: string, headers: Record, body: any) { + this.method = method; + this.url = url; + this.headers = headers; + this.body = body; + } + + /** + * Serializes the request body to JSON if it is still a RequestMessage. + * If body is already a Buffer, returns it as-is (idempotent). + * Sets Content-Type and Content-Length headers on successful serialization. + * + * @returns The serialized body as a Buffer. + * @throws Error if body is neither a RequestMessage nor a Buffer. + */ + serializeBody(): Buffer { + if (Buffer.isBuffer(this.body)) { + return this.body; + } + + if (this.body instanceof RequestMessage) { + // RequestMessage.toJSON() produces the flattened wire object (standard + custom fields). + const data = Buffer.from(JSON.stringify(this.body), 'utf-8'); + this.body = data; + this.headers['Content-Type'] = 'application/json'; + this.headers['Content-Length'] = String(data.length); + return data; + } + + const typeName = this.body === null ? 'null' : typeof this.body; + throw new Error(`unsupported body type: ${typeName}`); + } +} + +/** + * A request interceptor receives an HttpRequest and mutates it in place. + * Interceptors must not return a value. + */ +export type RequestInterceptor = (request: HttpRequest) => void | Promise; diff --git a/gremlin-js/gremlin-javascript/lib/driver/request-message.ts b/gremlin-js/gremlin-javascript/lib/driver/request-message.ts index f81d011cd37..c8783099354 100644 --- a/gremlin-js/gremlin-javascript/lib/driver/request-message.ts +++ b/gremlin-js/gremlin-javascript/lib/driver/request-message.ts @@ -100,6 +100,44 @@ export class RequestMessage { return this.customFields; } + /** + * Builds the plain object that represents this message on the wire. Standard fields are + * included when set, and custom fields are flattened to the top level. This method is + * invoked automatically by {@link JSON.stringify}. + * + * When a new standard field is added to this class, it should be added here as well so that + * it is included in the serialized request body. + */ + toJSON(): Record { + const payload: Record = { gremlin: this.gremlin }; + + if (this.language) { + payload['language'] = this.language; + } + if (this.g) { + payload['g'] = this.g; + } + if (this.bindings !== undefined) { + payload['bindings'] = this.bindings; + } + if (this.timeoutMs !== undefined) { + payload['timeoutMs'] = this.timeoutMs; + } + if (this.materializeProperties) { + payload['materializeProperties'] = this.materializeProperties; + } + if (this.bulkResults !== undefined) { + payload['bulkResults'] = this.bulkResults; + } + + // Flatten custom/provider fields to the top level + this.customFields.forEach((v, k) => { + payload[k] = v; + }); + + return payload; + } + static build(gremlin: string): Builder { return new Builder(gremlin); } diff --git a/gremlin-js/gremlin-javascript/lib/index.ts b/gremlin-js/gremlin-javascript/lib/index.ts index ef0e7ce1334..a2443b8d1d4 100644 --- a/gremlin-js/gremlin-javascript/lib/index.ts +++ b/gremlin-js/gremlin-javascript/lib/index.ts @@ -33,6 +33,7 @@ import DriverRemoteConnection from './driver/driver-remote-connection.js'; import ResponseError from './driver/response-error.js'; import Client from './driver/client.js'; import ResultSet from './driver/result-set.js'; +import { HttpRequest } from './driver/http-request.js'; import { basic, sigv4 } from './driver/auth.js'; import AnonymousTraversalSource from './process/anonymous-traversal.js'; @@ -44,6 +45,7 @@ export const driver = { DriverRemoteConnection, Client, ResultSet, + HttpRequest, auth: { basic, sigv4, diff --git a/gremlin-js/gremlin-javascript/test/integration/client-tests.js b/gremlin-js/gremlin-javascript/test/integration/client-tests.js index cf439aaa56b..eaf7a738c6f 100644 --- a/gremlin-js/gremlin-javascript/test/integration/client-tests.js +++ b/gremlin-js/gremlin-javascript/test/integration/client-tests.js @@ -273,4 +273,90 @@ describe('ProviderDefinedType - Client', function () { assert.strictEqual(list[1].fields.y, 4); }); }); +}); + +describe('Client interceptor integration', function () { + it('should auto serialize request message with interceptor mutation', async function () { + const { RequestMessage } = await import('../../lib/driver/request-message.js'); + const interceptor = (request) => { + if (request.body instanceof RequestMessage) { + request.body = RequestMessage.build('g.inject(99)').addG('gmodern').create(); + } + }; + const interceptorClient = new Client(serverUrl, { + traversalSource: 'gmodern', + interceptors: [interceptor], + }); + await interceptorClient.open(); + try { + const result = await interceptorClient.submit('g.inject(1)'); + assert.strictEqual(result.first(), 99); + } finally { + await interceptorClient.close(); + } + }); + + it('should propagate exception thrown during interceptor', async function () { + let callCount = 0; + const interceptor = () => { + callCount++; + if (callCount === 1) { + throw new Error('interceptor broke'); + } + }; + const interceptorClient = new Client(serverUrl, { + traversalSource: 'gmodern', + interceptors: [interceptor], + }); + await interceptorClient.open(); + try { + // First request should fail with interceptor error + await assert.rejects( + () => interceptorClient.submit('g.inject(1)'), + (err) => { + assert.ok(err.message.includes('interceptor') || err.message.includes('broke'), + `Expected error about interceptor, got: ${err.message}`); + return true; + } + ); + + // Subsequent request should succeed, proving connection recovery + const result = await interceptorClient.submit('g.inject(2)'); + assert.strictEqual(result.first(), 2); + } finally { + await interceptorClient.close(); + } + }); + + it('should propagate error when interceptor sets unsupported body type', async function () { + let callCount = 0; + const interceptor = (request) => { + callCount++; + if (callCount === 1) { + request.body = 42; + } + }; + const interceptorClient = new Client(serverUrl, { + traversalSource: 'gmodern', + interceptors: [interceptor], + }); + await interceptorClient.open(); + try { + // First request should fail with serialization error + await assert.rejects( + () => interceptorClient.submit('g.inject(1)'), + (err) => { + assert.ok(err.message.includes('unsupported body type') || err.message.includes('serialize'), + `Expected error about unsupported body type, got: ${err.message}`); + return true; + } + ); + + // Subsequent request should succeed, proving connection recovery + const result = await interceptorClient.submit('g.inject(2)'); + assert.strictEqual(result.first(), 2); + } finally { + await interceptorClient.close(); + } + }); }); \ No newline at end of file diff --git a/gremlin-js/gremlin-javascript/test/unit/auth-test.js b/gremlin-js/gremlin-javascript/test/unit/auth-test.js index cab2a5d24f5..e35fb1f8f47 100644 --- a/gremlin-js/gremlin-javascript/test/unit/auth-test.js +++ b/gremlin-js/gremlin-javascript/test/unit/auth-test.js @@ -20,18 +20,14 @@ import assert from 'assert'; import { Buffer } from 'buffer'; import { basic, sigv4 } from '../../lib/driver/auth.js'; +import { HttpRequest } from '../../lib/driver/http-request.js'; describe('auth', function () { describe('basic', function () { function createMockRequest() { - return { - url: 'https://localhost:8182/gremlin', - method: 'POST', - headers: { - 'accept': 'application/vnd.graphbinary-v4.0', - }, - body: new Uint8Array(0), - }; + return new HttpRequest('POST', 'https://localhost:8182/gremlin', { + 'accept': 'application/vnd.graphbinary-v4.0', + }, Buffer.from('')); } it('should add authorization header', function () { @@ -39,40 +35,28 @@ describe('auth', function () { assert.strictEqual(request.headers['authorization'], undefined); const interceptor = basic('username', 'password'); - const result = interceptor(request); + interceptor(request); - assert.ok(result.headers['authorization'].startsWith('Basic ')); + assert.ok(request.headers['authorization'].startsWith('Basic ')); }); it('should encode credentials correctly', function () { const request = createMockRequest(); const interceptor = basic('username', 'password'); - const result = interceptor(request); + interceptor(request); - const encoded = result.headers['authorization'].substring('Basic '.length); + const encoded = request.headers['authorization'].substring('Basic '.length); const decoded = Buffer.from(encoded, 'base64').toString(); assert.strictEqual(decoded, 'username:password'); }); - - it('should return the same request object', function () { - const request = createMockRequest(); - const interceptor = basic('username', 'password'); - const result = interceptor(request); - - assert.strictEqual(result, request); - }); }); describe('sigv4', function () { function createMockRequest() { - return { - url: 'https://localhost:8182/gremlin', - method: 'POST', - headers: { - 'accept': 'application/vnd.graphbinary-v4.0', - }, - body: new Uint8Array(Buffer.from('{"gremlin":"g.V()"}')), - }; + return new HttpRequest('POST', 'https://localhost:8182/gremlin', { + 'accept': 'application/vnd.graphbinary-v4.0', + 'content-type': 'application/json', + }, Buffer.from('{"gremlin":"g.V()"}')); } const mockProvider = () => ({ @@ -85,10 +69,10 @@ describe('auth', function () { assert.strictEqual(request.headers['authorization'], undefined); const interceptor = sigv4('xx-dummy-1', 'test-service', mockProvider); - const result = await interceptor(request); + await interceptor(request); - assert.ok(result.headers['x-amz-date']); - const authHeader = result.headers['authorization']; + assert.ok(request.headers['x-amz-date']); + const authHeader = request.headers['authorization']; assert.ok(authHeader.startsWith('AWS4-HMAC-SHA256 Credential=MOCK_ACCESS_KEY')); assert.ok(authHeader.includes('xx-dummy-1/test-service/aws4_request')); assert.ok(authHeader.includes('Signature=')); @@ -103,20 +87,33 @@ describe('auth', function () { }); const interceptor = sigv4('xx-dummy-1', 'test-service', providerWithToken); - const result = await interceptor(request); + await interceptor(request); - assert.strictEqual(result.headers['x-amz-security-token'], 'MOCK_SESSION_TOKEN'); - const authHeader = result.headers['authorization']; + assert.strictEqual(request.headers['x-amz-security-token'], 'MOCK_SESSION_TOKEN'); + const authHeader = request.headers['authorization']; assert.ok(authHeader.startsWith('AWS4-HMAC-SHA256 Credential=')); assert.ok(authHeader.includes('Signature=')); }); - it('should return the same request object', async function () { + it('should preserve pre-existing headers after signing', async function () { const request = createMockRequest(); - const interceptor = sigv4('xx-dummy-1', 'test-service', mockProvider); - const result = await interceptor(request); + // Capture the headers present before signing (accept + content-type) + const preSignKeys = Object.keys(request.headers); - assert.strictEqual(result, request); + const interceptor = sigv4('xx-dummy-1', 'test-service', mockProvider); + await interceptor(request); + + // The original headers must still be present (not dropped by wholesale replacement) + for (const key of preSignKeys) { + assert.ok( + key in request.headers, + `expected header '${key}' to be preserved after signing`); + } + assert.strictEqual(request.headers['accept'], 'application/vnd.graphbinary-v4.0'); + assert.strictEqual(request.headers['content-type'], 'application/json'); + + // Signing adds at least authorization and x-amz-date on top of the originals + assert.ok(Object.keys(request.headers).length >= preSignKeys.length + 2); }); }); }); diff --git a/gremlin-js/gremlin-javascript/test/unit/connection-test.js b/gremlin-js/gremlin-javascript/test/unit/connection-test.js new file mode 100644 index 00000000000..2ef6cc477ca --- /dev/null +++ b/gremlin-js/gremlin-javascript/test/unit/connection-test.js @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import assert from 'assert'; +import Connection from '../../lib/driver/connection.js'; +import { RequestMessage } from '../../lib/driver/request-message.js'; + +// Connection-level unit tests that verify interceptor wiring and auto-serialization +// by mocking the global fetch and capturing what the Connection sends. +describe('Connection (request pipeline)', function () { + let originalFetch; + let captured; + + beforeEach(function () { + captured = null; + originalFetch = global.fetch; + // Return an error response so we don't need a valid GraphBinary body; the test only + // cares about what was sent to fetch. The resulting ResponseError is swallowed per-test. + global.fetch = (url, init) => { + captured = { url, init }; + return Promise.resolve({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + headers: { get: () => 'text/plain' }, + arrayBuffer: () => Promise.resolve(Buffer.from('error')), + }); + }; + }); + + afterEach(function () { + global.fetch = originalFetch; + }); + + function makeConnection(options = {}) { + const conn = new Connection('http://localhost:8182/gremlin', + { enableUserAgentOnConnect: false, ...options }); + conn.isOpen = true; + return conn; + } + + async function submitAndIgnoreError(conn, msg) { + try { + await conn.submit(msg); + } catch (e) { + // Expected: the mock returns a 500 response. + } + } + + it('auto-serializes the body to JSON when no interceptor does', async function () { + const conn = makeConnection(); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.ok(captured, 'fetch should have been called'); + assert.strictEqual(captured.init.headers['Content-Type'], 'application/json'); + const parsed = JSON.parse(Buffer.from(captured.init.body).toString('utf-8')); + assert.strictEqual(parsed.gremlin, 'g.V()'); + assert.strictEqual(parsed.g, 'g'); + }); + + it('keeps the GraphBinary Accept header for responses', async function () { + const conn = makeConnection(); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.strictEqual(captured.init.headers['Accept'], 'application/vnd.graphbinary-v4.0'); + }); + + it('runs interceptors in registration order', async function () { + const order = []; + const conn = makeConnection({ + interceptors: [ + (req) => { order.push(1); }, + (req) => { order.push(2); }, + (req) => { order.push(3); }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.deepStrictEqual(order, [1, 2, 3]); + }); + + it('reflects interceptor body mutation in the serialized payload', async function () { + const conn = makeConnection({ + interceptors: [ + (req) => { + req.body = RequestMessage.build('g.inject(99)').addG('gmodern').create(); + }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + const parsed = JSON.parse(Buffer.from(captured.init.body).toString('utf-8')); + assert.strictEqual(parsed.gremlin, 'g.inject(99)'); + assert.strictEqual(parsed.g, 'gmodern'); + }); + + it('lets an interceptor add headers that reach the request', async function () { + const conn = makeConnection({ + interceptors: [ + (req) => { req.headers['X-Custom'] = 'value'; }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.strictEqual(captured.init.headers['X-Custom'], 'value'); + }); + + it('propagates interceptor errors to the caller', async function () { + const conn = makeConnection({ + interceptors: [ + () => { throw new Error('interceptor broke'); }, + ], + }); + + await assert.rejects( + () => conn.submit(RequestMessage.build('g.V()').addG('g').create()), + /interceptor broke/); + assert.strictEqual(captured, null, 'fetch should not be called when an interceptor throws'); + }); + + it('auth interceptor always runs last regardless of option order', async function () { + const order = []; + const conn = makeConnection({ + auth: (req) => { order.push(3); }, + interceptors: [ + (req) => { order.push(1); }, + (req) => { order.push(2); }, + ], + }); + await submitAndIgnoreError(conn, RequestMessage.build('g.V()').addG('g').create()); + + assert.deepStrictEqual(order, [1, 2, 3], 'auth interceptor should always run last'); + }); +}); diff --git a/gremlin-js/gremlin-javascript/test/unit/http-request-test.js b/gremlin-js/gremlin-javascript/test/unit/http-request-test.js new file mode 100644 index 00000000000..b3d03e442a1 --- /dev/null +++ b/gremlin-js/gremlin-javascript/test/unit/http-request-test.js @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import assert from 'assert'; +import { Buffer } from 'buffer'; +import { HttpRequest } from '../../lib/driver/http-request.js'; +import { RequestMessage } from '../../lib/driver/request-message.js'; + +describe('HttpRequest', function () { + describe('serializeBody()', function () { + it('should serialize RequestMessage to JSON bytes and set body to Buffer', function () { + const msg = RequestMessage.build("g.V().has('name','marko')") + .addG('g') + .create(); + + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + const data = httpReq.serializeBody(); + + assert(Buffer.isBuffer(data), 'should return a Buffer'); + assert(Buffer.isBuffer(httpReq.body), 'body should now be a Buffer'); + + const parsed = JSON.parse(data.toString('utf-8')); + assert.strictEqual(parsed.gremlin, "g.V().has('name','marko')"); + assert.strictEqual(parsed.g, 'g'); + assert.strictEqual(parsed.language, 'gremlin-lang'); + }); + + it('should set Content-Type to application/json', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + httpReq.serializeBody(); + + assert.strictEqual(httpReq.headers['Content-Type'], 'application/json'); + }); + + it('should set Content-Length to byte length of the serialized body', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const data = httpReq.serializeBody(); + + assert.strictEqual(httpReq.headers['Content-Length'], String(data.length)); + }); + + it('should be idempotent when body is already a Buffer', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const data1 = httpReq.serializeBody(); + const data2 = httpReq.serializeBody(); + + assert(data1.equals(data2), 'subsequent calls should return identical bytes'); + }); + + it('should produce identical results on multiple calls', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const results = []; + for (let i = 0; i < 3; i++) { + results.push(httpReq.serializeBody()); + } + + assert(results[0].equals(results[1])); + assert(results[1].equals(results[2])); + }); + + it('should return existing bytes if body is already a Buffer', function () { + const existing = Buffer.from('{"gremlin":"g.V()"}', 'utf-8'); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, existing); + + const data = httpReq.serializeBody(); + + assert.strictEqual(data, existing, 'should return the same Buffer reference'); + }); + + it('should include all fields from RequestMessage', function () { + const msg = RequestMessage.build("g.V().has('age',x)") + .addG('gCustom') + .addTimeoutMillis(5000) + .addMaterializeProperties('all') + .addBulkResults(true) + .addField('customKey', 'customValue') + .create(); + + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + const data = httpReq.serializeBody(); + const parsed = JSON.parse(data.toString('utf-8')); + + assert.strictEqual(parsed.gremlin, "g.V().has('age',x)"); + assert.strictEqual(parsed.g, 'gCustom'); + assert.strictEqual(parsed.language, 'gremlin-lang'); + assert.strictEqual(parsed.timeoutMs, 5000); + assert.strictEqual(parsed.materializeProperties, 'all'); + assert.strictEqual(parsed.bulkResults, true); + assert.strictEqual(parsed.customKey, 'customValue'); + }); + + it('should throw for unsupported body types', function () { + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, 42); + + assert.throws(() => httpReq.serializeBody(), /unsupported body type/); + }); + + it('should throw for null body', function () { + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, null); + + assert.throws(() => httpReq.serializeBody(), /unsupported body type/); + }); + }); + + describe('interceptors', function () { + it('should receive HttpRequest with body as RequestMessage', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', { 'Accept': 'application/vnd.graphbinary-v4.0' }, msg); + + let receivedBody; + const interceptor = (req) => { + receivedBody = req.body; + }; + + interceptor(httpReq); + + assert(receivedBody instanceof RequestMessage, 'interceptor should receive RequestMessage as body'); + }); + + it('should allow reading and modifying headers', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', { 'Accept': 'application/vnd.graphbinary-v4.0' }, msg); + + const interceptor = (req) => { + assert.strictEqual(req.headers['Accept'], 'application/vnd.graphbinary-v4.0'); + req.headers['X-Custom'] = 'test-value'; + }; + + interceptor(httpReq); + + assert.strictEqual(httpReq.headers['X-Custom'], 'test-value'); + }); + + it('should allow reading and modifying url', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + const interceptor = (req) => { + req.url = 'http://other-host:8182/gremlin'; + }; + + interceptor(httpReq); + + assert.strictEqual(httpReq.url, 'http://other-host:8182/gremlin'); + }); + + it('should run in registration order', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + const order = []; + + const interceptor1 = (req) => { order.push(1); }; + const interceptor2 = (req) => { order.push(2); }; + const interceptor3 = (req) => { order.push(3); }; + + interceptor1(httpReq); + interceptor2(httpReq); + interceptor3(httpReq); + + assert.deepStrictEqual(order, [1, 2, 3]); + }); + + it('should allow field mutation before serialization to affect output', function () { + const msg = RequestMessage.build('g.V()') + .addG('g') + .addField('providerField', 'original') + .create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + // Interceptor that calls serializeBody after modifying a header + // but a field-mutating interceptor must work at the RequestMessage level + // For field mutation, the interceptor replaces the body with a new RequestMessage + const interceptor = (req) => { + // Build a new request message with modified fields + const newMsg = RequestMessage.build('g.V().count()') + .addG('gModified') + .addField('providerField', 'modified') + .create(); + req.body = newMsg; + }; + + interceptor(httpReq); + const data = httpReq.serializeBody(); + const parsed = JSON.parse(data.toString('utf-8')); + + assert.strictEqual(parsed.gremlin, 'g.V().count()'); + assert.strictEqual(parsed.g, 'gModified'); + assert.strictEqual(parsed.providerField, 'modified'); + }); + + it('should allow interceptor to call serializeBody for payload hashing', function () { + const msg = RequestMessage.build('g.V()').addG('g').create(); + const httpReq = new HttpRequest('POST', 'http://localhost:8182/gremlin', {}, msg); + + // Simulates a SigV4 interceptor that needs the serialized bytes + const signingInterceptor = (req) => { + const bytes = req.serializeBody(); + // Use bytes for hashing (simulated) + req.headers['X-Payload-Hash'] = String(bytes.length); + }; + + signingInterceptor(httpReq); + + // Body should now be serialized + assert(Buffer.isBuffer(httpReq.body)); + assert.strictEqual(httpReq.headers['Content-Type'], 'application/json'); + assert(httpReq.headers['X-Payload-Hash'].length > 0); + + // Subsequent serializeBody call should be idempotent + const data = httpReq.serializeBody(); + assert(Buffer.isBuffer(data)); + }); + }); +}); diff --git a/gremlin-python/src/main/python/examples/connections.py b/gremlin-python/src/main/python/examples/connections.py index 654079d027e..a496d730df0 100644 --- a/gremlin-python/src/main/python/examples/connections.py +++ b/gremlin-python/src/main/python/examples/connections.py @@ -23,7 +23,6 @@ from gremlin_python.process.anonymous_traversal import traversal from gremlin_python.driver.driver_remote_connection import DriverRemoteConnection from gremlin_python.driver.auth import basic -from gremlin_python.driver.serializer import GraphBinarySerializersV4 VERTEX_LABEL = os.getenv('VERTEX_LABEL', 'connection') @@ -83,7 +82,6 @@ def with_configs(): server_url = os.getenv('GREMLIN_SERVER_URL', 'http://localhost:8182/gremlin').format(45940) rc = DriverRemoteConnection( server_url, 'g', - request_serializer=GraphBinarySerializersV4(), headers=None, ) g = traversal().with_remote(rc) diff --git a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py index 08a24624279..bdd100ce246 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/aiohttp/transport.py @@ -132,11 +132,6 @@ async def async_connect(): def write(self, message): # Inner function to perform async write. async def async_write(): - # To pass url into message for request authentication processing - message.update({'url': self._url}) - if message['auth']: - message['auth'](message) - async with async_timeout.timeout(self._write_timeout): self._http_req_resp = await self._client_session.post(url=self._url, data=message['payload'], diff --git a/gremlin-python/src/main/python/gremlin_python/driver/auth.py b/gremlin-python/src/main/python/gremlin_python/driver/auth.py index e7bd453b1d3..89d6ccf8891 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/auth.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/auth.py @@ -16,24 +16,29 @@ # specific language governing permissions and limitations # under the License. # +import base64 def basic(username, password): - from aiohttp import BasicAuth as aiohttpBasicAuth + """Returns an interceptor that adds Basic auth to the request.""" + def interceptor(request): + credentials = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode("utf-8") + request.headers['authorization'] = f"Basic {credentials}" - def apply(request): - return request['headers'].update({'authorization': aiohttpBasicAuth(username, password).encode()}) - - return apply + return interceptor def sigv4(region, service): + """Returns an interceptor that signs the request with AWS SigV4.""" import os from boto3 import Session from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest - def apply(request): + def interceptor(request): + # Ensure body is serialized so we can sign it + body_bytes = request.serialize_body() + access_key = os.environ.get('AWS_ACCESS_KEY_ID', '') secret_key = os.environ.get('AWS_SECRET_ACCESS_KEY', '') session_token = os.environ.get('AWS_SESSION_TOKEN', '') @@ -45,11 +50,8 @@ def apply(request): region_name=region ) - sigv4_request = AWSRequest(method="POST", url=request['url'], data=request['payload']) + sigv4_request = AWSRequest(method=request.method, url=request.url, data=body_bytes) SigV4Auth(session.get_credentials(), service, region).add_auth(sigv4_request) - request['headers'].update(sigv4_request.headers) - request['payload'] = sigv4_request.data - return request - - return apply + request.headers.update(dict(sigv4_request.headers)) + return interceptor diff --git a/gremlin-python/src/main/python/gremlin_python/driver/client.py b/gremlin-python/src/main/python/gremlin_python/driver/client.py index fa4dbc32c2b..a32e8f9d0b3 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/client.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/client.py @@ -39,7 +39,6 @@ def cpu_count(): class Client: def __init__(self, url, traversal_source, pool_size=None, max_workers=None, - request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=None, interceptors=None, auth=None, headers=None, enable_user_agent_on_connect=True, bulk_results=False, pdt_registry=None, **transport_kwargs): @@ -57,13 +56,10 @@ def __init__(self, url, traversal_source, pool_size=None, max_workers=None, if response_serializer is None: response_serializer = serializer.GraphBinarySerializersV4() if pdt_registry is not None: - if request_serializer is not None: - request_serializer.configure_pdt_registry(pdt_registry) response_serializer.configure_pdt_registry(pdt_registry) self._auth = auth self._response_serializer = response_serializer - self._request_serializer = request_serializer self._interceptors = interceptors self._transport_kwargs = transport_kwargs @@ -149,7 +145,6 @@ def _get_connection(self): return connection.Connection( self._url, self._traversal_source, self._executor, self._pool, - request_serializer=self._request_serializer, response_serializer=self._response_serializer, auth=self._auth, interceptors=self._interceptors, headers=self._headers, diff --git a/gremlin-python/src/main/python/gremlin_python/driver/connection.py b/gremlin-python/src/main/python/gremlin_python/driver/connection.py index 64a3f000818..c9b65eab788 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/connection.py @@ -19,6 +19,7 @@ from gremlin_python.driver import resultset, useragent from gremlin_python.driver.aiohttp.transport import AiohttpHTTPTransport +from gremlin_python.driver.http_request import HttpRequest __author__ = 'David M. Brown (davebshow@gmail.com)' @@ -34,17 +35,23 @@ def __init__(self, status): class Connection: def __init__(self, url, traversal_source, - executor, pool, request_serializer=None, + executor, pool, response_serializer=None, auth=None, interceptors=None, headers=None, enable_user_agent_on_connect=True, bulk_results=False, **transport_kwargs): if callable(interceptors): interceptors = [interceptors] - elif not (isinstance(interceptors, tuple) - or isinstance(interceptors, list) - or interceptors is None): + elif isinstance(interceptors, tuple): + interceptors = list(interceptors) + elif not (isinstance(interceptors, list) or interceptors is None): raise TypeError("interceptors must be a callable, tuple, list or None") + # Auth is just an interceptor. As a convenience (and for discoverability), the auth + # interceptor is appended to the end of the interceptor list so it runs last, after + # any user interceptors have modified the request. + if auth is not None: + interceptors = (interceptors or []) + [auth] + self._url = url self._headers = headers self._traversal_source = traversal_source @@ -54,9 +61,7 @@ def __init__(self, url, traversal_source, self._pool = pool self._result_set = None self._inited = False - self._request_serializer = request_serializer self._response_serializer = response_serializer - self._auth = auth self._interceptors = interceptors self._enable_user_agent_on_connect = enable_user_agent_on_connect if self._enable_user_agent_on_connect: @@ -78,22 +83,33 @@ def close(self): def _write_request(self, request_message): accept = str(self._response_serializer.version, encoding='utf-8') - message = { - 'headers': {'accept': accept}, - 'payload': self._request_serializer.serialize_message(request_message) - if self._request_serializer is not None else request_message, - 'auth': self._auth - } - if self._request_serializer is not None: - content_type = str(self._request_serializer.version, encoding='utf-8') - message['headers']['content-type'] = content_type + + headers = {'accept': accept} + # Promote transactionId to HTTP header before interceptors run. # The field remains in the serialized body as well (dual transmission # per the HTTP transaction protocol specification). if hasattr(request_message, 'fields') and 'transactionId' in request_message.fields: - message['headers']['X-Transaction-Id'] = request_message.fields['transactionId'] + headers['X-Transaction-Id'] = request_message.fields['transactionId'] + + http_request = HttpRequest( + method="POST", + url=self._url, + headers=headers, + body=request_message + ) + for interceptor in self._interceptors or []: - message = interceptor(message) + interceptor(http_request) + + # Auto-serialize if no interceptor already did so + http_request.serialize_body() + + # Build the transport message in the format the transport expects + message = { + 'headers': http_request.headers, + 'payload': http_request.body + } self._transport.write(message) def write(self, request_message): diff --git a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py index 279c2f48453..8c82b3d7294 100644 --- a/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py +++ b/gremlin-python/src/main/python/gremlin_python/driver/driver_remote_connection.py @@ -33,7 +33,6 @@ class DriverRemoteConnection(RemoteConnection): def __init__(self, url, traversal_source="g", pool_size=None, max_workers=None, - request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=None, interceptors=None, auth=None, headers=None, enable_user_agent_on_connect=True, bulk_results=False, pdt_registry=None, **transport_kwargs): @@ -54,7 +53,6 @@ def __init__(self, url, traversal_source="g", self._client = client.Client(url, traversal_source, pool_size=pool_size, max_workers=max_workers, - request_serializer=request_serializer, response_serializer=response_serializer, interceptors=interceptors, auth=auth, headers=headers, diff --git a/gremlin-python/src/main/python/gremlin_python/driver/http_request.py b/gremlin-python/src/main/python/gremlin_python/driver/http_request.py new file mode 100644 index 00000000000..304acf28a27 --- /dev/null +++ b/gremlin-python/src/main/python/gremlin_python/driver/http_request.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json + +from gremlin_python.driver.request import RequestMessage + + +class HttpRequest: + """Represents the HTTP request passed through the interceptor chain. + + The body starts as a RequestMessage and can be serialized to JSON bytes + via serialize_body(). Interceptors mutate this object in place. + """ + + def __init__(self, method, url, headers, body): + self.method = method + self.url = url + self.headers = headers + self.body = body + + def serialize_body(self): + """Serialize the body to JSON bytes if it is still a RequestMessage. + + If the body is already bytes, returns them as-is (idempotent). + Sets the Content-Type header to application/json and Content-Length + to the byte length of the serialized body. + + Returns: + bytes: the serialized body + + Raises: + TypeError: if the body is neither a RequestMessage nor bytes + """ + if isinstance(self.body, bytes): + return self.body + + if not isinstance(self.body, RequestMessage): + raise TypeError( + f"Cannot serialize body of type {type(self.body).__name__}. " + "Expected RequestMessage or bytes." + ) + + payload = {"gremlin": self.body.gremlin} + payload.update(self.body.fields) + data = json.dumps(payload).encode("utf-8") + self.body = data + self.headers["content-type"] = "application/json" + self.headers["content-length"] = str(len(data)) + return data diff --git a/gremlin-python/src/main/python/tests/feature/terrain.py b/gremlin-python/src/main/python/tests/feature/terrain.py index d81db72c19d..781216bb639 100644 --- a/gremlin-python/src/main/python/tests/feature/terrain.py +++ b/gremlin-python/src/main/python/tests/feature/terrain.py @@ -107,5 +107,5 @@ def __create_remote(server_graph_name): bulking = world.config.user_data["bulking"] == "true" if "bulking" in world.config.user_data else False return DriverRemoteConnection(test_no_auth_url, server_graph_name, - request_serializer=s, response_serializer=s, + response_serializer=s, bulk_results=bulking) diff --git a/gremlin-python/src/main/python/tests/integration/conftest.py b/gremlin-python/src/main/python/tests/integration/conftest.py index de685fb4b91..8776859ed2c 100644 --- a/gremlin-python/src/main/python/tests/integration/conftest.py +++ b/gremlin-python/src/main/python/tests/integration/conftest.py @@ -19,7 +19,6 @@ import concurrent.futures from collections import namedtuple -from json import dumps import os import ssl import pytest @@ -62,7 +61,6 @@ def connection(request): try: conn = Connection(anonymous_url, 'gmodern', executor, pool, - request_serializer=GraphBinarySerializersV4(), response_serializer=GraphBinarySerializersV4(), auth=basic('stephen', 'password')) except OSError: @@ -122,7 +120,6 @@ def graphbinary_serializer_v4(request): def remote_connection(request): try: remote_conn = DriverRemoteConnection(anonymous_url, 'gmodern', - request_serializer=serializer.GraphBinarySerializersV4(), response_serializer=serializer.GraphBinarySerializersV4()) except OSError: pytest.skip('Gremlin Server is not running') @@ -192,8 +189,7 @@ def fin(): def remote_connection_with_interceptor(request): try: remote_conn = DriverRemoteConnection(anonymous_url, 'gmodern', - request_serializer=None, - interceptors=json_interceptor) + interceptors=mutating_interceptor) except OSError: pytest.skip('Gremlin Server is not running') else: @@ -207,9 +203,8 @@ def fin(): @pytest.fixture() def client_with_interceptor(request): try: - client = Client(anonymous_url, 'gmodern', request_serializer=None, - response_serializer=GraphBinarySerializersV4(), - interceptors=json_interceptor) + client = Client(anonymous_url, 'gmodern', + interceptors=mutating_interceptor) except OSError: pytest.skip('Gremlin Server is not running') else: @@ -241,7 +236,8 @@ def fin(): return remote_conn -def json_interceptor(request): - request['headers']['content-type'] = "application/json" - request['payload'] = dumps({"gremlin": "g.inject(2)", "g": "g"}) - return request +def mutating_interceptor(http_request): + """Interceptor that replaces the gremlin query with g.inject(2) for testing.""" + from gremlin_python.driver.request import RequestMessage + if isinstance(http_request.body, RequestMessage): + http_request.body = RequestMessage(fields={"g": "g"}, gremlin="g.inject(2)") diff --git a/gremlin-python/src/main/python/tests/integration/driver/test_auth.py b/gremlin-python/src/main/python/tests/integration/driver/test_auth.py index 7efd1fe0422..fa6837098b3 100644 --- a/gremlin-python/src/main/python/tests/integration/driver/test_auth.py +++ b/gremlin-python/src/main/python/tests/integration/driver/test_auth.py @@ -16,51 +16,58 @@ # specific language governing permissions and limitations # under the License. # +import base64 import os -from aiohttp import BasicAuth as aiohttpBasicAuth + from gremlin_python.driver.auth import basic, sigv4 +from gremlin_python.driver.http_request import HttpRequest def create_mock_request(): - return {'headers': - {'content-type': 'application/vnd.graphbinary-v4.0', - 'accept': 'application/vnd.graphbinary-v4.0'}, - 'payload': b'', - 'url': 'https://test_url:8182/gremlin'} + return HttpRequest( + method="POST", + url="https://test_url:8182/gremlin", + headers={ + 'content-type': 'application/vnd.graphbinary-v4.0', + 'accept': 'application/vnd.graphbinary-v4.0', + }, + body=b'', + ) class TestAuth(object): def test_basic_auth_request(self): mock_request = create_mock_request() - assert 'authorization' not in mock_request['headers'] + assert 'authorization' not in mock_request.headers basic('username', 'password')(mock_request) - assert 'authorization' in mock_request['headers'] - assert aiohttpBasicAuth('username', 'password').encode() == mock_request['headers']['authorization'] + assert 'authorization' in mock_request.headers + expected = 'Basic ' + base64.b64encode('username:password'.encode('utf-8')).decode('utf-8') + assert expected == mock_request.headers['authorization'] def test_sigv4_auth_request(self): mock_request = create_mock_request() - assert 'Authorization' not in mock_request['headers'] - assert 'X-Amz-Date' not in mock_request['headers'] + assert 'Authorization' not in mock_request.headers + assert 'X-Amz-Date' not in mock_request.headers os.environ['AWS_ACCESS_KEY_ID'] = 'MOCK_ID' os.environ['AWS_SECRET_ACCESS_KEY'] = 'MOCK_KEY' sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request) - assert mock_request['headers']['X-Amz-Date'] is not None - assert mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=MOCK_ID') - assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request['headers']['Authorization'] - assert 'Signature=' in mock_request['headers']['Authorization'] + assert mock_request.headers['X-Amz-Date'] is not None + assert mock_request.headers['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=MOCK_ID') + assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request.headers['Authorization'] + assert 'Signature=' in mock_request.headers['Authorization'] def test_sigv4_auth_request_session_token(self): mock_request = create_mock_request() - assert 'Authorization' not in mock_request['headers'] - assert 'X-Amz-Date' not in mock_request['headers'] - assert 'X-Amz-Security-Token' not in mock_request['headers'] + assert 'Authorization' not in mock_request.headers + assert 'X-Amz-Date' not in mock_request.headers + assert 'X-Amz-Security-Token' not in mock_request.headers + os.environ['AWS_ACCESS_KEY_ID'] = 'MOCK_ID' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'MOCK_KEY' os.environ['AWS_SESSION_TOKEN'] = 'MOCK_TOKEN' sigv4('gremlin-east-1', 'tinkerpop-sigv4')(mock_request) - assert mock_request['headers']['X-Amz-Date'] is not None - assert mock_request['headers']['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=') - assert mock_request['headers']['X-Amz-Security-Token'] == 'MOCK_TOKEN' - assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request['headers']['Authorization'] - assert 'Signature=' in mock_request['headers']['Authorization'] - - + assert mock_request.headers['X-Amz-Date'] is not None + assert mock_request.headers['Authorization'].startswith('AWS4-HMAC-SHA256 Credential=') + assert mock_request.headers['X-Amz-Security-Token'] == 'MOCK_TOKEN' + assert 'gremlin-east-1/tinkerpop-sigv4/aws4_request' in mock_request.headers['Authorization'] + assert 'Signature=' in mock_request.headers['Authorization'] diff --git a/gremlin-python/src/main/python/tests/integration/driver/test_client.py b/gremlin-python/src/main/python/tests/integration/driver/test_client.py index 34fa004e125..b6a2236eda9 100644 --- a/gremlin-python/src/main/python/tests/integration/driver/test_client.py +++ b/gremlin-python/src/main/python/tests/integration/driver/test_client.py @@ -612,3 +612,47 @@ def test_pdt_in_collection(client): assert pdt_list[1].name == 'Point' assert pdt_list[1].fields['x'] == 3 assert pdt_list[1].fields['y'] == 4 + + +def test_auto_serializes_request_message_with_interceptor_mutation(): + """Verifies the driver auto-serializes when an interceptor modifies the RequestMessage body.""" + from gremlin_python.driver.request import RequestMessage + + def swap_query(http_request): + if isinstance(http_request.body, RequestMessage): + http_request.body = RequestMessage(fields={"g": "gmodern"}, gremlin="g.inject(99)") + + client = Client(test_no_auth_url, 'gmodern', + pool_size=1, interceptors=swap_query) + try: + result = client.submit("g.inject(1)").next() + assert 99 == result + finally: + client.close() + + +def test_interceptor_errors_propagate(): + """Verifies that an interceptor error propagates to the caller, the request is not sent, + and the client remains usable for subsequent requests.""" + call_count = [0] + + def failing_interceptor(http_request): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("interceptor broke") + + client = Client(test_no_auth_url, 'gmodern', + pool_size=1, interceptors=failing_interceptor) + try: + # First request should fail with interceptor error + try: + client.submit("g.inject(1)").next() + assert False, "Should have thrown an exception" + except RuntimeError as e: + assert "interceptor broke" in str(e) + + # Subsequent request should succeed, proving the client is still usable + result = client.submit("g.inject(2)").next() + assert 2 == result + finally: + client.close() diff --git a/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py b/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py index 23415fdef32..1f6d2e1f90b 100644 --- a/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py +++ b/gremlin-python/src/main/python/tests/unit/driver/test_http_streaming.py @@ -29,6 +29,7 @@ import asyncio import io +import json import queue import struct from concurrent.futures import Future @@ -1247,12 +1248,11 @@ def test_receive_uses_serializer_version_for_content_type_check(self): class TestConnectionWriteRequest: """ - Tests for Connection._write_request() which handles serialization, - header construction, auth, and interceptors before calling transport.write(). + Tests for Connection._write_request() which builds an HttpRequest, runs interceptors, + auto-serializes the body to JSON, and calls transport.write(). """ - def _make_connection(self, request_serializer=None, response_serializer=None, - auth=None, interceptors=None): + def _make_connection(self, response_serializer=None, interceptors=None): from gremlin_python.driver.connection import Connection from gremlin_python.driver.serializer import GraphBinarySerializersV4 @@ -1260,31 +1260,21 @@ def _make_connection(self, request_serializer=None, response_serializer=None, response_serializer = GraphBinarySerializersV4() conn = Connection.__new__(Connection) - conn._response_serializer = GraphBinarySerializersV4() - conn._request_serializer = request_serializer + conn._url = 'http://localhost:8182/gremlin' conn._response_serializer = response_serializer - conn._auth = auth conn._interceptors = interceptors conn._transport = MagicMock() return conn - def test_none_request_serializer_passes_raw_message(self): - conn = self._make_connection(request_serializer=None) - msg = RequestMessage(fields={}, gremlin="g.V()") + def test_auto_serializes_body_to_json(self): + conn = self._make_connection() + msg = RequestMessage(fields={"g": "g"}, gremlin="g.V()") conn._write_request(msg) written = conn._transport.write.call_args[0][0] - assert written['payload'] == msg - assert 'content-type' not in written['headers'] - - def test_graphbinary_serializer_serializes_payload(self): - from gremlin_python.driver.serializer import GraphBinarySerializersV4 - gb = GraphBinarySerializersV4() - conn = self._make_connection(request_serializer=gb, response_serializer=gb) - msg = RequestMessage(fields={}, gremlin="g.V()") - conn._write_request(msg) - written = conn._transport.write.call_args[0][0] - assert written['payload'] == gb.serialize_message(msg) - assert written['headers']['content-type'] == str(gb.version, encoding='utf-8') + payload = json.loads(written['payload']) + assert payload['gremlin'] == "g.V()" + assert payload['g'] == "g" + assert written['headers']['content-type'] == "application/json" def test_accept_header_set_from_response_serializer(self): from gremlin_python.driver.serializer import GraphBinarySerializersV4 @@ -1294,45 +1284,30 @@ def test_accept_header_set_from_response_serializer(self): written = conn._transport.write.call_args[0][0] assert written['headers']['accept'] == str(gb.version, encoding='utf-8') - def test_auth_passed_in_message(self): - auth_fn = lambda req: req - conn = self._make_connection(auth=auth_fn) - conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) - written = conn._transport.write.call_args[0][0] - assert written['auth'] is auth_fn - def test_single_interceptor_runs(self): - changed = RequestMessage(fields={}, gremlin="changed") def interceptor(request): - request['payload'] = changed - return request + request.body = RequestMessage(fields={"g": "g"}, gremlin="changed") conn = self._make_connection(interceptors=[interceptor]) conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) written = conn._transport.write.call_args[0][0] - assert written['payload'] == changed + assert json.loads(written['payload'])['gremlin'] == "changed" def test_interceptors_run_sequentially(self): - def one(req): req['payload'].gremlin.append(1); return req - def two(req): req['payload'].gremlin.append(2); return req - def three(req): req['payload'].gremlin.append(3); return req + order = [] + def one(req): order.append(1) + def two(req): order.append(2) + def three(req): order.append(3) conn = self._make_connection(interceptors=[one, two, three]) - conn._write_request(RequestMessage(fields={}, gremlin=[])) - written = conn._transport.write.call_args[0][0] - assert written['payload'].gremlin == [1, 2, 3] + conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) + assert order == [1, 2, 3] - def test_interceptor_works_with_serializer(self): - from gremlin_python.driver.serializer import GraphBinarySerializersV4 - gb = GraphBinarySerializersV4() - msg = RequestMessage(fields={}, gremlin="g.E()") - def assert_interceptor(request): - assert request['payload'] == gb.serialize_message(msg) - request['payload'] = "changed" - return request - conn = self._make_connection(request_serializer=gb, response_serializer=gb, - interceptors=[assert_interceptor]) - conn._write_request(msg) + def test_interceptor_can_modify_headers(self): + def interceptor(request): + request.headers['x-custom'] = "value" + conn = self._make_connection(interceptors=[interceptor]) + conn._write_request(RequestMessage(fields={}, gremlin="g.V()")) written = conn._transport.write.call_args[0][0] - assert written['payload'] == "changed" + assert written['headers']['x-custom'] == "value" class TestConnectionInterceptorValidation: diff --git a/gremlin-python/src/main/python/tests/unit/driver/test_interceptor.py b/gremlin-python/src/main/python/tests/unit/driver/test_interceptor.py new file mode 100644 index 00000000000..456d7fb4a58 --- /dev/null +++ b/gremlin-python/src/main/python/tests/unit/driver/test_interceptor.py @@ -0,0 +1,221 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import json + +from gremlin_python.driver.http_request import HttpRequest +from gremlin_python.driver.request import RequestMessage + + +def make_request(gremlin="g.V()", fields=None): + msg = RequestMessage(fields=fields or {"g": "g"}, gremlin=gremlin) + return HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={"accept": "application/vnd.graphbinary-v4.0"}, body=msg) + + +class TestBasicInterceptorExecution: + + def test_interceptor_receives_request_message_in_body(self): + request = make_request() + captured = [] + + def interceptor(req): + captured.append(req.body) + + interceptor(request) + assert len(captured) == 1 + assert isinstance(captured[0], RequestMessage) + + def test_interceptor_can_read_and_modify_headers(self): + request = make_request() + request.headers["x-existing"] = "original" + + def interceptor(req): + assert req.headers["x-existing"] == "original" + req.headers["x-existing"] = "modified" + req.headers["x-new"] = "added" + + interceptor(request) + assert request.headers["x-existing"] == "modified" + assert request.headers["x-new"] == "added" + + def test_interceptor_can_modify_uri(self): + request = make_request() + + def interceptor(req): + req.url = "http://other-host:9999/gremlin" + + interceptor(request) + assert request.url == "http://other-host:9999/gremlin" + + def test_interceptors_run_in_registration_order(self): + request = make_request() + order = [] + + interceptors = [ + lambda req: order.append(1), + lambda req: order.append(2), + lambda req: order.append(3), + ] + + for i in interceptors: + i(request) + + assert order == [1, 2, 3] + + +class TestSerializeBody: + + def test_converts_request_message_to_json_bytes(self): + request = make_request(gremlin="g.V().count()") + result = request.serialize_body() + + parsed = json.loads(result) + assert parsed["gremlin"] == "g.V().count()" + assert parsed["g"] == "g" + + def test_sets_content_type_header(self): + request = make_request() + request.serialize_body() + assert request.headers["content-type"] == "application/json" + + def test_sets_content_length_header(self): + request = make_request() + result = request.serialize_body() + assert request.headers["content-length"] == str(len(result)) + + def test_is_idempotent_with_pre_serialized_bytes(self): + existing = b'{"gremlin":"g.V()"}' + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=existing) + first = request.serialize_body() + second = request.serialize_body() + assert first is existing + assert first is second + + def test_is_idempotent_with_request_message(self): + request = make_request() + first = request.serialize_body() + second = request.serialize_body() + assert first is second + + def test_includes_all_fields_in_json(self): + msg = RequestMessage( + fields={"g": "g", "language": "gremlin-lang", "timeoutMs": 5000}, + gremlin="g.V()" + ) + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=msg) + result = request.serialize_body() + parsed = json.loads(result) + + assert parsed["gremlin"] == "g.V()" + assert parsed["g"] == "g" + assert parsed["language"] == "gremlin-lang" + assert parsed["timeoutMs"] == 5000 + + def test_raises_on_unsupported_body_type(self): + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=42) + try: + request.serialize_body() + assert False, "expected TypeError for unsupported body type" + except TypeError as e: + assert "int" in str(e) + + def test_raises_on_none_body(self): + request = HttpRequest(method="POST", url="http://localhost:8182/gremlin", + headers={}, body=None) + try: + request.serialize_body() + assert False, "expected TypeError for None body" + except TypeError as e: + assert "NoneType" in str(e) + + +class TestFieldMutation: + + def test_interceptor_can_replace_body_before_serialization(self): + request = make_request() + + def interceptor(req): + req.body = RequestMessage(fields={"g": "gmodern"}, gremlin="g.E()") + + interceptor(request) + request.serialize_body() + + parsed = json.loads(request.body) + assert parsed["gremlin"] == "g.E()" + assert parsed["g"] == "gmodern" + + + + +class TestAuthInterceptorOrdering: + + def test_auth_interceptor_is_always_last(self): + """Auth interceptor should always be appended to the end of the interceptor list.""" + from unittest.mock import MagicMock + from gremlin_python.driver.connection import Connection + + order = [] + + def interceptor1(req): + order.append(1) + + def interceptor2(req): + order.append(2) + + def auth_interceptor(req): + order.append(3) + + conn = Connection( + url="http://localhost:8182/gremlin", + traversal_source="g", + executor=MagicMock(), + pool=MagicMock(), + auth=auth_interceptor, + interceptors=[interceptor1, interceptor2], + enable_user_agent_on_connect=False + ) + + # Verify the internal interceptor list has auth at the end + assert len(conn._interceptors) == 3 + assert conn._interceptors[0] is interceptor1 + assert conn._interceptors[1] is interceptor2 + assert conn._interceptors[2] is auth_interceptor + + def test_auth_interceptor_is_last_even_without_other_interceptors(self): + """Auth interceptor works when no other interceptors are provided.""" + from unittest.mock import MagicMock + from gremlin_python.driver.connection import Connection + + def auth_interceptor(req): + pass + + conn = Connection( + url="http://localhost:8182/gremlin", + traversal_source="g", + executor=MagicMock(), + pool=MagicMock(), + auth=auth_interceptor, + enable_user_agent_on_connect=False + ) + + assert len(conn._interceptors) == 1 + assert conn._interceptors[0] is auth_interceptor diff --git a/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py b/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py index aac685982c7..4f9b086fc7f 100644 --- a/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py +++ b/gremlin-python/src/main/python/tests/unit/structure/io/test_provider_defined_type.py @@ -230,7 +230,6 @@ def test_client_passes_registry_to_serializers(self): registry = ProviderDefinedTypeRegistry() with patch.object(Client, '_fill_pool'): c = Client("ws://localhost:8182/gremlin", "g", pdt_registry=registry) - assert c._request_serializer._graphbinary_reader.pdt_registry is registry assert c._response_serializer._graphbinary_reader.pdt_registry is registry def test_driver_remote_connection_passes_registry(self): diff --git a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java index 99007d71155..8a688b6d6bd 100644 --- a/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java +++ b/gremlin-server/src/main/java/org/apache/tinkerpop/gremlin/server/handler/HttpRequestMessageDecoder.java @@ -210,6 +210,12 @@ private RequestMessage getRequestMessageFromHttpRequest(final FullHttpRequest re final JsonNode txIdNode = body.get(Tokens.ARGS_TRANSACTION_ID); if (null != txIdNode) builder.addTransactionId(txIdNode.asText()); + // bulkResults was previously only sent as an HTTP header by GLV drivers using + // GraphBinary. With the move to JSON requests in 4.x, drivers may include it in + // the body instead. + final JsonNode bulkResultsNode = body.get(Tokens.BULK_RESULTS); + if (null != bulkResultsNode) builder.addBulkResults(bulkResultsNode.asBoolean()); + return builder.create(); } } diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java index dd277cf8dbf..d6ccc6707ac 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinDriverIntegrateTest.java @@ -29,7 +29,6 @@ import org.apache.tinkerpop.gremlin.driver.ResultSet; import org.apache.tinkerpop.gremlin.driver.exception.NoHostAvailableException; import org.apache.tinkerpop.gremlin.driver.exception.ResponseException; -import org.apache.tinkerpop.gremlin.driver.interceptor.PayloadSerializingInterceptor; import org.apache.tinkerpop.gremlin.driver.remote.DriverRemoteConnection; import org.apache.tinkerpop.gremlin.jsr223.ScriptFileGremlinPlugin; import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource; @@ -45,8 +44,8 @@ import org.apache.tinkerpop.gremlin.util.ExceptionHelper; import org.apache.tinkerpop.gremlin.util.TimeUtil; import org.apache.tinkerpop.gremlin.util.function.FunctionUtils; +import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; -import org.apache.tinkerpop.gremlin.util.ser.GraphSONMessageSerializerV4; import org.apache.tinkerpop.gremlin.util.ser.Serializers; import org.junit.AfterClass; import org.junit.Before; @@ -174,10 +173,7 @@ public void shouldInterceptRequests() throws Exception { final AtomicInteger httpRequests = new AtomicInteger(0); final Cluster cluster = TestClientFactory.build(). - addInterceptor("counter", r -> { - httpRequests.incrementAndGet(); - return r; - }).create(); + interceptors(r -> httpRequests.incrementAndGet()).create(); try { final Client client = cluster.connect(); @@ -195,18 +191,63 @@ public void shouldInterceptRequests() throws Exception { public void shouldRunInterceptorsInOrder() throws Exception { AtomicReference body = new AtomicReference<>(); final Cluster cluster = TestClientFactory.build(). - addInterceptor("first", r -> { - body.set(r.getBody()); - r.setBody(null); - return r; - }). - addInterceptor("second", r -> { - r.setBody(body.get()); - return r; + interceptors( + r -> { + body.set(r.getBody()); + r.setBody(null); + }, + r -> r.setBody(body.get()) + ).create(); + + try { + final Client client = cluster.connect(); + assertEquals(2, client.submit("g.inject(2)").all().get().get(0).getInt()); + } finally { + cluster.close(); + } + } + + @Test + public void shouldAutoSerializeRequestMessageWithInterceptorMutation() throws Exception { + // Verifies the driver auto-serializes when an interceptor modifies the RequestMessage + // body but does not call serializeBody() itself. + final Cluster cluster = TestClientFactory.build(). + interceptors(r -> { + if (r.getBody() instanceof RequestMessage) { + r.setBody(RequestMessage.build("g.inject(99)").create()); + } }).create(); try { final Client client = cluster.connect(); + assertEquals(99, client.submit("g.inject(1)").all().get().get(0).getInt()); + } finally { + cluster.close(); + } + } + + @Test + public void shouldPropagateExceptionThrownDuringInterceptor() throws Exception { + final AtomicInteger callCount = new AtomicInteger(0); + final Cluster cluster = TestClientFactory.build(). + interceptors(r -> { + // Only throw on the first request to verify recovery + if (callCount.incrementAndGet() == 1) { + throw new RuntimeException("interceptor broke"); + } + }).create(); + try { + final Client client = cluster.connect(); + + // First request should fail with interceptor error + try { + client.submit("g.inject(1)").all().get(); + fail("Should have thrown an exception"); + } catch (Exception ex) { + assertTrue(ex.getCause().getMessage().contains("interceptor broke")); + } + + // Subsequent request should succeed, proving the connection is still usable assertEquals(2, client.submit("g.inject(2)").all().get().get(0).getInt()); } finally { cluster.close(); @@ -214,8 +255,39 @@ public void shouldRunInterceptorsInOrder() throws Exception { } @Test - public void shouldWorkWithGraphSONSerializer() throws Exception { - final Cluster cluster = TestClientFactory.build(new PayloadSerializingInterceptor(new GraphSONMessageSerializerV4())) + public void shouldPropagateErrorWhenInterceptorSetsUnsupportedBodyType() throws Exception { + // This error occurs after interceptors run, when serializeBody() encounters + // a body that is neither RequestMessage nor byte[]. + final AtomicInteger callCount = new AtomicInteger(0); + final Cluster cluster = TestClientFactory.build(). + interceptors(r -> { + if (callCount.incrementAndGet() == 1) { + r.setBody(42); + } + }).create(); + try { + final Client client = cluster.connect(); + + // First request should fail with body type error + try { + client.submit("g.inject(1)").all().get(); + fail("Should have thrown an exception"); + } catch (Exception ex) { + final String msg = ex.getCause().getMessage(); + assertTrue(msg.contains("Cannot serialize body of type")); + assertTrue(msg.contains("Integer")); + } + + // Subsequent request should succeed, proving the connection is still usable + assertEquals(3, client.submit("g.inject(3)").all().get().get(0).getInt()); + } finally { + cluster.close(); + } + } + + @Test + public void shouldWorkWithGraphSONResponse() throws Exception { + final Cluster cluster = TestClientFactory.build() .serializer(Serializers.GRAPHSON_V4.simpleInstance()).create(); try { @@ -232,8 +304,8 @@ public void shouldWorkWithGraphSONSerializer() throws Exception { } @Test - public void shouldWorkWithGraphSONRequestAndGraphBinaryResponse() throws Exception { - final Cluster cluster = TestClientFactory.build(new PayloadSerializingInterceptor(new GraphSONMessageSerializerV4())) + public void shouldWorkWithJsonRequestAndGraphBinaryResponse() throws Exception { + final Cluster cluster = TestClientFactory.build() .serializer(Serializers.GRAPHBINARY_V4.simpleInstance()).create(); try { @@ -245,8 +317,8 @@ public void shouldWorkWithGraphSONRequestAndGraphBinaryResponse() throws Excepti } @Test - public void shouldWorkWithGraphBinaryRequestAndGraphSONResponse() throws Exception { - final Cluster cluster = TestClientFactory.build(new PayloadSerializingInterceptor(new GraphBinaryMessageSerializerV4())) + public void shouldWorkWithJsonRequestAndGraphSONResponse() throws Exception { + final Cluster cluster = TestClientFactory.build() .serializer(Serializers.GRAPHSON_V4.simpleInstance()).create(); try { @@ -264,10 +336,7 @@ public void shouldInterceptRequestsWithHandshake() throws Exception { final Cluster cluster = TestClientFactory.build(). maxConnectionPoolSize(1). - addInterceptor("counter", r -> { - handshakeRequests.incrementAndGet(); - return r; - }).create(); + interceptors(r -> handshakeRequests.incrementAndGet()).create(); try { final Client client = cluster.connect(); diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java index 8c200c5128b..897b3740b24 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/GremlinServerAuthIntegrateTest.java @@ -94,11 +94,10 @@ public void shouldPassSigv4ToServer() throws Exception { final AtomicReference httpRequest = new AtomicReference<>(); final Cluster cluster = TestClientFactory.build() - .auth(sigv4("us-west2", credentialsProvider, "service-name")) - .addInterceptor("header-checker", r -> { - httpRequest.set(r); - return r; - }) + .interceptors( + sigv4("us-west2", credentialsProvider, "service-name"), + r -> httpRequest.set(r) + ) .create(); final Client client = cluster.connect(); client.submit("g.inject(2)").all().get(); diff --git a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java index 8c31e41c221..2f2e1445642 100644 --- a/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java +++ b/gremlin-server/src/test/java/org/apache/tinkerpop/gremlin/server/TestClientFactory.java @@ -19,7 +19,6 @@ package org.apache.tinkerpop.gremlin.server; import org.apache.tinkerpop.gremlin.driver.Cluster; -import org.apache.tinkerpop.gremlin.driver.RequestInterceptor; import org.apache.tinkerpop.gremlin.driver.simple.SimpleHttpClient; import java.net.URI; @@ -43,10 +42,6 @@ public static Cluster.Builder build(final String address) { return Cluster.build(address).port(PORT); } - public static Cluster.Builder build(final RequestInterceptor serializingInterceptor) { - return Cluster.build(serializingInterceptor).port(PORT); - } - public static Cluster open() { return build().create(); } diff --git a/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java b/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java index 3ea75f2f7d3..cdc49ff93f8 100644 --- a/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java +++ b/gremlin-tools/gremlin-socket-server/src/main/java/org/apache/tinkerpop/gremlin/socket/server/TestHttpGremlinHandler.java @@ -31,15 +31,17 @@ import org.apache.tinkerpop.gremlin.structure.Graph; import org.apache.tinkerpop.gremlin.structure.Vertex; import org.apache.tinkerpop.gremlin.tinkergraph.structure.TinkerFactory; -import org.apache.tinkerpop.gremlin.util.message.RequestMessage; import org.apache.tinkerpop.gremlin.util.message.ResponseMessage; import org.apache.tinkerpop.gremlin.util.ser.GraphBinaryMessageSerializerV4; import org.apache.tinkerpop.gremlin.util.ser.SerTokens; import org.apache.tinkerpop.gremlin.util.ser.SerializationException; +import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Random; import java.util.concurrent.TimeUnit; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static io.netty.handler.codec.http.HttpResponseStatus.BAD_REQUEST; import static io.netty.handler.codec.http.HttpResponseStatus.INTERNAL_SERVER_ERROR; @@ -56,10 +58,17 @@ public class TestHttpGremlinHandler extends SimpleChannelInboundHandler