diff --git a/contrib/temporal-payload-storage-s3/build.gradle b/contrib/temporal-payload-storage-s3/build.gradle new file mode 100644 index 0000000000..67ee2403fb --- /dev/null +++ b/contrib/temporal-payload-storage-s3/build.gradle @@ -0,0 +1,19 @@ +description = '''Temporal Java SDK External Storage Driver for AWS S3''' + +ext { + awsSdkVersion = '2.31.0' +} + +dependencies { + compileOnly project(':temporal-serviceclient') + compileOnly project(':temporal-sdk') + + api platform("software.amazon.awssdk:bom:$awsSdkVersion") + api "software.amazon.awssdk:s3" + + testImplementation project(':temporal-serviceclient') + testImplementation project(':temporal-sdk') + testImplementation "junit:junit:${junitVersion}" + testImplementation "org.mockito:mockito-core:${mockitoVersion}" + testRuntimeOnly group: 'ch.qos.logback', name: 'logback-classic', version: "${logbackVersion}" +} diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/BucketResolver.java b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/BucketResolver.java new file mode 100644 index 0000000000..e31d1bc3d5 --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/BucketResolver.java @@ -0,0 +1,18 @@ +package io.temporal.payload.storage.s3; + +import io.temporal.api.common.v1.Payload; +import io.temporal.common.Experimental; +import io.temporal.payload.storage.StorageDriverStoreContext; +import javax.annotation.Nonnull; + +/** + * Resolves the target S3 bucket for a payload. Use {@link + * S3StorageDriver.Builder#setBucket(String)} for a fixed bucket, or supply a resolver via {@link + * S3StorageDriver.Builder#setBucketResolver(BucketResolver)} to choose a bucket per payload. + */ +@Experimental +@FunctionalInterface +public interface BucketResolver { + @Nonnull + String resolveBucket(@Nonnull StorageDriverStoreContext context, @Nonnull Payload payload); +} diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/CompletableFutures.java b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/CompletableFutures.java new file mode 100644 index 0000000000..4f8d8bc97e --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/CompletableFutures.java @@ -0,0 +1,57 @@ +package io.temporal.payload.storage.s3; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +final class CompletableFutures { + private CompletableFutures() {} + + /** + * Completes with the results in input order once every future succeeds, or fails fast with the + * first failure's (unwrapped) cause as soon as any future fails. Supports cooperative + * cancellation. + */ + static CompletableFuture> allAsList(List> futures) { + CompletableFuture> result = new CompletableFuture<>(); + if (futures.isEmpty()) { + result.complete(new ArrayList<>()); + return result; + } + AtomicInteger remaining = new AtomicInteger(futures.size()); + for (CompletableFuture future : futures) { + future.whenComplete( + (value, ex) -> { + if (ex != null) { + result.completeExceptionally(unwrap(ex)); + } else if (remaining.decrementAndGet() == 0) { + List results = new ArrayList<>(futures.size()); + for (CompletableFuture completed : futures) { + results.add(completed.join()); + } + result.complete(results); + } + }); + } + result.whenComplete( + (value, ex) -> { + if (ex != null) { + for (CompletableFuture future : futures) { + future.cancel(true); + } + } + }); + return result; + } + + static Throwable unwrap(Throwable t) { + while ((t instanceof CompletionException || t instanceof ExecutionException) + && t.getCause() != null) { + t = t.getCause(); + } + return t; + } +} diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/README.md b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/README.md new file mode 100644 index 0000000000..2a79b3ad4b --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/README.md @@ -0,0 +1,115 @@ +# AWS S3 Driver + +Temporal's S3 Driver for External Storage. Uses the official [AWS S3 Java SDK](https://github.com/aws/aws-sdk-java-v2). + +## Usage + +Construct the S3 storage driver: + +```java +import io.temporal.payload.storage.s3.S3AsyncClientAdapter; +import io.temporal.payload.storage.s3.S3StorageDriver; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; + +S3AsyncClient s3Client = + S3AsyncClient.builder().region(Region.US_EAST_1).build(); + +S3StorageDriver driver = + S3StorageDriver.newBuilder() + .setClient(new S3AsyncClientAdapter(s3Client)) + .setBucket("temporal-payloads") + .build(); +``` + +Register the driver in external storage config: + +```java +import io.temporal.payload.storage.ExternalStorage; + +ExternalStorage externalStorage = + ExternalStorage.newBuilder() + .setDriver(driver) + .build(); +``` + +Use `setBucketResolver(...)` instead of `setBucket(...)` when bucket selection must vary per +payload. + +## S3 Storage Key Specification + +All Temporal S3 drivers generate S3 keys in a consistent manner. + +### Key format + +Workflow key: +```text +v0/ns/{namespace}/wt/{workflow-type}/wi/{workflow-id}/ri/{run-id}/d/{hash-algorithm}/{hex-digest} +``` + +Activity key: +```text +v0/ns/{namespace}/at/{activity-type}/ai/{activity-id}/ri/{run-id}/d/{hash-algorithm}/{hex-digest} +``` + +Fallback key (unknown target): +```text +v0/d/{hash-algorithm}/{hex-digest} +``` + +- If no namespace, workflow, or activity information is available, the fallback is used. +- Dynamic path segments are percent-encoded (rules below). +- Missing values (including a missing `run-id`) are encoded as `null`. +- `hex-digest` is lower-case SHA-256 hex (64 characters). + +### Percent-encoding rules + +1. Treat each key path component as UTF-8 bytes. +2. Leave ASCII letters and digits unescaped. +3. Leave the following ASCII characters unescaped: `- _ . ~ $ & + : = @` +4. Encode all other bytes as % followed by two uppercase hexadecimal digits. +5. Empty or null values are encoded as the literal string `null`. +6. This is path-segment escaping, not form encoding (`+` stays `+`). + +### Examples + +Workflow key example: + +```text +input: + namespace=payments prod + workflow-type=ChargeWorkflow + workflow-id=order+123=abc + run-id=3f1d6c7a-8b2e-4f7a-9d0a-87a6f95e4d31 + hash-algorithm=sha256 + hex-digest=9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08 + +output: + v0/ns/payments%20prod/wt/ChargeWorkflow/wi/order+123=abc/ri/3f1d6c7a-8b2e-4f7a-9d0a-87a6f95e4d31/d/sha256/9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08 +``` + +Activity key example: + +```text +input: + namespace=payments prod + activity-type=Capture/Charge + activity-id=activity id+42 + run-id=9e1d1fd9-2f8a-4c40-93e2-731f31b9268b + hash-algorithm=sha256 + hex-digest=2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824 + +output: + v0/ns/payments%20prod/at/Capture%2FCharge/ai/activity%20id+42/ri/9e1d1fd9-2f8a-4c40-93e2-731f31b9268b/d/sha256/2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824 +``` + +Fallback key example: + +```text +input: + hash-algorithm=sha256 + hex-digest=486ea46224d1bb4fb680f34f7c9ad96a8f24ec88be73ea8e5a6c65260e9cb8a7 + +output: + v0/d/sha256/486ea46224d1bb4fb680f34f7c9ad96a8f24ec88be73ea8e5a6c65260e9cb8a7 +``` diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3AsyncClientAdapter.java b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3AsyncClientAdapter.java new file mode 100644 index 0000000000..ebfdd3cea8 --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3AsyncClientAdapter.java @@ -0,0 +1,109 @@ +package io.temporal.payload.storage.s3; + +import io.temporal.common.Experimental; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import javax.annotation.Nonnull; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.HeadObjectResponse; +import software.amazon.awssdk.services.s3.model.NoSuchKeyException; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +/** + * {@link S3Client} backed by the AWS SDK for Java v2 {@link S3AsyncClient}. The wrapped client must + * be configured with credentials and a region by the caller. + */ +@Experimental +public final class S3AsyncClientAdapter implements S3Client { + private final S3AsyncClient client; + + public S3AsyncClientAdapter(@Nonnull S3AsyncClient client) { + this.client = Objects.requireNonNull(client, "client"); + } + + @Nonnull + @Override + public CompletableFuture putObject( + @Nonnull String bucket, @Nonnull String key, @Nonnull byte[] data) { + CompletableFuture request = + client.putObject( + PutObjectRequest.builder().bucket(bucket).key(key).build(), + AsyncRequestBody.fromBytesUnsafe(data)); // avoids a defensive copy + return abortRequestOnCancel(request, request.thenApply(response -> (Void) null)); + } + + @Nonnull + @Override + public CompletableFuture objectExists(@Nonnull String bucket, @Nonnull String key) { + CompletableFuture request = + client.headObject(HeadObjectRequest.builder().bucket(bucket).key(key).build()); + return abortRequestOnCancel( + request, + request.handle( + (response, ex) -> { + if (ex == null) { + return true; + } + Throwable cause = + (ex instanceof CompletionException && ex.getCause() != null) ? ex.getCause() : ex; + if (cause instanceof NoSuchKeyException) { + return false; + } + if (cause instanceof S3Exception && ((S3Exception) cause).statusCode() == 404) { + return false; + } + if (cause instanceof RuntimeException) { + throw (RuntimeException) cause; + } + throw new RuntimeException(cause); + })); + } + + @Nonnull + @Override + public CompletableFuture getObject(@Nonnull String bucket, @Nonnull String key) { + CompletableFuture> request = + client.getObject( + GetObjectRequest.builder().bucket(bucket).key(key).build(), + AsyncResponseTransformer.toBytes()); + return abortRequestOnCancel(request, request.thenApply(ResponseBytes::asByteArrayUnsafe)); + } + + /** + * Returns {@code result}, wired so that cancelling it cancels the underlying {@code request}. The + * AWS SDK aborts an async request when the future it returns is cancelled. Cancellation does not + * otherwise propagate across the {@code thenApply}/{@code handle} boundary. + */ + private static CompletableFuture abortRequestOnCancel( + CompletableFuture request, CompletableFuture result) { + result.whenComplete( + (value, ex) -> { + if (result.isCancelled()) { + request.cancel(true); + } + }); + return result; + } + + @Nonnull + @Override + public Map describe() { + Region region = client.serviceClientConfiguration().region(); + if (region == null) { + return Collections.emptyMap(); + } + return Collections.singletonMap("client_region", region.id()); + } +} diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3Client.java b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3Client.java new file mode 100644 index 0000000000..97302bce4a --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3Client.java @@ -0,0 +1,47 @@ +package io.temporal.payload.storage.s3; + +import io.temporal.common.Experimental; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import javax.annotation.Nonnull; + +/** + * Interface for S3 {@link S3StorageDriver} operations: upload, existence check, and download. + * + *

