From 6870fca7afaa97e870ca50da166edd2c310827eb Mon Sep 17 00:00:00 2001 From: Aniekan Victory Date: Sun, 28 Jun 2026 15:48:59 +0100 Subject: [PATCH] Add cross-contract communication contract --- Cargo.toml | 1 + contracts/cross_contract/Cargo.toml | 16 + contracts/cross_contract/src/lib.rs | 713 +++++++++++++++++++++++++++ contracts/cross_contract/src/test.rs | 575 +++++++++++++++++++++ 4 files changed, 1305 insertions(+) create mode 100644 contracts/cross_contract/Cargo.toml create mode 100644 contracts/cross_contract/src/lib.rs create mode 100644 contracts/cross_contract/src/test.rs diff --git a/Cargo.toml b/Cargo.toml index 2af0642..12718f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -92,6 +92,7 @@ members = [ "contracts/governance_token", "contracts/event_logging", "contracts/revenue_share", + "contracts/cross_contract", ] [workspace.dependencies] diff --git a/contracts/cross_contract/Cargo.toml b/contracts/cross_contract/Cargo.toml new file mode 100644 index 0000000..cd8977d --- /dev/null +++ b/contracts/cross_contract/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "cross_contract" +version = "0.1.0" +edition = "2021" + +[dependencies] +soroban-sdk = { workspace = true } + +[dev-dependencies] +soroban-sdk = { workspace = true, features = ["testutils"] } + +[lib] +crate-type = ["cdylib"] + +[features] +default = [] diff --git a/contracts/cross_contract/src/lib.rs b/contracts/cross_contract/src/lib.rs new file mode 100644 index 0000000..0d8f153 --- /dev/null +++ b/contracts/cross_contract/src/lib.rs @@ -0,0 +1,713 @@ +#![no_std] + +use soroban_sdk::{ + contract, contracterror, contractimpl, contracttype, symbol_short, Address, Bytes, Env, + IntoVal, Symbol, Vec, +}; + +#[contracterror] +#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)] +#[repr(u32)] +pub enum CrossContractError { + InvalidConfig = 1, + AlreadyInitialized = 2, + NotInitialized = 3, + Unauthorized = 4, + RouteNotFound = 5, + RouteDisabled = 6, + InvalidCallbackConfig = 7, + RateLimited = 8, + QueueEmpty = 9, + MessageNotFound = 10, + MessageNotQueued = 11, + TargetInvocationFailed = 12, + CallbackInvocationFailed = 13, + UnexpectedQueueState = 14, +} + +#[contracttype] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u32)] +pub enum MessageStatus { + Queued = 0, + Delivered = 1, + Failed = 2, + CallbackFailed = 3, +} + +#[contracttype] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(u32)] +pub enum AuditAction { + Enqueued = 0, + Routed = 1, + CallbackSucceeded = 2, + Delivered = 3, + Failed = 4, + CallbackFailed = 5, +} + +#[contracttype] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RateLimitConfig { + pub window_secs: u64, + pub max_messages: u32, +} + +#[contracttype] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct SenderWindow { + pub window: u64, + pub count: u32, +} + +#[contracttype] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct RouteConfig { + pub key: Symbol, + pub target_contract: Address, + pub target_method: Symbol, + pub default_callback_contract: Option
, + pub default_callback_method: Option, + pub enabled: bool, +} + +#[contracttype] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Message { + pub id: u64, + pub route: Symbol, + pub sender: Address, + pub payload: Bytes, + pub atomic: bool, + pub target_contract: Address, + pub target_method: Symbol, + pub callback_contract: Option
, + pub callback_method: Option, + pub status: MessageStatus, + pub queued_at: u64, + pub processed_at: Option, + pub response: Option, + pub last_error: Option, +} + +#[contracttype] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct AuditEntry { + pub timestamp: u64, + pub action: AuditAction, + pub status: MessageStatus, + pub error: Option, +} + +#[contracttype] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ProcessOutcome { + pub message_id: u64, + pub status: MessageStatus, + pub response: Option, +} + +#[contracttype] +#[derive(Clone, Debug, Eq, PartialEq)] +enum DataKey { + Admin, + RateLimit, + QueueHead, + QueueTail, + NextMessageId, + Route(Symbol), + Message(u64), + Audit(u64), + QueueSlot(u64), + SenderWindow(Address), +} + +#[contract] +pub struct CrossContractCommunication; + +#[contractimpl] +impl CrossContractCommunication { + pub fn initialize( + env: Env, + admin: Address, + window_secs: u64, + max_messages: u32, + ) -> Result<(), CrossContractError> { + admin.require_auth(); + if env.storage().instance().has(&DataKey::Admin) { + return Err(CrossContractError::AlreadyInitialized); + } + + let config = validate_rate_limit(window_secs, max_messages)?; + env.storage().instance().set(&DataKey::Admin, &admin); + env.storage().instance().set(&DataKey::RateLimit, &config); + env.storage().instance().set(&DataKey::QueueHead, &0u64); + env.storage().instance().set(&DataKey::QueueTail, &0u64); + env.storage().instance().set(&DataKey::NextMessageId, &1u64); + Ok(()) + } + + pub fn set_rate_limit( + env: Env, + admin: Address, + window_secs: u64, + max_messages: u32, + ) -> Result<(), CrossContractError> { + require_admin(&env, &admin)?; + let config = validate_rate_limit(window_secs, max_messages)?; + env.storage().instance().set(&DataKey::RateLimit, &config); + Ok(()) + } + + pub fn register_route( + env: Env, + admin: Address, + key: Symbol, + target_contract: Address, + target_method: Symbol, + callback_contract: Option
, + callback_method: Option, + ) -> Result<(), CrossContractError> { + require_admin(&env, &admin)?; + validate_callback_pair(&callback_contract, &callback_method)?; + + let route = RouteConfig { + key: key.clone(), + target_contract, + target_method, + default_callback_contract: callback_contract, + default_callback_method: callback_method, + enabled: true, + }; + env.storage() + .persistent() + .set(&DataKey::Route(key), &route); + Ok(()) + } + + pub fn set_route_enabled( + env: Env, + admin: Address, + key: Symbol, + enabled: bool, + ) -> Result<(), CrossContractError> { + require_admin(&env, &admin)?; + let mut route = get_route_or_err(&env, &key)?; + route.enabled = enabled; + env.storage() + .persistent() + .set(&DataKey::Route(key), &route); + Ok(()) + } + + pub fn queue_message( + env: Env, + sender: Address, + route: Symbol, + payload: Bytes, + atomic: bool, + callback_contract: Option
, + callback_method: Option, + ) -> Result { + ensure_initialized(&env)?; + sender.require_auth(); + + let route_cfg = get_route_or_err(&env, &route)?; + if !route_cfg.enabled { + return Err(CrossContractError::RouteDisabled); + } + + let (resolved_callback_contract, resolved_callback_method) = resolve_callback( + &route_cfg, + callback_contract, + callback_method, + )?; + + enforce_rate_limit(&env, &sender)?; + + let message_id = next_message_id(&env); + let message = Message { + id: message_id, + route: route.clone(), + sender: sender.clone(), + payload, + atomic, + target_contract: route_cfg.target_contract, + target_method: route_cfg.target_method, + callback_contract: resolved_callback_contract, + callback_method: resolved_callback_method, + status: MessageStatus::Queued, + queued_at: env.ledger().timestamp(), + processed_at: None, + response: None, + last_error: None, + }; + + env.storage() + .persistent() + .set(&DataKey::Message(message_id), &message); + push_queue(&env, message_id); + append_audit( + &env, + message_id, + AuditAction::Enqueued, + MessageStatus::Queued, + None, + ); + env.events() + .publish((symbol_short!("queued"), route), (message_id, sender)); + + Ok(message_id) + } + + pub fn process_next(env: Env) -> Result { + ensure_initialized(&env)?; + let message_id = peek_queue(&env)?; + let message = get_message_or_err(&env, message_id)?; + + if message.status != MessageStatus::Queued { + return Err(CrossContractError::MessageNotQueued); + } + + if message.atomic { + Self::process_atomic(env, message) + } else { + Self::process_non_atomic(env, message) + } + } + + pub fn get_message(env: Env, message_id: u64) -> Option { + env.storage().persistent().get(&DataKey::Message(message_id)) + } + + pub fn get_route(env: Env, key: Symbol) -> Option { + env.storage().persistent().get(&DataKey::Route(key)) + } + + pub fn get_audit_trail(env: Env, message_id: u64) -> Vec { + env.storage() + .persistent() + .get(&DataKey::Audit(message_id)) + .unwrap_or(Vec::new(&env)) + } + + pub fn get_queue_size(env: Env) -> Result { + ensure_initialized(&env)?; + Ok(queue_size(&env)) + } + + pub fn get_rate_limit(env: Env) -> Result { + ensure_initialized(&env)?; + get_rate_limit_config(&env) + } + + pub fn get_sender_window(env: Env, sender: Address) -> Option { + env.storage() + .persistent() + .get(&DataKey::SenderWindow(sender)) + } + + fn process_atomic(env: Env, mut message: Message) -> Result { + let response: Bytes = match env.try_invoke_contract::( + &message.target_contract, + &message.target_method, + build_target_args(&env, &message), + ) { + Ok(Ok(response)) => response, + _ => return Err(CrossContractError::TargetInvocationFailed), + }; + + if let (Some(callback_contract), Some(callback_method)) = + (message.callback_contract.clone(), message.callback_method.clone()) + { + let accepted = match env.try_invoke_contract::( + &callback_contract, + &callback_method, + build_callback_args(&env, &message, &response), + ) { + Ok(Ok(accepted)) => accepted, + _ => return Err(CrossContractError::CallbackInvocationFailed), + }; + + if !accepted { + return Err(CrossContractError::CallbackInvocationFailed); + } + + append_audit( + &env, + message.id, + AuditAction::Routed, + MessageStatus::Queued, + None, + ); + append_audit( + &env, + message.id, + AuditAction::CallbackSucceeded, + MessageStatus::Queued, + None, + ); + } else { + append_audit( + &env, + message.id, + AuditAction::Routed, + MessageStatus::Queued, + None, + ); + } + + message.status = MessageStatus::Delivered; + message.processed_at = Some(env.ledger().timestamp()); + message.response = Some(response.clone()); + message.last_error = None; + + env.storage() + .persistent() + .set(&DataKey::Message(message.id), &message); + pop_queue(&env, message.id)?; + append_audit( + &env, + message.id, + AuditAction::Delivered, + MessageStatus::Delivered, + None, + ); + env.events() + .publish((symbol_short!("done"), message.route), message.id); + + Ok(ProcessOutcome { + message_id: message.id, + status: MessageStatus::Delivered, + response: Some(response), + }) + } + + fn process_non_atomic( + env: Env, + mut message: Message, + ) -> Result { + let response: Bytes = match env.try_invoke_contract::( + &message.target_contract, + &message.target_method, + build_target_args(&env, &message), + ) { + Ok(Ok(response)) => response, + _ => { + message.status = MessageStatus::Failed; + message.processed_at = Some(env.ledger().timestamp()); + message.last_error = Some(CrossContractError::TargetInvocationFailed as u32); + env.storage() + .persistent() + .set(&DataKey::Message(message.id), &message); + pop_queue(&env, message.id)?; + append_audit( + &env, + message.id, + AuditAction::Failed, + MessageStatus::Failed, + Some(CrossContractError::TargetInvocationFailed as u32), + ); + env.events() + .publish((symbol_short!("failed"), message.route), message.id); + return Err(CrossContractError::TargetInvocationFailed); + } + }; + + append_audit( + &env, + message.id, + AuditAction::Routed, + MessageStatus::Queued, + None, + ); + + if let (Some(callback_contract), Some(callback_method)) = + (message.callback_contract.clone(), message.callback_method.clone()) + { + match env.try_invoke_contract::( + &callback_contract, + &callback_method, + build_callback_args(&env, &message, &response), + ) { + Ok(Ok(true)) => append_audit( + &env, + message.id, + AuditAction::CallbackSucceeded, + MessageStatus::Queued, + None, + ), + _ => { + message.status = MessageStatus::CallbackFailed; + message.processed_at = Some(env.ledger().timestamp()); + message.response = Some(response.clone()); + message.last_error = Some(CrossContractError::CallbackInvocationFailed as u32); + env.storage() + .persistent() + .set(&DataKey::Message(message.id), &message); + pop_queue(&env, message.id)?; + append_audit( + &env, + message.id, + AuditAction::CallbackFailed, + MessageStatus::CallbackFailed, + Some(CrossContractError::CallbackInvocationFailed as u32), + ); + env.events() + .publish((symbol_short!("cbfail"), message.route), message.id); + return Err(CrossContractError::CallbackInvocationFailed); + } + } + } + + message.status = MessageStatus::Delivered; + message.processed_at = Some(env.ledger().timestamp()); + message.response = Some(response.clone()); + message.last_error = None; + + env.storage() + .persistent() + .set(&DataKey::Message(message.id), &message); + pop_queue(&env, message.id)?; + append_audit( + &env, + message.id, + AuditAction::Delivered, + MessageStatus::Delivered, + None, + ); + env.events() + .publish((symbol_short!("done"), message.route), message.id); + + Ok(ProcessOutcome { + message_id: message.id, + status: MessageStatus::Delivered, + response: Some(response), + }) + } +} + +fn ensure_initialized(env: &Env) -> Result<(), CrossContractError> { + if !env.storage().instance().has(&DataKey::Admin) { + return Err(CrossContractError::NotInitialized); + } + Ok(()) +} + +fn require_admin(env: &Env, admin: &Address) -> Result<(), CrossContractError> { + ensure_initialized(env)?; + admin.require_auth(); + let stored: Address = env + .storage() + .instance() + .get(&DataKey::Admin) + .ok_or(CrossContractError::NotInitialized)?; + + if stored != *admin { + return Err(CrossContractError::Unauthorized); + } + Ok(()) +} + +fn validate_rate_limit( + window_secs: u64, + max_messages: u32, +) -> Result { + if window_secs == 0 || max_messages == 0 { + return Err(CrossContractError::InvalidConfig); + } + + Ok(RateLimitConfig { + window_secs, + max_messages, + }) +} + +fn get_rate_limit_config(env: &Env) -> Result { + env.storage() + .instance() + .get(&DataKey::RateLimit) + .ok_or(CrossContractError::NotInitialized) +} + +fn get_route_or_err(env: &Env, key: &Symbol) -> Result { + env.storage() + .persistent() + .get(&DataKey::Route(key.clone())) + .ok_or(CrossContractError::RouteNotFound) +} + +fn get_message_or_err(env: &Env, message_id: u64) -> Result { + env.storage() + .persistent() + .get(&DataKey::Message(message_id)) + .ok_or(CrossContractError::MessageNotFound) +} + +fn validate_callback_pair( + callback_contract: &Option
, + callback_method: &Option, +) -> Result<(), CrossContractError> { + match (callback_contract, callback_method) { + (None, None) | (Some(_), Some(_)) => Ok(()), + _ => Err(CrossContractError::InvalidCallbackConfig), + } +} + +fn resolve_callback( + route: &RouteConfig, + callback_contract: Option
, + callback_method: Option, +) -> Result<(Option
, Option), CrossContractError> { + validate_callback_pair(&callback_contract, &callback_method)?; + validate_callback_pair( + &route.default_callback_contract, + &route.default_callback_method, + )?; + + if callback_contract.is_some() { + Ok((callback_contract, callback_method)) + } else { + Ok(( + route.default_callback_contract.clone(), + route.default_callback_method.clone(), + )) + } +} + +fn next_message_id(env: &Env) -> u64 { + let next: u64 = env.storage().instance().get(&DataKey::NextMessageId).unwrap_or(1); + env.storage() + .instance() + .set(&DataKey::NextMessageId, &(next + 1)); + next +} + +fn push_queue(env: &Env, message_id: u64) { + let tail: u64 = env.storage().instance().get(&DataKey::QueueTail).unwrap_or(0); + env.storage() + .persistent() + .set(&DataKey::QueueSlot(tail), &message_id); + env.storage().instance().set(&DataKey::QueueTail, &(tail + 1)); +} + +fn peek_queue(env: &Env) -> Result { + let head: u64 = env.storage().instance().get(&DataKey::QueueHead).unwrap_or(0); + let tail: u64 = env.storage().instance().get(&DataKey::QueueTail).unwrap_or(0); + if head >= tail { + return Err(CrossContractError::QueueEmpty); + } + + env.storage() + .persistent() + .get(&DataKey::QueueSlot(head)) + .ok_or(CrossContractError::UnexpectedQueueState) +} + +fn pop_queue(env: &Env, expected_message_id: u64) -> Result<(), CrossContractError> { + let head: u64 = env.storage().instance().get(&DataKey::QueueHead).unwrap_or(0); + let actual: u64 = env + .storage() + .persistent() + .get(&DataKey::QueueSlot(head)) + .ok_or(CrossContractError::UnexpectedQueueState)?; + + if actual != expected_message_id { + return Err(CrossContractError::UnexpectedQueueState); + } + + env.storage().persistent().remove(&DataKey::QueueSlot(head)); + env.storage().instance().set(&DataKey::QueueHead, &(head + 1)); + Ok(()) +} + +fn queue_size(env: &Env) -> u64 { + let head: u64 = env.storage().instance().get(&DataKey::QueueHead).unwrap_or(0); + let tail: u64 = env.storage().instance().get(&DataKey::QueueTail).unwrap_or(0); + tail.saturating_sub(head) +} + +fn enforce_rate_limit(env: &Env, sender: &Address) -> Result<(), CrossContractError> { + let config = get_rate_limit_config(env)?; + let current_window = env.ledger().timestamp() / config.window_secs; + let key = DataKey::SenderWindow(sender.clone()); + let usage = env + .storage() + .persistent() + .get::<_, SenderWindow>(&key) + .unwrap_or(SenderWindow { + window: current_window, + count: 0, + }); + + let next = if usage.window == current_window { + if usage.count >= config.max_messages { + return Err(CrossContractError::RateLimited); + } + SenderWindow { + window: current_window, + count: usage.count + 1, + } + } else { + SenderWindow { + window: current_window, + count: 1, + } + }; + + env.storage().persistent().set(&key, &next); + Ok(()) +} + +fn build_target_args(env: &Env, message: &Message) -> Vec { + soroban_sdk::vec![ + env, + message.id.into_val(env), + message.sender.clone().into_val(env), + message.route.clone().into_val(env), + message.payload.clone().into_val(env) + ] +} + +fn build_callback_args(env: &Env, message: &Message, response: &Bytes) -> Vec { + soroban_sdk::vec![ + env, + message.id.into_val(env), + message.route.clone().into_val(env), + response.clone().into_val(env), + message.sender.clone().into_val(env) + ] +} + +fn append_audit( + env: &Env, + message_id: u64, + action: AuditAction, + status: MessageStatus, + error: Option, +) { + let key = DataKey::Audit(message_id); + let mut trail: Vec = env + .storage() + .persistent() + .get(&key) + .unwrap_or(Vec::new(env)); + + trail.push_back(AuditEntry { + timestamp: env.ledger().timestamp(), + action, + status, + error, + }); + + env.storage().persistent().set(&key, &trail); +} + +fn panic_with_error(env: &Env, err: CrossContractError) -> ! { + env.events().publish((symbol_short!("xerr"),), err as u32); + panic!("cross contract error"); +} + +#[cfg(test)] +mod test; diff --git a/contracts/cross_contract/src/test.rs b/contracts/cross_contract/src/test.rs new file mode 100644 index 0000000..1e0a302 --- /dev/null +++ b/contracts/cross_contract/src/test.rs @@ -0,0 +1,575 @@ +#![cfg(test)] +extern crate std; + +use super::*; +use soroban_sdk::{ + contract, contractimpl, contracttype, + testutils::{Address as _, Ledger, LedgerInfo}, + Address, Bytes, Env, Symbol, +}; + +#[contracttype] +#[derive(Clone)] +enum TargetDataKey { + CallCount, + AltCallCount, + LastSender, + LastRoute, + LastPayload, + MutationCount, +} + +#[contract] +struct MockTargetContract; + +#[contractimpl] +impl MockTargetContract { + pub fn receive(env: Env, _message_id: u64, sender: Address, route: Symbol, payload: Bytes) -> Bytes { + let calls: u32 = env.storage().instance().get(&TargetDataKey::CallCount).unwrap_or(0); + env.storage().instance().set(&TargetDataKey::CallCount, &(calls + 1)); + env.storage().instance().set(&TargetDataKey::LastSender, &sender); + env.storage().instance().set(&TargetDataKey::LastRoute, &route); + env.storage().instance().set(&TargetDataKey::LastPayload, &payload); + payload + } + + pub fn alternate( + env: Env, + _message_id: u64, + _sender: Address, + _route: Symbol, + payload: Bytes, + ) -> Bytes { + let calls: u32 = env.storage().instance().get(&TargetDataKey::AltCallCount).unwrap_or(0); + env.storage() + .instance() + .set(&TargetDataKey::AltCallCount, &(calls + 1)); + payload + } + + pub fn fail(_env: Env, _message_id: u64, _sender: Address, _route: Symbol, _payload: Bytes) -> Bytes { + panic!("target failed") + } + + pub fn mutate_then_fail( + env: Env, + _message_id: u64, + _sender: Address, + _route: Symbol, + _payload: Bytes, + ) -> Bytes { + let count: u32 = env.storage().instance().get(&TargetDataKey::MutationCount).unwrap_or(0); + env.storage() + .instance() + .set(&TargetDataKey::MutationCount, &(count + 1)); + panic!("target failed") + } + + pub fn get_call_count(env: Env) -> u32 { + env.storage().instance().get(&TargetDataKey::CallCount).unwrap_or(0) + } + + pub fn get_alt_call_count(env: Env) -> u32 { + env.storage() + .instance() + .get(&TargetDataKey::AltCallCount) + .unwrap_or(0) + } + + pub fn get_mutation_count(env: Env) -> u32 { + env.storage() + .instance() + .get(&TargetDataKey::MutationCount) + .unwrap_or(0) + } + + pub fn get_last_sender(env: Env) -> Option
{ + env.storage().instance().get(&TargetDataKey::LastSender) + } + + pub fn get_last_route(env: Env) -> Option { + env.storage().instance().get(&TargetDataKey::LastRoute) + } + + pub fn get_last_payload(env: Env) -> Option { + env.storage().instance().get(&TargetDataKey::LastPayload) + } +} + +#[contracttype] +#[derive(Clone)] +enum CallbackDataKey { + CallCount, + LastMessageId, + LastRoute, + LastResponse, + LastSender, +} + +#[contract] +struct MockCallbackContract; + +#[contractimpl] +impl MockCallbackContract { + pub fn accept( + env: Env, + message_id: u64, + route: Symbol, + response: Bytes, + sender: Address, + ) -> bool { + let calls: u32 = env.storage().instance().get(&CallbackDataKey::CallCount).unwrap_or(0); + env.storage() + .instance() + .set(&CallbackDataKey::CallCount, &(calls + 1)); + env.storage() + .instance() + .set(&CallbackDataKey::LastMessageId, &message_id); + env.storage().instance().set(&CallbackDataKey::LastRoute, &route); + env.storage() + .instance() + .set(&CallbackDataKey::LastResponse, &response); + env.storage().instance().set(&CallbackDataKey::LastSender, &sender); + true + } + + pub fn reject( + _env: Env, + _message_id: u64, + _route: Symbol, + _response: Bytes, + _sender: Address, + ) -> bool { + false + } + + pub fn get_call_count(env: Env) -> u32 { + env.storage().instance().get(&CallbackDataKey::CallCount).unwrap_or(0) + } + + pub fn get_last_response(env: Env) -> Option { + env.storage().instance().get(&CallbackDataKey::LastResponse) + } + + pub fn get_last_sender(env: Env) -> Option
{ + env.storage().instance().get(&CallbackDataKey::LastSender) + } +} + +fn bytes(env: &Env, input: &[u8]) -> Bytes { + Bytes::from_slice(env, input) +} + +fn setup() -> ( + Env, + CrossContractCommunicationClient<'static>, + Address, + Address, + MockTargetContractClient<'static>, + Address, + MockCallbackContractClient<'static>, +) { + let env = Env::default(); + env.mock_all_auths(); + env.ledger().set(LedgerInfo { + protocol_version: 20, + sequence_number: 1, + timestamp: 0, + network_id: Default::default(), + base_reserve: 10, + min_persistent_entry_ttl: 100, + min_temp_entry_ttl: 100, + max_entry_ttl: 100000, + }); + + let admin = Address::generate(&env); + let sender = Address::generate(&env); + + let contract_id = env.register_contract(None, CrossContractCommunication); + let client = CrossContractCommunicationClient::new(&env, &contract_id); + client.initialize(&admin, &60u64, &2u32).unwrap(); + + let target_id = env.register_contract(None, MockTargetContract); + let target_client = MockTargetContractClient::new(&env, &target_id); + + let callback_id = env.register_contract(None, MockCallbackContract); + let callback_client = MockCallbackContractClient::new(&env, &callback_id); + + (env, client, admin, sender, target_client, callback_id, callback_client) +} + +#[test] +fn registers_routes_and_queues_messages() { + let (env, client, admin, sender, target_client, callback_id, _) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "receive"), + &Some(callback_id.clone()), + &Some(Symbol::new(&env, "accept")), + ) + .unwrap(); + + let payload = bytes(&env, &[1, 2, 3]); + let message_id = client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &payload, + &true, + &None, + &None, + ) + .unwrap(); + + assert_eq!(message_id, 1); + assert_eq!(client.get_queue_size().unwrap(), 1); + + let message = client.get_message(&message_id).unwrap(); + assert_eq!(message.route, Symbol::new(&env, "quests")); + assert_eq!(message.sender, sender); + assert_eq!(message.payload, payload); + assert_eq!(message.callback_contract, Some(callback_id)); + + let audit = client.get_audit_trail(&message_id); + assert_eq!(audit.len(), 1); + assert_eq!(audit.get(0).unwrap().action, AuditAction::Enqueued); +} + +#[test] +fn routes_messages_to_expected_target_and_callback() { + let (env, client, admin, sender, target_client, callback_id, callback_client) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "receive"), + &Some(callback_id), + &Some(Symbol::new(&env, "accept")), + ) + .unwrap(); + + let payload = bytes(&env, &[9, 9, 9]); + let message_id = client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &payload.clone(), + &true, + &None, + &None, + ) + .unwrap(); + + let outcome = client.process_next().unwrap(); + assert_eq!(outcome.message_id, message_id); + assert_eq!(outcome.status, MessageStatus::Delivered); + assert_eq!(outcome.response, Some(payload.clone())); + + assert_eq!(target_client.get_call_count(), 1); + assert_eq!(target_client.get_last_sender().unwrap(), sender); + assert_eq!(target_client.get_last_route().unwrap(), Symbol::new(&env, "quests")); + assert_eq!(target_client.get_last_payload().unwrap(), payload); + + assert_eq!(callback_client.get_call_count(), 1); + assert_eq!(callback_client.get_last_response().unwrap(), bytes(&env, &[9, 9, 9])); + assert_eq!(callback_client.get_last_sender().unwrap(), sender); + + let message = client.get_message(&message_id).unwrap(); + assert_eq!(message.status, MessageStatus::Delivered); + assert_eq!(message.response, Some(bytes(&env, &[9, 9, 9]))); + assert_eq!(client.get_queue_size().unwrap(), 0); + + let audit = client.get_audit_trail(&message_id); + assert_eq!(audit.len(), 4); + assert_eq!(audit.get(1).unwrap().action, AuditAction::Routed); + assert_eq!(audit.get(2).unwrap().action, AuditAction::CallbackSucceeded); + assert_eq!(audit.get(3).unwrap().action, AuditAction::Delivered); +} + +#[test] +fn route_keys_dispatch_to_different_methods() { + let (env, client, admin, sender, target_client, _, _) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "primary"), + &target_client.address, + &Symbol::new(&env, "receive"), + &None, + &None, + ) + .unwrap(); + client.register_route( + &admin, + &Symbol::new(&env, "secondary"), + &target_client.address, + &Symbol::new(&env, "alternate"), + &None, + &None, + ) + .unwrap(); + + client.queue_message( + &sender, + &Symbol::new(&env, "secondary"), + &bytes(&env, &[7]), + &true, + &None, + &None, + ) + .unwrap(); + + let outcome = client.process_next().unwrap(); + assert_eq!(outcome.status, MessageStatus::Delivered); + assert_eq!(target_client.get_call_count(), 0); + assert_eq!(target_client.get_alt_call_count(), 1); +} + +#[test] +fn rate_limiting_blocks_excess_messages_and_resets_next_window() { + let (env, client, admin, sender, target_client, _, _) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "receive"), + &None, + &None, + ) + .unwrap(); + + assert!(client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[1]), + &true, + &None, + &None, + ) + .is_ok()); + assert!(client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[2]), + &true, + &None, + &None, + ) + .is_ok()); + assert_eq!( + client.queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[3]), + &true, + &None, + &None, + ), + Err(CrossContractError::RateLimited) + ); + + env.ledger().with_mut(|ledger| ledger.timestamp = 61); + + assert!(client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[4]), + &true, + &None, + &None, + ) + .is_ok()); + + let window = client.get_sender_window(&sender).unwrap(); + assert_eq!(window.window, 1); + assert_eq!(window.count, 1); +} + +#[test] +fn disabled_routes_are_rejected() { + let (env, client, admin, sender, target_client, _, _) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "receive"), + &None, + &None, + ) + .unwrap(); + client + .set_route_enabled(&admin, &Symbol::new(&env, "quests"), &false) + .unwrap(); + + let result = client.queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[1]), + &true, + &None, + &None, + ); + assert_eq!(result, Err(CrossContractError::RouteDisabled)); +} + +#[test] +fn non_atomic_target_failures_are_recorded_and_removed_from_queue() { + let (env, client, admin, sender, target_client, _, _) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "fail"), + &None, + &None, + ) + .unwrap(); + + let message_id = client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[8]), + &false, + &None, + &None, + ) + .unwrap(); + + let result = client.process_next(); + assert_eq!(result, Err(CrossContractError::TargetInvocationFailed)); + + let message = client.get_message(&message_id).unwrap(); + assert_eq!(message.status, MessageStatus::Failed); + assert_eq!( + message.last_error, + Some(CrossContractError::TargetInvocationFailed as u32) + ); + assert_eq!(client.get_queue_size().unwrap(), 0); + + let audit = client.get_audit_trail(&message_id); + assert_eq!(audit.len(), 2); + assert_eq!(audit.get(1).unwrap().action, AuditAction::Failed); +} + +#[test] +fn non_atomic_callback_failures_propagate_and_preserve_target_response() { + let (env, client, admin, sender, target_client, callback_id, callback_client) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "receive"), + &Some(callback_id), + &Some(Symbol::new(&env, "reject")), + ) + .unwrap(); + + let message_id = client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[4, 5]), + &false, + &None, + &None, + ) + .unwrap(); + + let result = client.process_next(); + assert_eq!(result, Err(CrossContractError::CallbackInvocationFailed)); + + assert_eq!(target_client.get_call_count(), 1); + assert_eq!(callback_client.get_call_count(), 0); + + let message = client.get_message(&message_id).unwrap(); + assert_eq!(message.status, MessageStatus::CallbackFailed); + assert_eq!(message.response, Some(bytes(&env, &[4, 5]))); + assert_eq!( + message.last_error, + Some(CrossContractError::CallbackInvocationFailed as u32) + ); +} + +#[test] +fn atomic_failures_keep_message_queued_and_rollback_target_side_effects() { + let (env, client, admin, sender, target_client, _, _) = setup(); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "mutate_then_fail"), + &None, + &None, + ) + .unwrap(); + + let message_id = client + .queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[1, 2]), + &true, + &None, + &None, + ) + .unwrap(); + + let result = client.process_next(); + assert_eq!(result, Err(CrossContractError::TargetInvocationFailed)); + + assert_eq!(target_client.get_mutation_count(), 0); + assert_eq!(client.get_queue_size().unwrap(), 1); + + let message = client.get_message(&message_id).unwrap(); + assert_eq!(message.status, MessageStatus::Queued); + assert!(message.response.is_none()); + + let audit = client.get_audit_trail(&message_id); + assert_eq!(audit.len(), 1); + assert_eq!(audit.get(0).unwrap().action, AuditAction::Enqueued); +} + +#[test] +fn callback_override_is_used_instead_of_default_route_callback() { + let (env, client, admin, sender, target_client, callback_id, callback_client) = setup(); + let alt_callback_id = env.register_contract(None, MockCallbackContract); + let alt_callback_client = MockCallbackContractClient::new(&env, &alt_callback_id); + + client.register_route( + &admin, + &Symbol::new(&env, "quests"), + &target_client.address, + &Symbol::new(&env, "receive"), + &Some(callback_id), + &Some(Symbol::new(&env, "accept")), + ) + .unwrap(); + + client.queue_message( + &sender, + &Symbol::new(&env, "quests"), + &bytes(&env, &[6]), + &true, + &Some(alt_callback_id), + &Some(Symbol::new(&env, "accept")), + ) + .unwrap(); + + let outcome = client.process_next().unwrap(); + assert_eq!(outcome.status, MessageStatus::Delivered); + assert_eq!(callback_client.get_call_count(), 0); + assert_eq!(alt_callback_client.get_call_count(), 1); +}