"""Tests for message framing.""" import pytest import time from radicle_reticulum.messages import ( MessageType, MessageHeader, NodeAnnouncement, InventoryAnnouncement, RefAnnouncement, Ping, Pong, decode_message, HEADER_SIZE, ) class TestMessageHeader: """Test message header encoding/decoding.""" def test_encode_decode_roundtrip(self): """Test header encode/decode roundtrip.""" header = MessageHeader( msg_type=MessageType.NODE_ANNOUNCEMENT, timestamp=1234567890123, payload_length=42, ) encoded = header.encode() assert len(encoded) == HEADER_SIZE decoded = MessageHeader.decode(encoded) assert decoded.msg_type == header.msg_type assert decoded.timestamp == header.timestamp assert decoded.payload_length == header.payload_length def test_header_too_short(self): """Test that short data raises error.""" with pytest.raises(ValueError, match="Header too short"): MessageHeader.decode(b"\x00\x01") class TestNodeAnnouncement: """Test NodeAnnouncement message.""" def test_encode_decode_roundtrip(self): """Test encode/decode roundtrip.""" msg = NodeAnnouncement( node_id="did:key:z6MkhaXgBZDvotDkL5257faiztiGiC2QtKLGpbnnEGta2doK", features=0x0003, version=1, ) encoded = msg.encode() decoded = NodeAnnouncement.decode(encoded) assert decoded.node_id == msg.node_id assert decoded.features == msg.features assert decoded.version == msg.version def test_to_message_includes_header(self): """Test that to_message includes proper header.""" msg = NodeAnnouncement(node_id="did:key:z6Mk...") full_message = msg.to_message() header = MessageHeader.decode(full_message) assert header.msg_type == MessageType.NODE_ANNOUNCEMENT assert header.timestamp > 0 assert header.payload_length == len(msg.encode()) class TestInventoryAnnouncement: """Test InventoryAnnouncement message.""" def test_encode_decode_roundtrip(self): """Test encode/decode roundtrip.""" msg = InventoryAnnouncement( node_id="did:key:z6MkhaXgBZDvotDkL5257faiztiGiC2QtKLGpbnnEGta2doK", repositories=[ "rad:z3gqcJUoA1n9HaHKufZs5FCSGazv5", "rad:z4gqcJUoA1n9HaHKufZs5FCSGazv6", ], ) encoded = msg.encode() decoded = InventoryAnnouncement.decode(encoded) assert decoded.node_id == msg.node_id assert decoded.repositories == msg.repositories def test_empty_repositories(self): """Test with empty repository list.""" msg = InventoryAnnouncement( node_id="did:key:z6Mk...", repositories=[], ) encoded = msg.encode() decoded = InventoryAnnouncement.decode(encoded) assert decoded.repositories == [] class TestRefAnnouncement: """Test RefAnnouncement message.""" def test_encode_decode_roundtrip(self): """Test encode/decode roundtrip.""" msg = RefAnnouncement( repository_id="rad:z3gqcJUoA1n9HaHKufZs5FCSGazv5", ref_name="refs/heads/main", old_oid=b"\x00" * 20, new_oid=bytes.fromhex("abc123def456789012345678901234567890abcd"), signature=b"fake_signature_bytes", ) encoded = msg.encode() decoded = RefAnnouncement.decode(encoded) assert decoded.repository_id == msg.repository_id assert decoded.ref_name == msg.ref_name assert decoded.old_oid == msg.old_oid assert decoded.new_oid == msg.new_oid assert decoded.signature == msg.signature class TestPingPong: """Test Ping/Pong messages.""" def test_ping_encode_decode(self): """Test Ping encode/decode.""" ping = Ping() encoded = ping.encode() decoded = Ping.decode(encoded) assert decoded.nonce == ping.nonce def test_pong_echoes_nonce(self): """Test Pong echoes ping nonce.""" ping = Ping() pong = Pong(nonce=ping.nonce) assert pong.nonce == ping.nonce class TestDecodeMessage: """Test the decode_message function.""" def test_decode_node_announcement(self): """Test decoding a NodeAnnouncement message.""" msg = NodeAnnouncement(node_id="did:key:z6Mk...") full_message = msg.to_message() header, decoded = decode_message(full_message) assert header.msg_type == MessageType.NODE_ANNOUNCEMENT assert isinstance(decoded, NodeAnnouncement) assert decoded.node_id == msg.node_id def test_decode_inventory_announcement(self): """Test decoding an InventoryAnnouncement message.""" msg = InventoryAnnouncement( node_id="did:key:z6Mk...", repositories=["repo1", "repo2"], ) full_message = msg.to_message() header, decoded = decode_message(full_message) assert header.msg_type == MessageType.INVENTORY_ANNOUNCEMENT assert isinstance(decoded, InventoryAnnouncement) assert decoded.repositories == msg.repositories def test_decode_ping(self): """Test decoding a Ping message.""" ping = Ping() full_message = ping.to_message() header, decoded = decode_message(full_message) assert header.msg_type == MessageType.PING assert isinstance(decoded, Ping) def test_unknown_message_type_raises(self): """Test that unknown message types raise error.""" # Create a message with invalid type import struct bad_message = struct.pack("!BQH", 0xFF, 0, 0) with pytest.raises(ValueError, match="Unknown message type"): decode_message(bad_message)