Cancelling a returned future makes a best-effort attempt to abort the in-flight requests. + */ +@Experimental +public interface S3Client { + /** + * Uploads {@code data} to the given {@code bucket} and {@code key}, overwriting any existing + * object at that key. Implementations must be safe to call concurrently for different keys. + */ + @Nonnull + CompletableFuture putObject( + @Nonnull String bucket, @Nonnull String key, @Nonnull byte[] data); + + /** + * Reports whether an object exists at the given {@code bucket} and {@code key}. The future + * completes with {@code false} when the object is absent, and completes exceptionally when + * existence cannot be determined (e.g. a network or permission failure). + */ + @Nonnull + CompletableFuture objectExists(@Nonnull String bucket, @Nonnull String key); + + /** + * Downloads the bytes stored at the given {@code bucket} and {@code key}. The future completes + * exceptionally if the object does not exist. + */ + @Nonnull + CompletableFuture getObject(@Nonnull String bucket, @Nonnull String key); + + /** + * Diagnostic metadata about the client configuration, such as {@code {"client_region": + * "us-west-2"}}, that the driver appends to error messages. Returns an empty map by default. + */ + @Nonnull + default Map describe() { + return Collections.emptyMap(); + } +} diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageDriver.java b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageDriver.java new file mode 100644 index 0000000000..2d342ac2da --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageDriver.java @@ -0,0 +1,323 @@ +package io.temporal.payload.storage.s3; + +import com.google.protobuf.InvalidProtocolBufferException; +import io.temporal.api.common.v1.Payload; +import io.temporal.common.Experimental; +import io.temporal.payload.storage.PayloadHasher; +import io.temporal.payload.storage.StorageDriver; +import io.temporal.payload.storage.StorageDriverClaim; +import io.temporal.payload.storage.StorageDriverRetrieveContext; +import io.temporal.payload.storage.StorageDriverStoreContext; +import io.temporal.payload.storage.StorageDriverTargetInfo; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nonnull; + +/** + * {@link StorageDriver} that stores payloads in Amazon S3 under content-addressable keys derived + * from the SHA-256 hash of the serialized payload. + * + *

Construct via {@link #newBuilder()}. + */ +@Experimental +public final class S3StorageDriver implements StorageDriver { + private static final String DRIVER_TYPE = "aws.s3driver"; + private static final String DEFAULT_DRIVER_NAME = "aws.s3driver"; + private static final int DEFAULT_MAX_PAYLOAD_SIZE = 50 * 1024 * 1024; + private static final String HASH_ALGORITHM = "sha256"; + + private static final String CLAIM_BUCKET = "bucket"; + private static final String CLAIM_KEY = "key"; + private static final String CLAIM_HASH_ALGORITHM = "hash_algorithm"; + private static final String CLAIM_HASH_VALUE = "hash_value"; + + public static Builder newBuilder() { + return new Builder(); + } + + private final @Nonnull S3Client client; + private final @Nonnull BucketResolver bucketResolver; + private final @Nonnull String name; + private final int maxPayloadSize; + + private S3StorageDriver( + @Nonnull S3Client client, + @Nonnull BucketResolver bucketResolver, + @Nonnull String name, + int maxPayloadSize) { + this.client = client; + this.bucketResolver = bucketResolver; + this.name = name; + this.maxPayloadSize = maxPayloadSize; + } + + @Nonnull + @Override + public String getName() { + return name; + } + + @Nonnull + @Override + public String getType() { + return DRIVER_TYPE; + } + + @Nonnull + @Override + public CompletableFuture> store( + @Nonnull StorageDriverStoreContext context, @Nonnull List payloads) { + for (Payload payload : payloads) { + int size = payload.getSerializedSize(); + if (size > maxPayloadSize) { + return failedFuture( + new S3StorageException("payload size " + size + " exceeds maximum " + maxPayloadSize)); + } + } + + StorageDriverTargetInfo target = context.getTarget(); + String describeSuffix = describeSuffix(); + List> claimFutures = new ArrayList<>(payloads.size()); + for (Payload payload : payloads) { + byte[] data = payload.toByteArray(); + String hexDigest = PayloadHasher.sha256Hex(data); + String bucket = bucketResolver.resolveBucket(context, payload); + String key = S3StorageKey.forPayload(target, HASH_ALGORITHM, hexDigest); + String location = storageLocation(bucket, key, describeSuffix); + + CompletableFuture existsRequest = client.objectExists(bucket, key); + // We track current inflight request for cancellation + AtomicReference> inFlightRequest = new AtomicReference<>(existsRequest); + CompletableFuture claimFuture = + withFailureContext(existsRequest, "existence check failed " + location) + .thenCompose( + exists -> { + if (exists) { + return CompletableFuture.completedFuture(null); + } + CompletableFuture uploadRequest = client.putObject(bucket, key, data); + inFlightRequest.set(uploadRequest); + return withFailureContext(uploadRequest, "upload failed " + location); + }) + .thenApply(ignored -> claimFor(bucket, key, hexDigest)); + cancelRequestWhenCancelled(claimFuture, inFlightRequest); + claimFutures.add(claimFuture); + } + return CompletableFutures.allAsList(claimFutures); + } + + @Nonnull + @Override + public CompletableFuture> retrieve( + @Nonnull StorageDriverRetrieveContext context, @Nonnull List claims) { + String describeSuffix = describeSuffix(); + List> payloadFutures = new ArrayList<>(claims.size()); + for (StorageDriverClaim claim : claims) { + Map claimData = claim.getClaimData(); + String bucket = claimData.get(CLAIM_BUCKET); + if (bucket == null) { + payloadFutures.add(failedFuture(missingField(CLAIM_BUCKET))); + continue; + } + String key = claimData.get(CLAIM_KEY); + if (key == null) { + payloadFutures.add(failedFuture(missingField(CLAIM_KEY))); + continue; + } + String location = storageLocation(bucket, key, describeSuffix); + CompletableFuture downloadRequest = client.getObject(bucket, key); + CompletableFuture payloadFuture = + withFailureContext(downloadRequest, "download failed " + location) + .thenApply(data -> verifyAndParse(claimData, bucket, key, data)); + cancelRequestWhenCancelled(payloadFuture, downloadRequest); + payloadFutures.add(payloadFuture); + } + return CompletableFutures.allAsList(payloadFutures); + } + + private StorageDriverClaim claimFor(String bucket, String key, String hexDigest) { + Map claimData = new HashMap<>(); + claimData.put(CLAIM_BUCKET, bucket); + claimData.put(CLAIM_KEY, key); + claimData.put(CLAIM_HASH_ALGORITHM, HASH_ALGORITHM); + claimData.put(CLAIM_HASH_VALUE, hexDigest); + return new StorageDriverClaim(claimData); + } + + private Payload verifyAndParse( + Map claimData, String bucket, String key, byte[] data) { + String algorithm = claimData.get(CLAIM_HASH_ALGORITHM); + if (algorithm == null) { + throw missingField(CLAIM_HASH_ALGORITHM); + } + if (!HASH_ALGORITHM.equals(algorithm)) { + throw new S3StorageException("unsupported hash algorithm \"" + algorithm + "\""); + } + String expectedHash = claimData.get(CLAIM_HASH_VALUE); + if (expectedHash == null) { + throw missingField(CLAIM_HASH_VALUE); + } + String actualHash = PayloadHasher.sha256Hex(data); + if (!actualHash.equals(expectedHash)) { + throw new S3StorageException( + "integrity check failed [bucket=" + + bucket + + ", key=" + + key + + "]: expected hash " + + expectedHash + + ", got " + + actualHash); + } + try { + return Payload.parseFrom(data); + } catch (InvalidProtocolBufferException e) { + throw new S3StorageException( + "failed to unmarshal payload [bucket=" + bucket + ", key=" + key + "]", e); + } + } + + private static String storageLocation(String bucket, String key, String describeSuffix) { + return "[bucket=" + bucket + ", key=" + key + describeSuffix + "]"; + } + + /** Renders {@link S3Client#describe()} as a {@code ", k=v"} suffix for failure messages. */ + private String describeSuffix() { + Map describe = client.describe(); + if (describe == null || describe.isEmpty()) { + return ""; + } + StringBuilder sb = new StringBuilder(); + for (Map.Entry entry : describe.entrySet()) { + sb.append(", ").append(entry.getKey()).append("=").append(entry.getValue()); + } + return sb.toString(); + } + + private static S3StorageException missingField(String field) { + return new S3StorageException("claim missing field \"" + field + "\""); + } + + /** + * Cancels {@code request} when {@code pipeline} is cancelled + */ + private static void cancelRequestWhenCancelled( + CompletableFuture pipeline, CompletableFuture request) { + pipeline.whenComplete( + (value, ex) -> { + if (pipeline.isCancelled()) { + request.cancel(true); + } + }); + } + + /** + * Cancels {@code request} when {@code pipeline} is cancelled + */ + private static void cancelRequestWhenCancelled( + CompletableFuture pipeline, AtomicReference> inFlightRequest) { + pipeline.whenComplete( + (value, ex) -> { + if (pipeline.isCancelled()) { + inFlightRequest.get().cancel(true); + } + }); + } + + private static CompletableFuture withFailureContext( + CompletableFuture future, String failureMessage) { + return future.handle( + (value, ex) -> { + if (ex == null) { + return value; + } + Throwable cause = CompletableFutures.unwrap(ex); + String causeMessage = cause.getMessage() != null ? cause.getMessage() : cause.toString(); + throw new S3StorageException(failureMessage + ": " + causeMessage, cause); + }); + } + + private static CompletableFuture failedFuture(Throwable t) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(t); + return future; + } + + public static final class Builder { + private S3Client client; + private String staticBucket; + private BucketResolver bucketResolver; + private String name = DEFAULT_DRIVER_NAME; + private int maxPayloadSize = DEFAULT_MAX_PAYLOAD_SIZE; + + private Builder() {} + + /** Required. The S3 client used for storage operations. */ + public Builder setClient(@Nonnull S3Client client) { + this.client = Objects.requireNonNull(client, "client"); + return this; + } + + /** + * Stores every payload in a fixed bucket. Mutually exclusive with {@link #setBucketResolver}: + * setting both before {@link #build()} is an error. + */ + public Builder setBucket(@Nonnull String bucket) { + this.staticBucket = Objects.requireNonNull(bucket, "bucket"); + return this; + } + + /** + * Selects the bucket per payload. Mutually exclusive with {@link #setBucket} and setting both + * before {@link #build()} will throw. + */ + public Builder setBucketResolver(@Nonnull BucketResolver bucketResolver) { + this.bucketResolver = Objects.requireNonNull(bucketResolver, "bucketResolver"); + return this; + } + + /** + * Stable, unique identifier for this driver instance. Defaults to {@code "aws.s3driver"}; + * override it when registering multiple S3 drivers with distinct configurations. + */ + public Builder setName(@Nonnull String name) { + this.name = Objects.requireNonNull(name, "name"); + return this; + } + + /** + * Maximum serialized payload size in bytes the driver accepts. Defaults to 50 MiB. Storing a + * larger payload fails the {@code store} call. + */ + public Builder setMaxPayloadSize(int maxPayloadSize) { + this.maxPayloadSize = maxPayloadSize; + return this; + } + + public S3StorageDriver build() { + if (client == null) { + throw new IllegalStateException("client is required"); + } + if (staticBucket != null && bucketResolver != null) { + throw new IllegalStateException("setBucket and setBucketResolver are mutually exclusive"); + } + BucketResolver resolver = bucketResolver; + if (resolver == null && staticBucket != null) { + String bucket = staticBucket; + resolver = (context, payload) -> bucket; + } + if (resolver == null) { + throw new IllegalStateException("a bucket or bucket resolver is required"); + } + if (maxPayloadSize <= 0) { + throw new IllegalStateException("maxPayloadSize must be positive, got " + maxPayloadSize); + } + return new S3StorageDriver(client, resolver, name, maxPayloadSize); + } + } +} diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageException.java b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageException.java new file mode 100644 index 0000000000..8a49674c27 --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageException.java @@ -0,0 +1,15 @@ +package io.temporal.payload.storage.s3; + +import io.temporal.common.Experimental; + +/** Thrown when an {@link S3StorageDriver} store or retrieve operation fails. */ +@Experimental +public class S3StorageException extends RuntimeException { + public S3StorageException(String message) { + super(message); + } + + public S3StorageException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageKey.java b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageKey.java new file mode 100644 index 0000000000..dccf49bd5f --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/main/java/io/temporal/payload/storage/s3/S3StorageKey.java @@ -0,0 +1,69 @@ +package io.temporal.payload.storage.s3; + +import io.temporal.payload.storage.StorageDriverActivityInfo; +import io.temporal.payload.storage.StorageDriverTargetInfo; +import io.temporal.payload.storage.StorageDriverWorkflowInfo; +import java.nio.charset.StandardCharsets; + +/** + * Builds the content-addressable S3 object key. The key format and percent-encoding rules are the + * cross-SDK specification documented in this package's {@code README.md}. + */ +final class S3StorageKey { + private static final String KEY_VERSION = "v0"; + private static final String PATH_SEGMENT_UNRESERVED = "-_.~$&+:=@"; + + private S3StorageKey() {} + + static String forPayload(StorageDriverTargetInfo target, String hashAlgorithm, String hexDigest) { + String digestSegment = "/d/" + hashAlgorithm + "/" + hexDigest; + if (target instanceof StorageDriverWorkflowInfo) { + StorageDriverWorkflowInfo wf = (StorageDriverWorkflowInfo) target; + return KEY_VERSION + + "/ns/" + + escapePathSegment(wf.getNamespace()) + + "/wt/" + + escapePathSegment(wf.getType()) + + "/wi/" + + escapePathSegment(wf.getId()) + + "/ri/" + + escapePathSegment(wf.getRunId()) + + digestSegment; + } + if (target instanceof StorageDriverActivityInfo) { + StorageDriverActivityInfo act = (StorageDriverActivityInfo) target; + return KEY_VERSION + + "/ns/" + + escapePathSegment(act.getNamespace()) + + "/at/" + + escapePathSegment(act.getType()) + + "/ai/" + + escapePathSegment(act.getId()) + + "/ri/" + + escapePathSegment(act.getRunId()) + + digestSegment; + } + return KEY_VERSION + digestSegment; + } + + static String escapePathSegment(String value) { + if (value == null || value.isEmpty()) { + return "null"; + } + StringBuilder sb = new StringBuilder(value.length()); + for (byte b : value.getBytes(StandardCharsets.UTF_8)) { + int c = b & 0xFF; + if ((c >= 'A' && c <= 'Z') + || (c >= 'a' && c <= 'z') + || (c >= '0' && c <= '9') + || PATH_SEGMENT_UNRESERVED.indexOf(c) >= 0) { + sb.append((char) c); + } else { + sb.append('%'); + sb.append(Character.toUpperCase(Character.forDigit((c >> 4) & 0xF, 16))); + sb.append(Character.toUpperCase(Character.forDigit(c & 0xF, 16))); + } + } + return sb.toString(); + } +} diff --git a/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3AsyncClientAdapterTest.java b/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3AsyncClientAdapterTest.java new file mode 100644 index 0000000000..3049a2c6bb --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3AsyncClientAdapterTest.java @@ -0,0 +1,38 @@ +package io.temporal.payload.storage.s3; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletableFuture; +import org.junit.Test; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.PutObjectResponse; + +public class S3AsyncClientAdapterTest { + + /** + * Cancelling the future the adapter returns must abort the underlying AWS request. The adapter + * wraps the AWS future with {@code thenApply}, which does not propagate cancellation upstream, so + * this verifies the explicit forwarding does its job. + */ + @Test + public void cancellingReturnedFutureAbortsTheUnderlyingRequest() { + S3AsyncClient s3 = mock(S3AsyncClient.class); + CompletableFuture awsRequest = new CompletableFuture<>(); + when(s3.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))) + .thenReturn(awsRequest); + + CompletableFuture result = + new S3AsyncClientAdapter(s3).putObject("bucket", "key", new byte[] {1, 2, 3}); + + assertFalse(awsRequest.isCancelled()); + result.cancel(true); + assertTrue( + "cancelling the adapter's future should abort the AWS request", awsRequest.isCancelled()); + } +} diff --git a/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3StorageDriverTest.java b/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3StorageDriverTest.java new file mode 100644 index 0000000000..aa793237e6 --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3StorageDriverTest.java @@ -0,0 +1,584 @@ +package io.temporal.payload.storage.s3; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import com.google.protobuf.ByteString; +import io.temporal.api.common.v1.Payload; +import io.temporal.payload.storage.StorageDriverActivityInfo; +import io.temporal.payload.storage.StorageDriverClaim; +import io.temporal.payload.storage.StorageDriverRetrieveContext; +import io.temporal.payload.storage.StorageDriverStoreContext; +import io.temporal.payload.storage.StorageDriverTargetInfo; +import io.temporal.payload.storage.StorageDriverWorkflowInfo; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; + +public class S3StorageDriverTest { + + private static Payload payload(String data) { + return Payload.newBuilder() + .putMetadata("encoding", ByteString.copyFromUtf8("binary/plain")) + .setData(ByteString.copyFromUtf8(data)) + .build(); + } + + private static S3StorageDriver driver(S3Client client) { + return S3StorageDriver.newBuilder().setClient(client).setBucket("test-bucket").build(); + } + + private static StorageDriverStoreContext storeContext() { + return () -> null; + } + + private static StorageDriverStoreContext storeContext(StorageDriverTargetInfo target) { + return () -> target; + } + + private static final StorageDriverRetrieveContext RETRIEVE_CONTEXT = + new StorageDriverRetrieveContext() {}; + + /** Joins a future expected to fail and returns the message of the underlying cause. */ + private static String failureMessage(CompletableFuture future) { + try { + future.join(); + } catch (CompletionException e) { + Throwable cause = e; + while (cause instanceof CompletionException && cause.getCause() != null) { + cause = cause.getCause(); + } + return cause.getMessage(); + } + fail("expected the future to fail"); + return null; + } + + // --- Builder --- + + @Test + public void builderDefaults() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + assertEquals("aws.s3driver", driver.getName()); + assertEquals("aws.s3driver", driver.getType()); + } + + @Test + public void builderCustomName() { + S3StorageDriver driver = + S3StorageDriver.newBuilder() + .setClient(new InMemoryS3Client()) + .setBucket("b") + .setName("custom-name") + .build(); + assertEquals("custom-name", driver.getName()); + } + + @Test(expected = IllegalStateException.class) + public void builderRequiresClient() { + S3StorageDriver.newBuilder().setBucket("b").build(); + } + + @Test(expected = IllegalStateException.class) + public void builderRequiresBucket() { + S3StorageDriver.newBuilder().setClient(new InMemoryS3Client()).build(); + } + + @Test(expected = IllegalStateException.class) + public void builderRejectsNonPositiveMaxPayloadSize() { + S3StorageDriver.newBuilder() + .setClient(new InMemoryS3Client()) + .setBucket("b") + .setMaxPayloadSize(0) + .build(); + } + + @Test(expected = IllegalStateException.class) + public void builderRejectsBothBucketAndResolver() { + S3StorageDriver.newBuilder() + .setClient(new InMemoryS3Client()) + .setBucket("b") + .setBucketResolver((context, payload) -> "other") + .build(); + } + + // --- Store --- + + @Test + public void storeSinglePayloadProducesClaim() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + + List claims = + driver.store(storeContext(), Collections.singletonList(payload("hello"))).join(); + + assertEquals(1, claims.size()); + Map claimData = claims.get(0).getClaimData(); + assertEquals("test-bucket", claimData.get("bucket")); + assertEquals("sha256", claimData.get("hash_algorithm")); + assertFalse(claimData.get("hash_value").isEmpty()); + assertEquals("v0/d/sha256/" + claimData.get("hash_value"), claimData.get("key")); + } + + @Test + public void storeEmptyPayloadsProducesNoClaims() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + assertTrue(driver.store(storeContext(), Collections.emptyList()).join().isEmpty()); + } + + @Test + public void storeDeduplicatesIdenticalPayloads() { + InMemoryS3Client client = new InMemoryS3Client(); + S3StorageDriver driver = driver(client); + Payload p = payload("duplicate-me"); + + driver.store(storeContext(), Collections.singletonList(p)).join(); + assertEquals(1, client.putCount.get()); + + driver.store(storeContext(), Collections.singletonList(p)).join(); + assertEquals(1, client.putCount.get()); + } + + @Test + public void storeMultiplePayloadsProducesDistinctKeys() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + + List claims = + driver + .store(storeContext(), Arrays.asList(payload("a"), payload("b"), payload("c"))) + .join(); + + assertEquals(3, claims.size()); + assertEquals(3, claims.stream().map(c -> c.getClaimData().get("key")).distinct().count()); + } + + @Test + public void storeRejectsOversizedPayload() { + S3StorageDriver driver = + S3StorageDriver.newBuilder() + .setClient(new InMemoryS3Client()) + .setBucket("b") + .setMaxPayloadSize(10) + .build(); + + String message = + failureMessage( + driver.store( + storeContext(), + Collections.singletonList(payload("definitely longer than ten bytes")))); + assertTrue( + message, message.contains("payload size ") && message.contains("exceeds maximum 10")); + } + + @Test + public void storeUploadsNothingWhenAnyPayloadFailsValidation() { + InMemoryS3Client client = new InMemoryS3Client(); + Payload small = payload("small"); + Payload oversized = payload(String.join("", Collections.nCopies(1000, "x"))); + S3StorageDriver driver = + S3StorageDriver.newBuilder() + .setClient(client) + .setBucket("b") + .setMaxPayloadSize(small.getSerializedSize()) + .build(); + + // The valid payload precedes the oversized one; validation must reject the batch before any + // upload starts, leaving nothing written to S3. + failureMessage(driver.store(storeContext(), Arrays.asList(small, oversized))); + assertEquals(0, client.putCount.get()); + } + + @Test + public void storeResolvesBucketPerPayload() { + S3StorageDriver driver = + S3StorageDriver.newBuilder() + .setClient(new InMemoryS3Client()) + .setBucketResolver( + (context, payload) -> + "a".equals(payload.getData().toStringUtf8()) ? "bucket-a" : "bucket-b") + .build(); + + List claims = + driver.store(storeContext(), Arrays.asList(payload("a"), payload("b"))).join(); + + assertEquals("bucket-a", claims.get(0).getClaimData().get("bucket")); + assertEquals("bucket-b", claims.get(1).getClaimData().get("bucket")); + } + + @Test + public void storeWrapsUploadErrorWithContext() { + InMemoryS3Client client = new InMemoryS3Client(); + client.putError = new RuntimeException("access denied"); + S3StorageDriver driver = driver(client); + + String message = + failureMessage(driver.store(storeContext(), Collections.singletonList(payload("x")))); + assertTrue(message, message.startsWith("upload failed [bucket=test-bucket, key=")); + assertTrue(message, message.endsWith(", client_region=ap-southeast-2]: access denied")); + } + + @Test + public void storeWrapsExistenceCheckErrorWithContext() { + InMemoryS3Client client = new InMemoryS3Client(); + client.existsError = new RuntimeException("network timeout"); + S3StorageDriver driver = driver(client); + + String message = + failureMessage(driver.store(storeContext(), Collections.singletonList(payload("x")))); + assertTrue(message, message.startsWith("existence check failed [bucket=test-bucket, key=")); + assertTrue(message, message.endsWith(", client_region=ap-southeast-2]: network timeout")); + } + + // --- Store with target identity --- + + @Test + public void storeKeyIncludesWorkflowTarget() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + StorageDriverTargetInfo target = + new StorageDriverWorkflowInfo("default", "wf-123", "run-456", "MyWorkflow"); + + String key = + driver + .store(storeContext(target), Collections.singletonList(payload("p"))) + .join() + .get(0) + .getClaimData() + .get("key"); + assertTrue(key, key.startsWith("v0/ns/default/wt/MyWorkflow/wi/wf-123/ri/run-456/d/sha256/")); + } + + @Test + public void storeKeyIncludesActivityTarget() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + StorageDriverTargetInfo target = + new StorageDriverActivityInfo("default", "act-789", "run-abc", "MyActivity"); + + String key = + driver + .store(storeContext(target), Collections.singletonList(payload("p"))) + .join() + .get(0) + .getClaimData() + .get("key"); + assertTrue(key, key.startsWith("v0/ns/default/at/MyActivity/ai/act-789/ri/run-abc/d/sha256/")); + } + + @Test + public void storeKeyPercentEncodesSpecialChars() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + StorageDriverTargetInfo target = + new StorageDriverWorkflowInfo("my namespace", "wf id+1", "run=abc", "my/workflow"); + + String key = + driver + .store(storeContext(target), Collections.singletonList(payload("p"))) + .join() + .get(0) + .getClaimData() + .get("key"); + assertTrue( + key, + key.startsWith("v0/ns/my%20namespace/wt/my%2Fworkflow/wi/wf%20id+1/ri/run=abc/d/sha256/")); + } + + @Test + public void storageKeyEscapesPathSegmentsByContract() { + assertEquals("null", S3StorageKey.escapePathSegment(null)); + assertEquals("null", S3StorageKey.escapePathSegment("")); + assertEquals("azAZ09-_.~$&+:=@", S3StorageKey.escapePathSegment("azAZ09-_.~$&+:=@")); + assertEquals( + "space%20slash%2Fpercent%25snowman%E2%98%83", + S3StorageKey.escapePathSegment("space slash/percent%snowman\u2603")); + } + + @Test + public void storageKeyReadmeExamples() { + // Segment encoding examples. + assertEquals("my%20namespace", S3StorageKey.escapePathSegment("my namespace")); + assertEquals("my%2Fworkflow", S3StorageKey.escapePathSegment("my/workflow")); + assertEquals("wf%20id+1", S3StorageKey.escapePathSegment("wf id+1")); + assertEquals("attempt=1", S3StorageKey.escapePathSegment("attempt=1")); + + String workflowDigest = "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"; + String activityDigest = "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"; + String fallbackDigest = "486ea46224d1bb4fb680f34f7c9ad96a8f24ec88be73ea8e5a6c65260e9cb8a7"; + + // Workflow full-key example. + assertEquals( + "v0/ns/payments%20prod/wt/ChargeWorkflow/wi/order+123=abc/ri/3f1d6c7a-8b2e-4f7a-9d0a-87a6f95e4d31/d/sha256/" + + workflowDigest, + S3StorageKey.forPayload( + new StorageDriverWorkflowInfo( + "payments prod", + "order+123=abc", + "3f1d6c7a-8b2e-4f7a-9d0a-87a6f95e4d31", + "ChargeWorkflow"), + "sha256", + workflowDigest)); + + // Activity full-key example. + assertEquals( + "v0/ns/payments%20prod/at/Capture%2FCharge/ai/activity%20id+42/ri/9e1d1fd9-2f8a-4c40-93e2-731f31b9268b/d/sha256/" + + activityDigest, + S3StorageKey.forPayload( + new StorageDriverActivityInfo( + "payments prod", + "activity id+42", + "9e1d1fd9-2f8a-4c40-93e2-731f31b9268b", + "Capture/Charge"), + "sha256", + activityDigest)); + + // Fallback full-key example. + assertEquals( + "v0/d/sha256/" + fallbackDigest, S3StorageKey.forPayload(null, "sha256", fallbackDigest)); + } + + @Test + public void storeSamePayloadDifferentTargetsProducesDifferentKeys() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + Payload p = payload("shared"); + + String wfKey = + driver + .store( + storeContext(new StorageDriverWorkflowInfo("ns", "wf-1", "run-1", "WF")), + Collections.singletonList(p)) + .join() + .get(0) + .getClaimData() + .get("key"); + String actKey = + driver + .store( + storeContext(new StorageDriverActivityInfo("ns", "act-1", "run-1", "ACT")), + Collections.singletonList(p)) + .join() + .get(0) + .getClaimData() + .get("key"); + assertNotEquals(wfKey, actKey); + } + + // --- Retrieve --- + + @Test + public void retrieveRoundTrip() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + Payload original = payload("round-trip data"); + + List claims = + driver.store(storeContext(), Collections.singletonList(original)).join(); + List restored = driver.retrieve(RETRIEVE_CONTEXT, claims).join(); + + assertEquals(1, restored.size()); + assertEquals(original, restored.get(0)); + } + + @Test + public void retrieveRoundTripMultiplePreservesOrder() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + List originals = Arrays.asList(payload("x"), payload("y"), payload("z")); + + List claims = driver.store(storeContext(), originals).join(); + List restored = driver.retrieve(RETRIEVE_CONTEXT, claims).join(); + + assertEquals(originals, restored); + } + + @Test + public void retrieveDetectsCorruptedData() { + InMemoryS3Client client = new InMemoryS3Client(); + S3StorageDriver driver = driver(client); + + List claims = + driver.store(storeContext(), Collections.singletonList(payload("legit"))).join(); + client.objects.replaceAll((k, v) -> "corrupted".getBytes()); + + String message = failureMessage(driver.retrieve(RETRIEVE_CONTEXT, claims)); + assertTrue(message, message.startsWith("integrity check failed [bucket=test-bucket, key=")); + } + + @Test + public void retrieveRejectsUnsupportedHashAlgorithm() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + List claims = + driver.store(storeContext(), Collections.singletonList(payload("data"))).join(); + + Map tampered = new HashMap<>(claims.get(0).getClaimData()); + tampered.put("hash_algorithm", "md5"); + + String message = + failureMessage( + driver.retrieve( + RETRIEVE_CONTEXT, Collections.singletonList(new StorageDriverClaim(tampered)))); + assertEquals("unsupported hash algorithm \"md5\"", message); + } + + @Test + public void retrieveRejectsClaimMissingBucket() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + StorageDriverClaim claim = + new StorageDriverClaim(Collections.singletonMap("key", "v0/d/sha256/abc")); + + assertEquals( + "claim missing field \"bucket\"", + failureMessage(driver.retrieve(RETRIEVE_CONTEXT, Collections.singletonList(claim)))); + } + + @Test + public void retrieveRejectsClaimMissingKey() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + StorageDriverClaim claim = + new StorageDriverClaim(Collections.singletonMap("bucket", "test-bucket")); + + assertEquals( + "claim missing field \"key\"", + failureMessage(driver.retrieve(RETRIEVE_CONTEXT, Collections.singletonList(claim)))); + } + + @Test + public void retrieveRejectsClaimMissingHashAlgorithm() { + S3StorageDriver driver = driver(new InMemoryS3Client()); + List claims = + driver.store(storeContext(), Collections.singletonList(payload("x"))).join(); + + Map tampered = new HashMap<>(claims.get(0).getClaimData()); + tampered.remove("hash_algorithm"); + + assertEquals( + "claim missing field \"hash_algorithm\"", + failureMessage( + driver.retrieve( + RETRIEVE_CONTEXT, Collections.singletonList(new StorageDriverClaim(tampered))))); + } + + @Test + public void retrieveWrapsDownloadErrorWithContext() { + InMemoryS3Client client = new InMemoryS3Client(); + S3StorageDriver driver = driver(client); + List claims = + driver.store(storeContext(), Collections.singletonList(payload("data"))).join(); + + client.getError = new RuntimeException("throttled"); + + String message = failureMessage(driver.retrieve(RETRIEVE_CONTEXT, claims)); + assertTrue(message, message.startsWith("download failed [bucket=test-bucket, key=")); + assertTrue(message, message.endsWith(", client_region=ap-southeast-2]: throttled")); + } + + @Test(timeout = 5000) + public void storeFailsFastAndCancelsInFlightUploads() { + // The first upload fails; the second stays pending. The batch must surface the failure promptly + // (rather than blocking on the pending upload), as an unwrapped S3StorageException, and must + // cancel the still-running upload. + HoldSecondUploadClient client = new HoldSecondUploadClient(); + S3StorageDriver driver = S3StorageDriver.newBuilder().setClient(client).setBucket("b").build(); + + CompletableFuture> future = + driver.store(storeContext(), Arrays.asList(payload("a"), payload("b"))); + + try { + future.join(); + fail("expected the future to fail"); + } catch (CompletionException e) { + assertTrue(String.valueOf(e.getCause()), e.getCause() instanceof S3StorageException); + assertTrue(e.getCause().getMessage(), e.getCause().getMessage().endsWith(": boom")); + } + assertTrue("the in-flight upload should be cancelled", client.secondUpload.isCancelled()); + } + + /** + * Fails the first upload and leaves the second pending (cancellable), to exercise fail-fast and + * in-flight cancellation. + */ + private static final class HoldSecondUploadClient implements S3Client { + private final AtomicInteger puts = new AtomicInteger(); + final CompletableFuture secondUpload = new CompletableFuture<>(); + + @Override + public CompletableFuture putObject(String bucket, String key, byte[] data) { + if (puts.incrementAndGet() == 1) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(new RuntimeException("boom")); + return failed; + } + return secondUpload; + } + + @Override + public CompletableFuture objectExists(String bucket, String key) { + return CompletableFuture.completedFuture(false); + } + + @Override + public CompletableFuture getObject(String bucket, String key) { + CompletableFuture failed = new CompletableFuture<>(); + failed.completeExceptionally(new UnsupportedOperationException()); + return failed; + } + } + + /** In-memory {@link S3Client} with optional error injection, for unit tests. */ + private static final class InMemoryS3Client implements S3Client { + final Map objects = new ConcurrentHashMap<>(); + final AtomicInteger putCount = new AtomicInteger(); + RuntimeException putError; + RuntimeException getError; + RuntimeException existsError; + + private static String objectKey(String bucket, String key) { + return bucket + "/" + key; + } + + @Override + public CompletableFuture putObject(String bucket, String key, byte[] data) { + if (putError != null) { + return failed(putError); + } + putCount.incrementAndGet(); + objects.put(objectKey(bucket, key), data.clone()); + return CompletableFuture.completedFuture(null); + } + + @Override + public CompletableFuture objectExists(String bucket, String key) { + if (existsError != null) { + return failed(existsError); + } + return CompletableFuture.completedFuture(objects.containsKey(objectKey(bucket, key))); + } + + @Override + public CompletableFuture getObject(String bucket, String key) { + if (getError != null) { + return failed(getError); + } + byte[] data = objects.get(objectKey(bucket, key)); + if (data == null) { + return failed(new RuntimeException("not found: " + objectKey(bucket, key))); + } + return CompletableFuture.completedFuture(data.clone()); + } + + @Override + public Map describe() { + return Collections.singletonMap("client_region", "ap-southeast-2"); + } + + private static CompletableFuture failed(Throwable t) { + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(t); + return future; + } + } +} diff --git a/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3StorageKeyTest.java b/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3StorageKeyTest.java new file mode 100644 index 0000000000..62a3d576bc --- /dev/null +++ b/contrib/temporal-payload-storage-s3/src/test/java/io/temporal/payload/storage/s3/S3StorageKeyTest.java @@ -0,0 +1,31 @@ +package io.temporal.payload.storage.s3; + +import static org.junit.Assert.assertEquals; + +import org.junit.Test; + +public class S3StorageKeyTest { + + @Test + public void escapesEmptyAndNullAsNull() { + assertEquals("null", S3StorageKey.escapePathSegment("")); + assertEquals("null", S3StorageKey.escapePathSegment(null)); + } + + @Test + public void leavesUnreservedCharactersUnescaped() { + assertEquals("AZaz09-_.~$&+:=@", S3StorageKey.escapePathSegment("AZaz09-_.~$&+:=@")); + } + + @Test + public void percentEncodesReservedCharactersAndSpace() { + assertEquals("a%2Fb%20c", S3StorageKey.escapePathSegment("a/b c")); + } + + @Test + public void percentEncodesMultibyteUtf8() { + // 'é' (U+00E9) is the two UTF-8 bytes C3 A9; '€' (U+20AC) is the three bytes E2 82 AC. + assertEquals("caf%C3%A9", S3StorageKey.escapePathSegment("café")); + assertEquals("%E2%82%AC", S3StorageKey.escapePathSegment("€")); + } +} diff --git a/settings.gradle b/settings.gradle index fe80370b0c..fb4b3dbff1 100644 --- a/settings.gradle +++ b/settings.gradle @@ -9,6 +9,8 @@ project(':temporal-opentracing').projectDir = file('contrib/temporal-opentracing include 'temporal-kotlin' include 'temporal-spring-ai' project(':temporal-spring-ai').projectDir = file('contrib/temporal-spring-ai') +include 'temporal-payload-storage-s3' +project(':temporal-payload-storage-s3').projectDir = file('contrib/temporal-payload-storage-s3') include 'temporal-spring-boot-autoconfigure' include 'temporal-spring-boot-starter' include 'temporal-remote-data-encoder' diff --git a/temporal-sdk/src/main/java/io/temporal/payload/storage/PayloadHasher.java b/temporal-sdk/src/main/java/io/temporal/payload/storage/PayloadHasher.java new file mode 100644 index 0000000000..4e3c8fe509 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/payload/storage/PayloadHasher.java @@ -0,0 +1,31 @@ +package io.temporal.payload.storage; + +import io.temporal.common.Experimental; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import javax.annotation.Nonnull; + +/** Computes payload hashes shared by external storage drivers. */ +@Experimental +public final class PayloadHasher { + private static final char[] HEX = "0123456789abcdef".toCharArray(); + + private PayloadHasher() {} + + /** Returns the lower-case SHA-256 hex digest of {@code data}. */ + @Nonnull + public static String sha256Hex(@Nonnull byte[] data) { + byte[] digest; + try { + // If we ever move to Java 17+ we can use HexFormat.of().formatHex() instead. + digest = MessageDigest.getInstance("SHA-256").digest(data); + } catch (NoSuchAlgorithmException e) { + throw new AssertionError("SHA-256 MessageDigest cannot be found", e); + } + StringBuilder sb = new StringBuilder(digest.length * 2); + for (byte b : digest) { + sb.append(HEX[(b >> 4) & 0xF]).append(HEX[b & 0xF]); + } + return sb.toString(); + } +}