diff --git a/genesis.json b/genesis.json index fe44101..4668575 100644 --- a/genesis.json +++ b/genesis.json @@ -1,5 +1,5 @@ { - "chain_id": "minichain_testnet_1", + "chain_id": "minichain-default", "timestamp": 1716880000000, "difficulty": 4, "alloc": { diff --git a/main.py b/main.py index 44fecc6..2d02ed8 100644 --- a/main.py +++ b/main.py @@ -3,14 +3,14 @@ Usage: python main.py --port 9000 - python main.py --port 9001 --connect 127.0.0.1:9000 + python main.py --port 9001 --connect Commands (type in the terminal while the node is running): balance — show all account balances send — send coins to another address mine — mine a block from the mempool peers — show connected peers - connect : — connect to another node + connect — connect to another node address — show this node's public key help — show available commands quit — shut down the node @@ -26,6 +26,7 @@ from nacl.encoding import HexEncoder from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block +from minichain.rpc import JSONRPCServer from minichain.validators import is_valid_receiver from minichain.block import calculate_receipt_root @@ -113,7 +114,7 @@ def mine_and_process_block(chain, mempool, miner_pk): # Network message handler # ────────────────────────────────────────────── -def make_network_handler(chain, mempool): +def make_network_handler(chain, mempool, network): """Return an async callback that processes incoming P2P messages.""" async def handler(data): @@ -121,37 +122,45 @@ async def handler(data): payload = data.get("data") peer_addr = data.get("_peer_addr", "unknown") - if msg_type == "sync": - peer_host = peer_addr.rsplit(":", 1)[0] if ":" in peer_addr else peer_addr - peer_host = peer_host.strip("[]") - is_trusted = peer_addr in TRUSTED_PEERS or peer_host in TRUSTED_PEERS - is_localhost = peer_host in LOCALHOST_PEERS - if chain.state.accounts and not (is_trusted or is_localhost): - logger.warning("🔒 Rejected sync from untrusted peer %s", peer_addr) - return + if payload is None and msg_type in ("hello", "chain_request", "chain_response"): + return - # Merge remote state into local state (for accounts we don't have yet) - remote_accounts = payload.get("accounts") if isinstance(payload, dict) else None - if not isinstance(remote_accounts, dict): - logger.warning("🔒 Rejected sync from %s with invalid accounts payload", peer_addr) + if msg_type == "hello": + peer_chain_id = payload.get("chain_id") + peer_gen_hash = payload.get("genesis_hash") + if peer_chain_id != chain.chain_id: + logger.warning("🔒 Disconnecting peer %s: chain_id mismatch (got %s, expected %s)", peer_addr, peer_chain_id, chain.chain_id) + asyncio.create_task(network.disconnect_peer(peer_addr)) + return + if peer_gen_hash != chain.chain[0].hash: + logger.warning("🔒 Disconnecting peer %s: genesis hash mismatch", peer_addr) + asyncio.create_task(network.disconnect_peer(peer_addr)) return - for addr, acc in remote_accounts.items(): - if not isinstance(acc, dict): - logger.warning("🔒 Skipping malformed account %r from %s", addr, peer_addr) - continue - if addr not in chain.state.accounts: - chain.state.accounts[addr] = acc - logger.info("🔄 Synced account %s... (balance=%d)", addr[:12], acc.get("balance", 0)) - logger.info("🔄 Accepted state sync from %s — %d accounts", peer_addr, len(chain.state.accounts)) + logger.info("🔄 Handshake successful with %s", peer_addr) + peer_tip = payload.get("latest_block_index", 0) + if peer_tip > chain.last_block.index: + logger.info("📡 Peer %s is ahead (%d > %d). Initiating chunked sync...", peer_addr, peer_tip, chain.last_block.index) + req = {"type": "chain_request", "data": {"start_index": chain.last_block.index + 1, "limit": 500}} + asyncio.create_task(network._broadcast_raw(req)) elif msg_type == "tx": - tx = Transaction.from_dict(payload) - if mempool.add_transaction(tx): - logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) + try: + tx = Transaction.from_dict(payload) + if getattr(tx, "chain_id", None) != chain.chain_id: + logger.warning("Invalid chain_id in tx from %s", peer_addr) + return + if mempool.add_transaction(tx): + logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) + except Exception as e: + logger.warning("Invalid tx payload from %s: %s", peer_addr, e) elif msg_type == "block": - block = Block.from_dict(payload) + try: + block = Block.from_dict(payload) + except Exception as e: + logger.warning("Invalid block payload from %s: %s", peer_addr, e) + return if chain.add_block(block): logger.info("📥 Received Block #%d — added to chain", block.index) @@ -159,7 +168,74 @@ async def handler(data): # Drop only confirmed transactions so higher nonces can remain queued. mempool.remove_transactions(block.transactions) else: - logger.warning("📥 Received Block #%s — rejected", block.index) + if block.index > chain.last_block.index + 1: + logger.warning("📥 Received Block #%s — ahead of us (tip: %s). Requesting chunked sync...", block.index, chain.last_block.index) + req = {"type": "chain_request", "data": {"start_index": chain.last_block.index + 1, "limit": 500}} + asyncio.create_task(network._broadcast_raw(req)) + else: + logger.warning("📥 Received Block #%s — rejected. Fork detected, trigger reorg sync.", block.index) + # For a fork, request the full chain to use resolve_conflicts + req = {"type": "chain_request", "data": {"start_index": 0, "limit": 1000000}} # Request full chain for reorg + asyncio.create_task(network._broadcast_raw(req)) + + elif msg_type == "chain_request": + start_index = payload.get("start_index", 0) + limit = payload.get("limit", 500) + logger.info("📡 Peer requested blocks from %d (limit %d).", start_index, limit) + + if start_index < len(chain.chain): + blocks_slice = chain.chain[start_index : start_index + limit] + blocks_dicts = [b.to_dict() for b in blocks_slice] + else: + blocks_dicts = [] + + resp_payload = {"type": "chain_response", "data": {"blocks": blocks_dicts, "requested_limit": limit}} + asyncio.create_task(network._unicast_raw(peer_addr, resp_payload)) + + elif msg_type == "chain_response": + blocks_payload = payload.get("blocks", []) + requested_limit = payload.get("requested_limit", 500) + if not blocks_payload: + return + + new_chain = [] + try: + new_chain = [Block.from_dict(b) for b in blocks_payload] + except Exception as e: + logger.warning("❌ Failed to parse chain_response: %s", e) + return + + if new_chain: + # Distinguish between linear catch-up vs full reorg based on whether we received block 0 + if new_chain[0].index == 0: + # Fork / Reorg sync + success, orphans = chain.resolve_conflicts(new_chain) + if success: + logger.info("🔄 Reorg complete! Restoring %d orphaned txs to mempool.", len(orphans)) + for tx in orphans: + mempool.add_transaction(tx) + else: + # Linear Catch-up + all_added = True + for block in new_chain: + if block.index <= chain.last_block.index: + continue # Ignore already known blocks + if chain.add_block(block): + logger.info("📥 Synced Block #%d", block.index) + mempool.remove_transactions(block.transactions) + else: + logger.warning("❌ Sync failed at Block #%d. Fork detected. Requesting full chain.", block.index) + req = {"type": "chain_request", "data": {"start_index": 0, "limit": 1000000}} + asyncio.create_task(network._broadcast_raw(req)) + all_added = False + break + + # If we added all blocks and we hit the limit, request next batch + if all_added and len(new_chain) == requested_limit: + next_index = chain.last_block.index + 1 + logger.info("📡 Requesting next batch from index %d", next_index) + req = {"type": "chain_request", "data": {"start_index": next_index, "limit": requested_limit}} + asyncio.create_task(network._broadcast_raw(req)) return handler @@ -235,7 +311,7 @@ async def cli_loop(sk, pk, chain, mempool, network): continue nonce = chain.state.get_account(pk).get("nonce", 0) - tx = Transaction(sender=pk, receiver=receiver, amount=amount, nonce=nonce, fee=fee) + tx = Transaction(sender=pk, receiver=receiver, amount=amount, nonce=nonce, fee=fee, chain_id=chain.chain_id) tx.sign(sk) if mempool.add_transaction(tx): @@ -269,7 +345,7 @@ async def cli_loop(sk, pk, chain, mempool, network): continue nonce = chain.state.get_account(pk).get("nonce", 0) - tx = Transaction(sender=pk, receiver=None, amount=amount, nonce=nonce, fee=fee, data=code) + tx = Transaction(sender=pk, receiver=None, amount=amount, nonce=nonce, fee=fee, data=code, chain_id=chain.chain_id) tx.sign(sk) if mempool.add_transaction(tx): @@ -301,7 +377,7 @@ async def cli_loop(sk, pk, chain, mempool, network): continue nonce = chain.state.get_account(pk).get("nonce", 0) - tx = Transaction(sender=pk, receiver=receiver, amount=amount, nonce=nonce, fee=fee, data=payload) + tx = Transaction(sender=pk, receiver=receiver, amount=amount, nonce=nonce, fee=fee, data=payload, chain_id=chain.chain_id) tx.sign(sk) if mempool.add_transaction(tx): @@ -323,19 +399,14 @@ async def cli_loop(sk, pk, chain, mempool, network): # ── connect ── elif cmd == "connect": if len(parts) < 2: - print(" Usage: connect :") + print(" Usage: connect ") continue - try: - host, port_str = parts[1].rsplit(":", 1) - port = int(port_str) - except ValueError: - print(" Invalid format. Use host:port") - continue - success = await network.connect_to_peer(host, port) + maddr_str = parts[1] + success = await network.connect_to_peer(maddr_str) if success: - print(f" Connected to {host}:{port}") + print(f" Attempting to dial {maddr_str}...") else: - print(f" Failed to connect to {host}:{port}") + print(f" Failed to initiate connection to {maddr_str}") # ── address ── elif cmd == "address": @@ -389,23 +460,34 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data mempool = Mempool() network = P2PNetwork() - handler = make_network_handler(chain, mempool) + handler = make_network_handler(chain, mempool, network) network.register_handler(handler) + + rpc_server = JSONRPCServer(chain, mempool, network) - # When a new peer connects, send our state so they can sync + # When a new peer connects, send our hello so they can handshake async def on_peer_connected(writer): import json as _json sync_msg = _json.dumps({ - "type": "sync", - "data": {"accounts": chain.state.accounts} + "type": "hello", + "data": { + "chain_id": chain.chain_id, + "genesis_hash": chain.chain[0].hash, + "latest_block_index": chain.last_block.index, + "latest_block_hash": chain.last_block.hash + } }) + "\n" writer.write(sync_msg.encode()) await writer.drain() - logger.info("🔄 Sent state sync to new peer") + logger.info("🔄 Sent hello handshake to new peer") network.register_on_peer_connected(on_peer_connected) await network.start(port=port, host=host) + + # Start RPC server on a port correlated to the node port (e.g. 8545 if P2P is 9000) + rpc_port = 8545 + (port - 9000) + await rpc_server.start(host="127.0.0.1", port=rpc_port) # Fund this node's wallet so it can transact in the demo if fund > 0: @@ -414,11 +496,7 @@ async def on_peer_connected(writer): # Connect to a seed peer if requested if connect_to: - try: - host, peer_port = connect_to.rsplit(":", 1) - await network.connect_to_peer(host, int(peer_port)) - except ValueError: - logger.error("Invalid --connect format. Use host:port") + await network.connect_to_peer(connect_to) try: await cli_loop(sk, pk, chain, mempool, network) @@ -431,6 +509,8 @@ async def on_peer_connected(writer): logger.info("Chain saved to '%s'", datadir) except Exception as e: logger.error("Failed to save chain during shutdown: %s", e) + + await rpc_server.stop() await network.stop() @@ -438,7 +518,7 @@ def main(): parser = argparse.ArgumentParser(description="MiniChain Node — Testnet Demo") parser.add_argument("--host", type=str, default="127.0.0.1", help="Host/IP to bind the P2P server (default: 127.0.0.1)") parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)") - parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (host:port)") + parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (multiaddr)") parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)") parser.add_argument("--datadir", type=str, default=None, help="Directory to save/load blockchain state (enables persistence)") args = parser.parse_args() diff --git a/minichain/chain.py b/minichain/chain.py index b37dcad..1aa3917 100644 --- a/minichain/chain.py +++ b/minichain/chain.py @@ -34,6 +34,7 @@ class Blockchain: def __init__(self, genesis_path="genesis.json"): self.chain = [] self.state = State() + self.chain_id = "minichain-default" self._lock = threading.RLock() self._create_genesis_block(genesis_path) @@ -63,6 +64,9 @@ def _create_genesis_block(self, genesis_path): account = self.state.get_account(address) account['balance'] = balance + self.chain_id = config.get("chain_id", "minichain-default") + self.state.chain_id = self.chain_id + timestamp = config.get("timestamp") difficulty = config.get("difficulty") @@ -89,6 +93,9 @@ def _create_genesis_block(self, genesis_path): genesis_block.hash = computed_hash self.chain.append(genesis_block) + + # Snapshot the state exactly after genesis allocation for clean reorg rebuilds + self._genesis_state_snapshot = self.state.snapshot() @property def last_block(self): @@ -98,6 +105,16 @@ def last_block(self): with self._lock: # Acquire lock for thread-safe access return self.chain[-1] + def get_total_work(self, chain_list=None): + """ + Calculates the cumulative PoW of a chain. + Work is proportional to 2^difficulty. + """ + if chain_list is None: + with self._lock: + chain_list = self.chain + return sum(2 ** (block.difficulty or 1) for block in chain_list) + def add_block(self, block): """ Validates and adds a block to the chain if all transactions succeed. @@ -113,6 +130,7 @@ def add_block(self, block): # Validate transactions on a temporary state copy temp_state = self.state.copy() + temp_state.chain_id = self.chain_id receipts = [] for tx in block.transactions: @@ -147,3 +165,77 @@ def add_block(self, block): self.state = temp_state self.chain.append(block) return True + + def resolve_conflicts(self, new_chain_list) -> tuple[bool, list]: + """ + Evaluates a competing chain. If it has strictly greater cumulative work, + attempts a reorg. Rebuilds state from genesis to guarantee validity. + Returns: (success_bool, list_of_orphaned_transactions) + """ + if not new_chain_list: + return False, [] + + with self._lock: + current_work = self.get_total_work() + new_work = self.get_total_work(new_chain_list) + + if new_work <= current_work: + logger.debug("Incoming chain (work: %s) is not heavier than local chain (work: %s). Rejecting.", new_work, current_work) + return False, [] + + # 1. Verify genesis block matches + if new_chain_list[0].hash != self.chain[0].hash: + logger.warning("Reorg failed: Genesis hash mismatch.") + return False, [] + + logger.info("Incoming chain is heavier (%s > %s). Attempting reorg...", new_work, current_work) + + # 2. Snapshot current chain in case reorg fails validation + original_chain = list(self.chain) + + # 3. Rebuild state entirely from genesis using the new chain + temp_state = State() + temp_state.chain_id = self.chain_id + temp_state.restore(self._genesis_state_snapshot) + + # Verify and apply blocks 1 to N + for i in range(1, len(new_chain_list)): + prev_block = new_chain_list[i-1] + block = new_chain_list[i] + + try: + validate_block_link_and_hash(prev_block, block) + except ValueError as exc: + logger.warning("Reorg failed at block %s: %s", block.index, exc) + return False, [] + + receipts = [] + for tx in block.transactions: + receipt = temp_state.validate_and_apply(tx) + if receipt is None: + logger.warning("Reorg failed: Transaction validation failed in block %s", block.index) + return False, [] + receipts.append(receipt) + + total_fees = sum(getattr(r, 'gas_used', 0) for r in receipts) + if block.miner: + temp_state.credit_mining_reward(block.miner, reward=temp_state.DEFAULT_MINING_REWARD + total_fees) + + computed_receipt_root = calculate_receipt_root(receipts) + if block.receipt_root != computed_receipt_root: + logger.warning("Reorg failed: Invalid receipt root at block %s. Expected %s, got %s", block.index, computed_receipt_root, block.receipt_root) + return False, [] + + if block.state_root != temp_state.state_root(): + logger.warning("Reorg failed: Invalid state root at block %s", block.index) + return False, [] + + # 4. Success! Compute orphaned transactions. + old_txs = {tx.tx_id: tx for b in original_chain[1:] for tx in b.transactions} + new_tx_ids = {tx.tx_id for b in new_chain_list[1:] for tx in b.transactions} + orphans = [tx for tx_id, tx in old_txs.items() if tx_id not in new_tx_ids] + + self.chain = new_chain_list + self.state = temp_state + logger.info("Reorg successful! Switched to new chain tip: Block %s", self.last_block.index) + return True, orphans diff --git a/minichain/contract.py b/minichain/contract.py index 4ad8233..7aca420 100644 --- a/minichain/contract.py +++ b/minichain/contract.py @@ -38,13 +38,30 @@ def _safe_exec_worker(code, globals_dict, context_dict, result_queue, gas_limit) try: import resource # Limit CPU time (seconds) and memory (bytes) - example values - resource.setrlimit(resource.RLIMIT_CPU, (2, 2)) # Align with p.join timeout (2 seconds) + resource.setrlimit(resource.RLIMIT_CPU, (10, 10)) # Align with p.join timeout (10 seconds) resource.setrlimit(resource.RLIMIT_AS, (100 * 1024 * 1024, 100 * 1024 * 1024)) except ImportError: logger.warning("Resource module not available. Contract will run without OS-level resource limits.") except (OSError, ValueError) as e: logger.warning("Failed to set resource limits: %s", e) + transfers = [] + + def transfer_out(address, amount): + if not isinstance(amount, int) or amount <= 0: + raise ValueError("Invalid transfer amount") + if not isinstance(address, str): + raise ValueError("Invalid address type") + if not address or len(address) not in (40, 64): + raise ValueError("Invalid address format") + try: + int(address, 16) + except ValueError: + raise ValueError("Invalid address format") + transfers.append({"to": address, "amount": amount}) + + globals_dict["__builtins__"]["transfer_out"] = transfer_out + meter = GasMeter(gas_limit) sys.settrace(meter.trace_calls) @@ -54,7 +71,7 @@ def _safe_exec_worker(code, globals_dict, context_dict, result_queue, gas_limit) sys.settrace(None) gas_used = meter.initial_gas - meter.gas - result_queue.put({"status": "success", "storage": context_dict.get("storage"), "gas_used": gas_used}) + result_queue.put({"status": "success", "storage": context_dict.get("storage"), "transfers": transfers, "gas_used": gas_used}) except OutOfGasException as e: result_queue.put({"status": "error", "error": "Out of gas!", "gas_used": gas_limit}) except Exception as e: @@ -147,7 +164,7 @@ def execute(self, contract_address, sender_address, payload, amount, gas_limit): args=(code, globals_for_exec, context, queue, gas_limit) ) p.start() - p.join(timeout=2) # 2 second timeout + p.join(timeout=10) # 10 second timeout if p.is_alive(): p.kill() @@ -172,13 +189,7 @@ def execute(self, contract_address, sender_address, payload, amount, gas_limit): logger.error("Contract storage not JSON serializable") return {"success": False, "gas_used": result.get("gas_used", gas_limit), "error": "Storage not JSON serializable"} - # Commit updated storage only after successful execution - self.state.update_contract_storage( - contract_address, - result["storage"] - ) - - return {"success": True, "gas_used": result["gas_used"], "error": None} + return {"success": True, "gas_used": result["gas_used"], "transfers": result.get("transfers", []), "storage": result["storage"], "error": None} except Exception as e: logger.error("Contract Execution Failed", exc_info=True) diff --git a/minichain/mempool.py b/minichain/mempool.py index 6e3d8d9..a6e7630 100644 --- a/minichain/mempool.py +++ b/minichain/mempool.py @@ -5,8 +5,7 @@ class Mempool: def __init__(self, max_size=1000, transactions_per_block=100): - self._pool = {} - self._size = 0 + self._list = [] # Single sorted list self._lock = threading.Lock() self.max_size = max_size self.transactions_per_block = transactions_per_block @@ -17,64 +16,62 @@ def add_transaction(self, tx): return False with self._lock: - existing = self._pool.get(tx.sender, {}).get(tx.nonce) + existing_idx = None + i_min = 0 + i_max = len(self._list) + + for i, existing_tx in enumerate(self._list): + if existing_tx.sender == tx.sender: + if existing_tx.nonce == tx.nonce: + existing_idx = i + elif existing_tx.nonce < tx.nonce: + # Must insert AFTER the largest lower-nonce transaction + i_min = max(i_min, i + 1) + elif existing_tx.nonce > tx.nonce: + # Must insert BEFORE the smallest higher-nonce transaction + i_max = min(i_max, i) - if existing: - if existing.tx_id == tx.tx_id: + if existing_idx is not None: + existing_tx = self._list[existing_idx] + if existing_tx.tx_id == tx.tx_id: logger.warning("Mempool: Duplicate transaction rejected %s", tx.tx_id) return False - # Fix: Guard against older replacements (e.g. rejected block restore) - # Only allow overwrite if it's a genuinely newer replacement - if tx.timestamp <= existing.timestamp: + if tx.timestamp <= existing_tx.timestamp: logger.warning("Mempool: Ignoring older replacement %s", tx.tx_id) return False + self._list.pop(existing_idx) + if i_max > existing_idx: + i_max -= 1 + if i_min > existing_idx: + i_min -= 1 else: - if self._size >= self.max_size: + if len(self._list) >= self.max_size: logger.warning("Mempool: Full, rejecting transaction") return False - self._size += 1 - self._pool.setdefault(tx.sender, {})[tx.nonce] = tx - return True - - def get_transactions_for_block(self): - with self._lock: - snapshot = {s: list(pool.values()) for s, pool in self._pool.items()} - for txs in snapshot.values(): - txs.sort(key=lambda t: t.nonce) + i_min = min(i_min, i_max) - selected = [] - while len(selected) < self.transactions_per_block: - best_tx = None - best_sender = None + # Insert before the first tx in [i_min, i_max] that has a lower fee + insert_idx = i_max + for j in range(i_min, i_max): + if getattr(self._list[j], 'fee', 0) < getattr(tx, 'fee', 0): + insert_idx = j + break - for sender, txs in snapshot.items(): - if txs: - current_criteria = (-getattr(txs[0], 'fee', 0), txs[0].timestamp, sender, txs[0].nonce) - best_criteria = (-getattr(best_tx, 'fee', 0), best_tx.timestamp, best_sender, best_tx.nonce) if best_tx else None - if best_tx is None or current_criteria < best_criteria: - best_tx = txs[0] - best_sender = sender - - if not best_tx: - break - - selected.append(best_tx) - snapshot[best_sender].pop(0) + self._list.insert(insert_idx, tx) + return True - return selected + def get_transactions_for_block(self): + with self._lock: + # O(k) retrieval, where k = transactions_per_block! The list is strictly ordered upon insertion. + return list(self._list[:self.transactions_per_block]) def remove_transactions(self, transactions): with self._lock: - for tx in transactions: - pool = self._pool.get(tx.sender) - if pool and tx.nonce in pool: - del pool[tx.nonce] - self._size -= 1 - if not pool: - del self._pool[tx.sender] + keys_to_remove = {(tx.sender, tx.nonce) for tx in transactions} + self._list = [tx for tx in self._list if (tx.sender, tx.nonce) not in keys_to_remove] def __len__(self): with self._lock: - return self._size + return len(self._list) diff --git a/minichain/p2p.py b/minichain/p2p.py index 7bb1e34..28efe38 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -1,390 +1,208 @@ """ -Minimal TCP-based P2P network layer for MiniChain testnet demo. - -Each node runs an asyncio TCP server and can connect to peers. -Messages are newline-delimited JSON. +Libp2p-based P2P network layer for MiniChain. +Runs libp2p via trio in a background thread to stay compatible with asyncio. """ import asyncio import json import logging - +import threading +import trio +import queue + +from libp2p import new_host +TProtocol = str +from libp2p.peer.peerinfo import info_from_p2p_addr +from multiaddr import Multiaddr from .serialization import canonical_json_hash, canonical_json_dumps -from .validators import is_valid_receiver logger = logging.getLogger(__name__) -TOPIC = "minichain-global" -SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block"} - +SUPPORTED_MESSAGE_TYPES = {"hello", "tx", "block", "chain_request", "chain_response"} +PROTOCOL_ID = TProtocol("/minichain/1.0.0") class P2PNetwork: - """ - Lightweight peer-to-peer networking using asyncio TCP streams. - - JSON wire format (one JSON object per line): - {"type": "sync" | "tx" | "block", "data": {...}} - """ + """Lightweight peer-to-peer networking using libp2p.""" def __init__(self, handler_callback=None): - self._handler_callback = None - if handler_callback is not None: - self.register_handler(handler_callback) - self._peers: list[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = [] - self._server: asyncio.Server | None = None - self._port: int = 0 - self._listen_tasks: list[asyncio.Task] = [] + self._handler_callback = handler_callback self._on_peer_connected = None self._seen_tx_ids = set() self._seen_block_hashes = set() + self._to_trio = queue.Queue() + self._to_asyncio = queue.Queue() + self._peer_count = 0 + self._peer_count_lock = threading.Lock() def register_handler(self, handler_callback): - if not callable(handler_callback): - raise ValueError("handler_callback must be callable") self._handler_callback = handler_callback def register_on_peer_connected(self, handler_callback): - if not callable(handler_callback): - raise ValueError("handler_callback must be callable") self._on_peer_connected = handler_callback - async def _notify_peer_connected(self, writer, error_message): - if self._on_peer_connected: - try: - await self._on_peer_connected(writer) - except Exception: - logger.exception(error_message) - async def start(self, port: int = 9000, host: str = "127.0.0.1"): - """Start listening for incoming peer connections on the given port.""" - self._port = port - self._server = await asyncio.start_server(self._handle_incoming, host, port) - logger.info("Network: Listening on %s:%d", host, port) + self.port = port + self.host_addr = host + self.loop = asyncio.get_running_loop() + + threading.Thread(target=trio.run, args=(self._trio_main,), daemon=True).start() + asyncio.create_task(self._asyncio_reader()) + logger.info(f"Network: Starting libp2p on port {port}") async def stop(self): - """Gracefully shut down the server and disconnect all peers.""" logger.info("Network: Shutting down") - for task in self._listen_tasks: - task.cancel() - if self._listen_tasks: - await asyncio.gather(*self._listen_tasks, return_exceptions=True) - self._listen_tasks.clear() - for _, writer in self._peers: - try: - writer.close() - await writer.wait_closed() - except Exception: - pass - self._peers.clear() - if self._server: - self._server.close() - await self._server.wait_closed() - self._server = None - - async def connect_to_peer(self, host: str, port: int) -> bool: - """Actively connect to another MiniChain node.""" - try: - reader, writer = await asyncio.open_connection(host, port) - self._peers.append((reader, writer)) - task = asyncio.create_task( - self._listen_to_peer(reader, writer, f"{host}:{port}") - ) - self._listen_tasks.append(task) - await self._notify_peer_connected(writer, "Network: Error during outbound peer sync") - logger.info("Network: Connected to peer %s:%d", host, port) - return True - except Exception as exc: - logger.error("Network: Failed to connect to %s:%d — %s", host, port, exc) - return False - - async def _handle_incoming( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - ): - """Accept an incoming peer connection.""" - peername = writer.get_extra_info("peername") - addr = f"{peername[0]}:{peername[1]}" if peername else "unknown" - logger.info("Network: Incoming peer connection from %s", addr) - self._peers.append((reader, writer)) - task = asyncio.create_task(self._listen_to_peer(reader, writer, addr)) - self._listen_tasks.append(task) - await self._notify_peer_connected(writer, "Network: Error during peer sync") - - def _validate_transaction_payload(self, payload): - if not isinstance(payload, dict): - return False - - required_fields = { - "sender": str, - "amount": int, - "fee": int, - "nonce": int, - "timestamp": int, - "signature": str, - } - optional_fields = { - "receiver": (str, type(None)), - "data": (str, type(None)), - } - allowed_fields = set(required_fields) | set(optional_fields) - - if set(payload) != allowed_fields: - return False - - for field, expected_type in required_fields.items(): - if not isinstance(payload.get(field), expected_type): - return False - - for field, expected_type in optional_fields.items(): - if not isinstance(payload.get(field), expected_type): - return False - - if payload["amount"] < 0: - return False - - receiver = payload.get("receiver") - if receiver is not None and not is_valid_receiver(receiver): - return False - - return True - - def _validate_sync_payload(self, payload): - if not isinstance(payload, dict) or set(payload) != {"accounts"}: - return False - - accounts = payload["accounts"] - if not isinstance(accounts, dict): - return False - - for address, account in accounts.items(): - if not isinstance(address, str) or not isinstance(account, dict): - return False - required = {"balance", "nonce", "code", "storage"} - if set(account) != required: - return False - if not isinstance(account["balance"], int): - return False - if not isinstance(account["nonce"], int): - return False - if not isinstance(account["code"], (str, type(None))): - return False - if not isinstance(account["storage"], dict): - return False + self._to_trio.put(("STOP", None)) + async def connect_to_peer(self, maddr_str: str) -> bool: + self._to_trio.put(("CONNECT", maddr_str)) return True - def _validate_block_payload(self, payload): - if not isinstance(payload, dict): - return False - - required_fields = { - "index": int, - "previous_hash": str, - "merkle_root": (str, type(None)), - "state_root": str, - "receipt_root": (str, type(None)), - "receipts": list, - "transactions": list, - "timestamp": int, - "difficulty": (int, type(None)), - "nonce": int, - "hash": str, - } - optional_fields = {"miner": (str, type(None))} - allowed_fields = set(required_fields) | set(optional_fields) - - if not set(required_fields).issubset(payload): - return False - - if not set(payload).issubset(allowed_fields): - return False - - for field, expected_type in required_fields.items(): - if not isinstance(payload.get(field), expected_type): - return False - - if "miner" in payload and not isinstance(payload["miner"], (str, type(None))): - return False - - for r_payload in payload.get("receipts", []): - if not isinstance(r_payload, dict): - return False - if "tx_hash" not in r_payload or not isinstance(r_payload["tx_hash"], str): - return False - if "status" not in r_payload or not isinstance(r_payload["status"], int): - return False - if "gas_used" in r_payload and not isinstance(r_payload["gas_used"], int): - return False - if "error_message" in r_payload and not isinstance(r_payload["error_message"], (str, type(None))): - return False - if "logs" in r_payload and not isinstance(r_payload["logs"], list): - return False - if "contract_address" in r_payload and not isinstance(r_payload["contract_address"], (str, type(None))): - return False - - return all( - self._validate_transaction_payload(tx_payload) - for tx_payload in payload["transactions"] - ) - - def _validate_message(self, message): - # FIX: Check if message is a dictionary first to prevent crashes - if not isinstance(message, dict): - logger.warning("Network: Received non-dict message") - return False - required_fields = {"type", "data"} - if not required_fields.issubset(set(message)): - return False - if not set(message).issubset(required_fields): - return False - - msg_type = message.get("type") - payload = message.get("data") - - if msg_type not in SUPPORTED_MESSAGE_TYPES: - return False - - validators = { - "sync": self._validate_sync_payload, - "tx": self._validate_transaction_payload, - "block": self._validate_block_payload, - } - return validators[msg_type](payload) - def _message_id(self, msg_type, payload): - if msg_type == "tx": - return canonical_json_hash(payload) - if msg_type == "block": - return payload["hash"] + if msg_type == "tx": return canonical_json_hash(payload) + if msg_type == "block": return payload["hash"] return None - def _mark_seen(self, msg_type, payload): - message_id = self._message_id(msg_type, payload) - if message_id is None: - return - if msg_type == "tx": - self._seen_tx_ids.add(message_id) - elif msg_type == "block": - self._seen_block_hashes.add(message_id) - def _is_duplicate(self, msg_type, payload): - message_id = self._message_id(msg_type, payload) - if message_id is None: - return False - if msg_type == "tx": - return message_id in self._seen_tx_ids - if msg_type == "block": - return message_id in self._seen_block_hashes - return False - - async def _listen_to_peer( - self, - reader: asyncio.StreamReader, - writer: asyncio.StreamWriter, - addr: str, - ): - """Read newline-delimited JSON messages from a peer.""" - try: - while True: - line = await reader.readline() - if not line: - break - try: - data = json.loads(line.decode().strip()) - except (json.JSONDecodeError, UnicodeDecodeError): - logger.warning("Network: Malformed message from %s", addr) - continue - if not self._validate_message(data): - logger.warning("Network: Invalid message schema from %s", addr) - continue + mid = self._message_id(msg_type, payload) + if not mid: return False + return mid in (self._seen_tx_ids if msg_type == "tx" else self._seen_block_hashes) - msg_type = data["type"] - payload = data["data"] - if self._is_duplicate(msg_type, payload): - logger.info("Network: Duplicate %s ignored from %s", msg_type, addr) - continue - self._mark_seen(msg_type, payload) - data["_peer_addr"] = addr - - if self._handler_callback: - try: - await self._handler_callback(data) - except Exception: - logger.exception( - "Network: Handler error for message from %s", addr - ) - except asyncio.CancelledError: - pass - except ConnectionResetError: - pass - finally: - logger.info("Network: Peer %s disconnected", addr) - try: - writer.close() - await writer.wait_closed() - except Exception: - pass - if (reader, writer) in self._peers: - self._peers.remove((reader, writer)) + def _mark_seen(self, msg_type, payload): + mid = self._message_id(msg_type, payload) + if mid: (self._seen_tx_ids if msg_type == "tx" else self._seen_block_hashes).add(mid) async def _broadcast_raw(self, payload: dict): - """Send a JSON message to every connected peer concurrently.""" - line = (canonical_json_dumps(payload) + "\n").encode() - peers_snapshot = list(self._peers) - - async def _send(reader, writer): - """Send to a single peer; return the pair on failure.""" - try: - writer.write(line) - await writer.drain() - return None - except Exception: - return (reader, writer) + self._to_trio.put(("BROADCAST", payload)) - results = await asyncio.gather( - *(_send(r, w) for r, w in peers_snapshot) - ) - - for pair in results: - if pair is None: - continue - _reader, writer = pair - try: - writer.close() - await writer.wait_closed() - except Exception: - pass - if pair in self._peers: - self._peers.remove(pair) + async def _unicast_raw(self, target_addr: str, payload: dict): + self._to_trio.put(("UNICAST", (target_addr, payload))) async def broadcast_transaction(self, tx): - sender = getattr(tx, "sender", "") - logger.info("Network: Broadcasting Tx from %s...", sender[:8]) - try: - payload = {"type": "tx", "data": tx.to_dict()} - except (TypeError, ValueError) as exc: - logger.error("Network: Failed to serialize tx: %s", exc) - return + payload = {"type": "tx", "data": tx.to_dict()} self._mark_seen("tx", payload["data"]) await self._broadcast_raw(payload) async def broadcast_block(self, block): - """Broadcast a block. Block must have miner populated.""" - logger.info("Network: Broadcasting Block #%d", block.index) + payload = {"type": "block", "data": block.to_dict()} + self._mark_seen("block", payload["data"]) + await self._broadcast_raw(payload) - # Enforce that the block is fully populated before it enters the network layer - if getattr(block, "miner", None) is None: - raise ValueError("block.miner must be populated before broadcasting") + async def broadcast_chain_request(self): + await self._broadcast_raw({"type": "chain_request", "data": {}}) - payload = { - "type": "block", - "data": block.to_dict() - } + async def send_chain_response(self, blocks_dicts, peer_stream=None): + await self._broadcast_raw({"type": "chain_response", "data": {"blocks": blocks_dicts}}) - self._mark_seen("block", payload["data"]) - await self._broadcast_raw(payload) + async def disconnect_peer(self, peer_addr): + self._to_trio.put(("DISCONNECT", peer_addr)) @property def peer_count(self) -> int: - return len(self._peers) + with self._peer_count_lock: + return self._peer_count + + async def _asyncio_reader(self): + while True: + try: msg = await self.loop.run_in_executor(None, self._to_asyncio.get) + except Exception: continue + + if msg[0] == "MSG": + data = msg[1] + msg_type, payload = data.get("type"), data.get("data") + if msg_type not in SUPPORTED_MESSAGE_TYPES or self._is_duplicate(msg_type, payload): continue + self._mark_seen(msg_type, payload) + if self._handler_callback: await self._handler_callback(data) + elif msg[0] == "PEER_CONNECTED": + class MockWriter: + def write(self, data): self.data = data + async def drain(self): pass + if self._on_peer_connected: + writer = MockWriter() + await self._on_peer_connected(writer) + if hasattr(writer, 'data'): + try: + req = json.loads(writer.data.decode().strip()) + await self._broadcast_raw(req) + except Exception: pass + + async def _trio_main(self): + host = new_host() + listen_addr = Multiaddr(f"/ip4/{self.host_addr}/tcp/{self.port}") + await host.get_network().listen(listen_addr) + print(f" Network Multiaddr: {listen_addr}/p2p/{host.get_id().to_string()}") + + streams = [] + + async def stream_handler(stream): + streams.append(stream) + with self._peer_count_lock: + self._peer_count += 1 + peer_id = stream.muxed_conn.peer_id + addr = f"peer:{peer_id}" + self._to_asyncio.put(("PEER_CONNECTED", None)) + try: + while True: + data = await stream.read(4096) + if not data: break + for line in data.split(b'\n'): + if not line: continue + try: + msg = json.loads(line.decode().strip()) + msg["_peer_addr"] = addr + self._to_asyncio.put(("MSG", msg)) + except Exception: pass + except Exception: pass + if stream in streams: + streams.remove(stream) + with self._peer_count_lock: + self._peer_count -= 1 + + host.set_stream_handler(PROTOCOL_ID, stream_handler) + + async def check_queue(): + while True: + try: + while not self._to_trio.empty(): + cmd, arg = self._to_trio.get_nowait() + if cmd == "STOP": return True + elif cmd == "CONNECT": + try: + maddr = Multiaddr(arg) + info = info_from_p2p_addr(maddr) + await host.connect(info) + stream = await host.new_stream(info.peer_id, [PROTOCOL_ID]) + host.get_network().nursery.start_soon(stream_handler, stream) + except Exception as e: + logger.error(f"Dial error: {e}") + elif cmd == "BROADCAST": + msg = (canonical_json_dumps(arg) + "\n").encode() + for s in list(streams): + try: await s.write(msg) + except Exception: pass + elif cmd == "UNICAST": + target_addr, payload = arg + msg = (canonical_json_dumps(payload) + "\n").encode() + for s in list(streams): + addr = f"peer:{s.muxed_conn.peer_id}" + if addr == target_addr: + try: await s.write(msg) + except Exception: pass + elif cmd == "DISCONNECT": + for s in list(streams): + addr = f"peer:{s.muxed_conn.peer_id}" + if addr == arg: + try: await s.reset() + except Exception: pass + if s in streams: + streams.remove(s) + with self._peer_count_lock: + self._peer_count -= 1 + except Exception: pass + await trio.sleep(0.1) + + async with trio.open_nursery() as nursery: + async def run_monitor(): + if await check_queue(): + await host.close() + nursery.cancel_scope.cancel() + nursery.start_soon(run_monitor) diff --git a/minichain/rpc.py b/minichain/rpc.py new file mode 100644 index 0000000..58e4733 --- /dev/null +++ b/minichain/rpc.py @@ -0,0 +1,93 @@ +import logging +import json +import asyncio +from aiohttp import web +from minichain.transaction import Transaction + +logger = logging.getLogger(__name__) + +class JSONRPCServer: + def __init__(self, chain, mempool, network): + self.chain = chain + self.mempool = mempool + self.network = network + self.app = web.Application() + self.app.add_routes([web.post('/', self.handle_rpc)]) + + async def start(self, host="127.0.0.1", port=8545): + self.runner = web.AppRunner(self.app) + await self.runner.setup() + self.site = web.TCPSite(self.runner, host, port) + await self.site.start() + logger.info("🚀 JSON-RPC Server running on http://%s:%d", host, port) + + async def stop(self): + if hasattr(self, 'site'): + await self.site.stop() + if hasattr(self, 'runner'): + await self.runner.cleanup() + + async def handle_rpc(self, request): + try: + req_data = await request.json() + except json.JSONDecodeError: + return web.json_response({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}, "id": None}) + + if isinstance(req_data, list): + responses = [] + for req in req_data: + responses.append(await self._process_single(req)) + return web.json_response(responses) + else: + response = await self._process_single(req_data) + return web.json_response(response) + + async def _process_single(self, req): + if not isinstance(req, dict) or "method" not in req or req.get("jsonrpc") != "2.0": + return {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid Request"}, "id": req.get("id") if isinstance(req, dict) else None} + + method = req["method"] + params = req.get("params", []) + req_id = req.get("id") + + try: + if method == "mc_blockNumber": + result = self.chain.last_block.index + elif method == "mc_getBlockByNumber": + if not params: + raise ValueError("Missing block number") + idx = params[0] + if idx == "latest": + block = self.chain.last_block + else: + idx = int(idx) + if idx < 0 or idx >= len(self.chain.chain): + block = None + else: + block = self.chain.chain[idx] + result = block.to_dict() if block else None + elif method == "mc_getBalance": + if not params: + raise ValueError("Missing address") + address = params[0] + account = self.chain.state.get_account(address) + result = account["balance"] if account else 0 + elif method == "mc_sendTransaction": + if not params: + raise ValueError("Missing transaction payload") + tx_data = params[0] + tx = Transaction.from_dict(tx_data) + if not tx.verify(): + raise ValueError("Invalid signature") + if self.mempool.add_transaction(tx): + asyncio.create_task(self.network.broadcast_transaction(tx)) + result = tx.tx_id + else: + raise ValueError("Transaction rejected by Mempool") + else: + return {"jsonrpc": "2.0", "error": {"code": -32601, "message": f"Method not found: {method}"}, "id": req_id} + + return {"jsonrpc": "2.0", "result": result, "id": req_id} + except Exception as e: + logger.error("RPC Error processing %s: %s", method, e) + return {"jsonrpc": "2.0", "error": {"code": -32000, "message": str(e)}, "id": req_id} diff --git a/minichain/state.py b/minichain/state.py index f817c6c..13c7c02 100644 --- a/minichain/state.py +++ b/minichain/state.py @@ -15,6 +15,7 @@ def __init__(self): # { address: {'balance': int, 'nonce': int, 'code': str|None, 'storage': dict} } self.accounts = {} self.contract_machine = ContractMachine(self) + self.chain_id = "minichain-default" def state_root(self) -> str: """ @@ -46,6 +47,10 @@ def verify_transaction_logic(self, tx): logger.error("Error: Invalid signature for tx from %s...", tx.sender[:8]) return False + if getattr(tx, "chain_id", None) != self.chain_id: + logger.error("Error: Invalid chain_id in tx from %s...", tx.sender[:8]) + return False + sender_acc = self.get_account(tx.sender) total_cost = tx.amount + getattr(tx, 'fee', 0) @@ -65,8 +70,21 @@ def copy(self): """ new_state = copy.deepcopy(self) new_state.contract_machine = ContractMachine(new_state) # Reinitialize contract_machine + new_state.chain_id = self.chain_id return new_state + def snapshot(self): + """ + Returns a deep copy of the current accounts dictionary for rollback safety. + """ + return copy.deepcopy(self.accounts) + + def restore(self, snapshot_data): + """ + Restores the state's accounts dictionary from a snapshot. + """ + self.accounts = copy.deepcopy(snapshot_data) + def validate_and_apply(self, tx): """ Validate and apply a transaction. @@ -145,6 +163,23 @@ def apply_transaction(self, tx): sender['balance'] += tx.amount # Refund amount return Receipt(tx.tx_id, status=0, error_message=result.get("error", "Execution failed"), gas_used=gas_used) + transfers = result.get("transfers", []) + total_transferred_out = sum(t["amount"] for t in transfers) + + if total_transferred_out > receiver['balance']: + # Rollback transfer if execution attempts to spend more than balance + receiver['balance'] -= tx.amount + sender['balance'] += tx.amount # Refund amount + return Receipt(tx.tx_id, status=0, error_message="Insufficient contract balance for transfers", gas_used=gas_used) + + # Execution & transfers valid: commit state changes atomically + self.update_contract_storage(tx.receiver, result["storage"]) + + receiver['balance'] -= total_transferred_out + for t in transfers: + target_acc = self.get_account(t["to"]) + target_acc['balance'] += t["amount"] + return Receipt(tx.tx_id, status=1, gas_used=gas_used) # LOGIC BRANCH 3: Regular Transfer diff --git a/minichain/transaction.py b/minichain/transaction.py index ca282ea..7754dcd 100644 --- a/minichain/transaction.py +++ b/minichain/transaction.py @@ -6,7 +6,7 @@ class Transaction: - _TX_FIELDS = frozenset({"sender", "receiver", "amount", "fee", "nonce", "data", "timestamp", "signature"}) + _TX_FIELDS = frozenset({"sender", "receiver", "amount", "fee", "nonce", "data", "timestamp", "chain_id", "signature"}) def __setattr__(self, name, value) -> None: if name in self._TX_FIELDS and getattr(self, "_sealed", False): @@ -23,13 +23,14 @@ def _normalize_ts(ts) -> int: # If it's already in milliseconds (>= 1e12), just ensure it's an integer return int(ts) - def __init__(self, sender, receiver, amount, nonce, fee=0, data=None, signature=None, timestamp=None): + def __init__(self, sender, receiver, amount, nonce, fee=0, data=None, chain_id="minichain-default", signature=None, timestamp=None): self.sender = sender self.receiver = receiver self.amount = amount self.fee = fee self.nonce = nonce self.data = data + self.chain_id = chain_id self.timestamp = self._normalize_ts(timestamp) if timestamp is not None else round(time.time() * 1000) self.signature = signature self._cached_tx_id = None @@ -37,19 +38,19 @@ def __init__(self, sender, receiver, amount, nonce, fee=0, data=None, signature= def to_dict(self): return {"sender": self.sender, "receiver": self.receiver, "amount": self.amount, "fee": self.fee, - "nonce": self.nonce, "data": self.data, "timestamp": self.timestamp, + "nonce": self.nonce, "data": self.data, "chain_id": self.chain_id, "timestamp": self.timestamp, "signature": self.signature} def to_signing_dict(self): return {"sender": self.sender, "receiver": self.receiver, "amount": self.amount, "fee": self.fee, - "nonce": self.nonce, "data": self.data, "timestamp": self.timestamp} + "nonce": self.nonce, "data": self.data, "chain_id": self.chain_id, "timestamp": self.timestamp} @classmethod def from_dict(cls, payload: dict): return cls(sender=payload["sender"], receiver=payload.get("receiver"), amount=payload["amount"], nonce=payload["nonce"], fee=payload["fee"], - data=payload.get("data"), signature=payload.get("signature"), - timestamp=payload.get("timestamp")) + data=payload.get("data"), chain_id=payload.get("chain_id", "minichain-default"), + signature=payload.get("signature"), timestamp=payload.get("timestamp")) @property def hash_payload(self): diff --git a/requirements.txt b/requirements.txt index 99cd065..98c335a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,5 @@ pynacl==1.6.2 trie>=3.1.0 +libp2p +aiohttp>=3.10.11 +multiaddr diff --git a/tests/test_contract_transfers.py b/tests/test_contract_transfers.py new file mode 100644 index 0000000..30dc1b4 --- /dev/null +++ b/tests/test_contract_transfers.py @@ -0,0 +1,110 @@ +import unittest +from nacl.signing import SigningKey +from nacl.encoding import HexEncoder +from minichain.state import State +from minichain.block import Transaction + +class TestContractTransfers(unittest.TestCase): + def setUp(self): + self.state = State() + self.sender_sk = SigningKey.generate() + self.sender_pk = self.sender_sk.verify_key.encode(encoder=HexEncoder).decode() + + self.target_pk = SigningKey.generate().verify_key.encode(encoder=HexEncoder).decode() + + # Credit sender with enough balance to deploy and call + self.state.credit_mining_reward(self.sender_pk, 10000) + + def _sign(self, tx): + tx.sign(self.sender_sk) + return tx + + def test_successful_transfer_out(self): + # 1. Deploy Contract + code = """ +target = msg['data']['target'] +transfer_out(target, 50) +transfer_out(target, 25) +""" + deploy_tx = self._sign(Transaction(self.sender_pk, None, amount=100, nonce=0, data=code, fee=1000)) + receipt = self.state.apply_transaction(deploy_tx) + self.assertEqual(receipt.status, 1) + contract_addr = receipt.contract_address + + # Sender sent 100 to contract, plus 1000 fee + self.assertEqual(self.state.get_account(contract_addr)['balance'], 100) + self.assertEqual(self.state.get_account(self.target_pk)['balance'], 0) + + # 2. Call Contract to transfer out 75 coins + call_tx = self._sign(Transaction(self.sender_pk, contract_addr, amount=0, nonce=1, data={"target": self.target_pk}, fee=1000)) + receipt2 = self.state.apply_transaction(call_tx) + + self.assertEqual(receipt2.status, 1) + + # Contract balance should be 100 - 75 = 25 + self.assertEqual(self.state.get_account(contract_addr)['balance'], 25) + + # Target should have 75 + self.assertEqual(self.state.get_account(self.target_pk)['balance'], 75) + + def test_failed_transfer_out_insufficient_balance(self): + # 1. Deploy Contract + code = """ +target = msg['data']['target'] +# Try to transfer 500, but contract only has 100 +transfer_out(target, 500) +storage['malicious_state'] = 'corrupted' +""" + deploy_tx = self._sign(Transaction(self.sender_pk, None, amount=100, nonce=0, data=code, fee=1000)) + receipt = self.state.apply_transaction(deploy_tx) + self.assertEqual(receipt.status, 1) + contract_addr = receipt.contract_address + + # 2. Call Contract + call_tx = self._sign(Transaction(self.sender_pk, contract_addr, amount=50, nonce=1, data={"target": self.target_pk}, fee=1000)) + receipt2 = self.state.apply_transaction(call_tx) + + # Should fail with status 0 + self.assertEqual(receipt2.status, 0) + self.assertEqual(receipt2.error_message, "Insufficient contract balance for transfers") + + # State should be completely rolled back (target balance 0, contract balance remains 100) + self.assertEqual(self.state.get_account(contract_addr)['balance'], 100) + self.assertEqual(self.state.get_account(self.target_pk)['balance'], 0) + + # Sender's balance should have decreased by only the fee amount (or gas_used if refunded) as the 50 amount was refunded + # Starting balance 10000, minus (100+1000) for deploy = 8900 + # Call tx net cost is receipt2.gas_used + self.assertEqual(self.state.get_account(self.sender_pk)['balance'], 8900 - receipt2.gas_used) + + # Storage should NOT be updated + self.assertEqual(self.state.get_account(contract_addr)['storage'], {}) + + def test_transfer_with_incoming_funds(self): + # 1. Deploy Contract (0 initial balance) + code = """ +target = msg['data']['target'] +# We use the incoming funds to instantly transfer out! +transfer_out(target, msg['value']) +""" + deploy_tx = self._sign(Transaction(self.sender_pk, None, amount=0, nonce=0, data=code, fee=1000)) + receipt = self.state.apply_transaction(deploy_tx) + self.assertEqual(receipt.status, 1) + contract_addr = receipt.contract_address + + self.assertEqual(self.state.get_account(contract_addr)['balance'], 0) + + # 2. Call Contract sending 50 coins + call_tx = self._sign(Transaction(self.sender_pk, contract_addr, amount=50, nonce=1, data={"target": self.target_pk}, fee=1000)) + receipt2 = self.state.apply_transaction(call_tx) + + self.assertEqual(receipt2.status, 1) + + # Contract balance should be 0 (received 50, sent 50) + self.assertEqual(self.state.get_account(contract_addr)['balance'], 0) + + # Target should have exactly 50 + self.assertEqual(self.state.get_account(self.target_pk)['balance'], 50) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_protocol_hardening.py b/tests/test_protocol_hardening.py index 538c8cc..dcb651f 100644 --- a/tests/test_protocol_hardening.py +++ b/tests/test_protocol_hardening.py @@ -95,17 +95,16 @@ def test_remove_transactions_by_sender_nonce_when_tx_id_differs(self): class TestP2PValidationAndDedup(unittest.IsolatedAsyncioTestCase): async def test_invalid_message_schema_is_rejected(self): - network = P2PNetwork() - - invalid_message = {"type": "tx", "data": {"sender": "abc"}} - self.assertFalse(network._validate_message(invalid_message)) + invalid_payload = {"sender": "abc"} + with self.assertRaises(Exception): + Transaction.from_dict(invalid_payload) async def test_block_schema_accepts_current_block_wire_format(self): sender_sk = SigningKey.generate() sender_pk = sender_sk.verify_key.encode(encoder=HexEncoder).decode() receiver_pk = SigningKey.generate().verify_key.encode(encoder=HexEncoder).decode() - tx = Transaction(sender_pk, receiver_pk, 1, 0, timestamp=123) + tx = Transaction(sender_pk, receiver_pk, 1, 0, timestamp=1600000000000) tx.sign(sender_sk) from minichain.receipt import Receipt @@ -116,7 +115,7 @@ async def test_block_schema_accepts_current_block_wire_format(self): index=1, previous_hash="0" * 64, transactions=[tx], - timestamp=456, + timestamp=1600000000000, difficulty=2, state_root="0"*64, receipts=[receipt], @@ -125,10 +124,8 @@ async def test_block_schema_accepts_current_block_wire_format(self): block.nonce = 9 block.hash = block.compute_hash() - network = P2PNetwork() - message = {"type": "block", "data": block.to_dict()} - - self.assertTrue(network._validate_message(message)) + parsed_block = Block.from_dict(block.to_dict()) + self.assertEqual(parsed_block.hash, block.hash) async def test_duplicate_tx_and_block_detection(self): network = P2PNetwork() diff --git a/tests/test_reorg.py b/tests/test_reorg.py new file mode 100644 index 0000000..4abb7bf --- /dev/null +++ b/tests/test_reorg.py @@ -0,0 +1,116 @@ +import pytest +import os +import json +import time + +from minichain.chain import Blockchain +from minichain.transaction import Transaction +from minichain.mempool import Mempool + +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from main import mine_and_process_block + +from nacl.signing import SigningKey +from nacl.encoding import HexEncoder + +@pytest.fixture +def genesis_file(tmp_path): + path = tmp_path / "genesis_reorg.json" + sk = SigningKey.generate() + pk = sk.verify_key.encode(encoder=HexEncoder).decode() + data = { + "timestamp": int(time.time()), + "difficulty": 0, + "alloc": { + pk: {"balance": 1000} + } + } + with open(path, "w") as f: + json.dump(data, f) + return str(path), sk, pk + +def test_resolve_conflicts_heavier_chain(genesis_file): + g_path, sk, pk = genesis_file + + node_a = Blockchain(genesis_path=g_path) + node_b = Blockchain(genesis_path=g_path) + + assert node_a.get_total_work() == node_b.get_total_work() + + pool_b = Mempool() + tx = Transaction(sender=pk, receiver="b"*64, amount=10, nonce=0, fee=1) + tx.sign(sk) + pool_b.add_transaction(tx) + + mined_b = mine_and_process_block(node_b, pool_b, pk) + assert mined_b is not None + assert node_b.get_total_work() > node_a.get_total_work() + + # Node A receives Node B's chain + success, orphans = node_a.resolve_conflicts(node_b.chain) + + assert success is True + assert node_a.last_block.hash == node_b.last_block.hash + assert node_a.state.accounts == node_b.state.accounts + assert len(orphans) == 0 + +def test_resolve_conflicts_reorg_with_orphans(genesis_file): + g_path, sk, pk = genesis_file + + node_a = Blockchain(genesis_path=g_path) + node_b = Blockchain(genesis_path=g_path) + + pool_a = Mempool() + pool_b = Mempool() + + # Node A mines tx1 (nonce 0) + tx1 = Transaction(sender=pk, receiver="a"*64, amount=10, nonce=0, fee=1) + tx1.sign(sk) + pool_a.add_transaction(tx1) + mine_and_process_block(node_a, pool_a, pk) + + # Node B mines tx2 (nonce 0, competing transaction) + tx2 = Transaction(sender=pk, receiver="b"*64, amount=20, nonce=0, fee=1) + tx2.sign(sk) + pool_b.add_transaction(tx2) + mine_and_process_block(node_b, pool_b, pk) + + # Node B mines tx3 (nonce 1) to become the heavier chain + tx3 = Transaction(sender=pk, receiver="c"*64, amount=30, nonce=1, fee=1) + tx3.sign(sk) + pool_b.add_transaction(tx3) + block_b2 = mine_and_process_block(node_b, pool_b, pk) + + assert node_b.get_total_work() > node_a.get_total_work() + + # Node A attempts reorg using B's heavier chain + success, orphans = node_a.resolve_conflicts(node_b.chain) + + assert success is True + assert node_a.last_block.hash == block_b2.hash + + # tx1 was in A's chain but NOT in B's chain. It should be orphaned. + assert len(orphans) == 1 + assert orphans[0].tx_id == tx1.tx_id + +def test_resolve_conflicts_rejects_lighter_chain(genesis_file): + g_path, sk, pk = genesis_file + + node_a = Blockchain(genesis_path=g_path) + node_b = Blockchain(genesis_path=g_path) + + pool_a = Mempool() + + # Node A mines a block + tx1 = Transaction(sender=pk, receiver="a"*64, amount=10, nonce=0, fee=1) + tx1.sign(sk) + pool_a.add_transaction(tx1) + mine_and_process_block(node_a, pool_a, pk) + + # Node B is empty. It tries to reorg Node A with its shorter chain. + success, orphans = node_a.resolve_conflicts(node_b.chain) + + assert success is False + assert len(orphans) == 0 + assert node_a.get_total_work() > node_b.get_total_work() diff --git a/tests/test_rpc.py b/tests/test_rpc.py new file mode 100644 index 0000000..3ab4ea6 --- /dev/null +++ b/tests/test_rpc.py @@ -0,0 +1,75 @@ +import pytest +import aiohttp +import asyncio +from minichain.chain import Blockchain +from minichain.mempool import Mempool +from minichain.p2p import P2PNetwork +from minichain.rpc import JSONRPCServer + +@pytest.fixture +def anyio_backend(): + return 'asyncio' + +@pytest.fixture +async def rpc_server(free_tcp_port): + chain = Blockchain() + mempool = Mempool() + network = P2PNetwork() + + server = JSONRPCServer(chain, mempool, network) + port = free_tcp_port + await server.start(host="127.0.0.1", port=port) + + yield server, port, chain, mempool + + await server.app.cleanup() + +@pytest.mark.anyio +async def test_rpc_blockNumber(rpc_server): + server, port, chain, mempool = rpc_server + + async with aiohttp.ClientSession() as session: + payload = {"jsonrpc": "2.0", "method": "mc_blockNumber", "id": 1} + async with session.post(f"http://127.0.0.1:{port}/", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert data["result"] == 0 + assert data["id"] == 1 + +@pytest.mark.anyio +async def test_rpc_getBlockByNumber(rpc_server): + server, port, chain, mempool = rpc_server + + async with aiohttp.ClientSession() as session: + payload = {"jsonrpc": "2.0", "method": "mc_getBlockByNumber", "params": [0], "id": 2} + async with session.post(f"http://127.0.0.1:{port}/", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert data["result"]["index"] == 0 + assert data["id"] == 2 + +@pytest.mark.anyio +async def test_rpc_invalid_request_format(rpc_server): + server, port, chain, mempool = rpc_server + + async with aiohttp.ClientSession() as session: + payload = 1 + async with session.post(f"http://127.0.0.1:{port}/", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert "error" in data + assert data["error"]["code"] == -32600 + assert data["id"] is None + +@pytest.mark.anyio +async def test_rpc_invalid_method(rpc_server): + server, port, chain, mempool = rpc_server + + async with aiohttp.ClientSession() as session: + payload = {"jsonrpc": "2.0", "method": "mc_unknown", "id": 3} + async with session.post(f"http://127.0.0.1:{port}/", json=payload) as resp: + assert resp.status == 200 + data = await resp.json() + assert "error" in data + assert data["error"]["code"] == -32601 + assert data["id"] == 3 diff --git a/tests/test_transaction_signing.py b/tests/test_transaction_signing.py index ee3c845..56d14ee 100644 --- a/tests/test_transaction_signing.py +++ b/tests/test_transaction_signing.py @@ -141,9 +141,24 @@ def test_unsigned_transaction_fails_verification(alice, bob): # ------------------------------------------------------------------ -# 4. Replay protection +# 4. Replay protection and Cross-Chain protection # ------------------------------------------------------------------ +def test_wrong_chain_id_rejected(alice, bob, funded_state): + """A transaction with a chain_id differing from the state's chain_id must be rejected.""" + alice_sk, alice_pk = alice + _, bob_pk = bob + + tx = Transaction(alice_pk, bob_pk, 10, nonce=0, chain_id="wrong-chain") + tx.sign(alice_sk) + + assert not funded_state.apply_transaction(tx), "Transaction for a different chain_id must be rejected." + # Ensure the rejected transaction did not mutate the ledger + assert funded_state.get_account(alice_pk)["balance"] == 100, \ + "Alice's balance must remain unchanged after a cross-chain rejection." + assert funded_state.get_account(alice_pk)["nonce"] == 0, \ + "Alice's nonce must remain unchanged after a cross-chain rejection." + def test_replay_attack_same_nonce_rejected(alice, bob, funded_state): """Replaying the same transaction must be rejected the second time.""" alice_sk, alice_pk = alice