From c07324ec0203b842fe81dd0c3bc43534f7a9401a Mon Sep 17 00:00:00 2001 From: dicethedev Date: Wed, 3 Jun 2026 07:12:36 +0100 Subject: [PATCH] fix req-resp multi-chunk response decoding --- crates/net/p2p/src/req_resp/encoding.rs | 187 ++++++++++++++++++++---- 1 file changed, 155 insertions(+), 32 deletions(-) diff --git a/crates/net/p2p/src/req_resp/encoding.rs b/crates/net/p2p/src/req_resp/encoding.rs index 3cb26b0a..02d9343a 100644 --- a/crates/net/p2p/src/req_resp/encoding.rs +++ b/crates/net/p2p/src/req_resp/encoding.rs @@ -11,15 +11,8 @@ pub const MAX_COMPRESSED_PAYLOAD_SIZE: usize = 32 + MAX_PAYLOAD_SIZE + MAX_PAYLO /// Decoded payload together with the size of its on-wire snappy-compressed /// bytes (excluding the varint length prefix). /// -/// `compressed_size` is accurate for single-chunk streams (Status request / -/// response, BlocksByRoot request). For multi-chunk streams (BlocksByRoot -/// response) the value is over-reported on the first chunk because -/// `decode_payload` slurps the whole stream via `read_to_end` before parsing -/// the first varint, so `compressed_size` measures everything left after that -/// varint rather than just this chunk's snappy frame. The metric is still -/// useful for single-chunk traffic and as an order-of-magnitude signal on -/// multi-chunk responses; precise per-chunk accounting would require refactoring -/// `decode_payload` to read one varint + one snappy frame at a time. +/// `compressed_size` excludes the varint length prefix and covers only the +/// snappy frame for this payload. pub struct DecodedPayload { pub uncompressed: Vec, pub compressed_size: usize, @@ -30,20 +23,7 @@ pub async fn decode_payload(io: &mut T) -> io::Result where T: AsyncRead + Unpin + Send, { - let mut buf = vec![]; - let read = io - .take(MAX_COMPRESSED_PAYLOAD_SIZE as u64) - .read_to_end(&mut buf) - .await?; - - if read < 2 { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "message too short", - )); - } - let (size, rest) = decode_varint(&buf)?; - let compressed_size = rest.len(); + let size = read_varint(io).await?; if size as usize > MAX_PAYLOAD_SIZE { return Err(io::Error::new( @@ -52,22 +32,112 @@ where )); } - let mut decoder = snap::read::FrameDecoder::new(rest); - let mut uncompressed = Vec::new(); - io::Read::read_to_end(&mut decoder, &mut uncompressed)?; - if uncompressed.len() != size as usize { - return Err(io::Error::new( - io::ErrorKind::InvalidData, - "uncompressed size does not match received size", - )); + if size == 0 { + return Ok(DecodedPayload { + uncompressed: Vec::new(), + compressed_size: 0, + }); } + let (uncompressed, compressed_size) = read_snappy_frame(io, size as usize).await?; + Ok(DecodedPayload { uncompressed, compressed_size, }) } +async fn read_varint(io: &mut T) -> io::Result +where + T: AsyncRead + Unpin + Send, +{ + let mut buf = [0_u8; 5]; + + for i in 0..5 { + io.read_exact(&mut buf[i..=i]).await?; + + if buf[i] & 0x80 == 0 { + return decode_varint(&buf[..=i]).map(|(size, _)| size); + } + } + + Err(io::Error::new( + io::ErrorKind::InvalidData, + "message size is bigger than 28 bits", + )) +} + +async fn read_snappy_frame( + io: &mut T, + expected_uncompressed_size: usize, +) -> io::Result<(Vec, usize)> +where + T: AsyncRead + Unpin + Send, +{ + let mut frame = Vec::new(); + let mut uncompressed_len = 0_usize; + + while uncompressed_len < expected_uncompressed_size { + let mut header = [0_u8; 4]; + io.read_exact(&mut header).await?; + let chunk_type = header[0]; + let chunk_len = u32::from_le_bytes([header[1], header[2], header[3], 0]) as usize; + + if frame.len() + header.len() + chunk_len > MAX_COMPRESSED_PAYLOAD_SIZE { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "compressed message size exceeds maximum allowed", + )); + } + + frame.extend_from_slice(&header); + + let data_start = frame.len(); + frame.resize(data_start + chunk_len, 0); + io.read_exact(&mut frame[data_start..]).await?; + let data = &frame[data_start..]; + + let chunk_uncompressed_len = match chunk_type { + 0x00 => { + let block = data.get(4..).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "compressed chunk too short") + })?; + snap::raw::decompress_len(block) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))? + } + 0x01 => chunk_len.checked_sub(4).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "uncompressed chunk too short") + })?, + 0x80..=0xff => 0, + _ => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "unsupported snappy chunk type", + )); + } + }; + + uncompressed_len = uncompressed_len + .checked_add(chunk_uncompressed_len) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "uncompressed size overflow") + })?; + + if uncompressed_len > expected_uncompressed_size { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "uncompressed size does not match received size", + )); + } + } + + let mut decoder = snap::read::FrameDecoder::new(frame.as_slice()); + let mut uncompressed = vec![0; expected_uncompressed_size]; + io::Read::read_exact(&mut decoder, &mut uncompressed)?; + + Ok((uncompressed, frame.len())) +} + /// Write a varint-prefixed, snappy-compressed SSZ payload. Returns the size /// of the snappy-compressed bytes (excluding the varint length prefix). pub async fn write_payload(io: &mut T, encoded: &[u8]) -> io::Result @@ -137,7 +207,8 @@ pub fn decode_varint(buf: &[u8]) -> io::Result<(u32, &[u8])> { #[cfg(test)] mod tests { - use super::decode_varint; + use super::{decode_payload, decode_varint, write_payload}; + use futures::io::Cursor; #[test] fn test_decode_varint() { @@ -149,4 +220,56 @@ mod tests { let expected: &[u8] = &[]; assert_eq!(rest, expected); } + + #[tokio::test] + async fn decode_payload_leaves_following_payload_in_stream() { + let first = b"first payload"; + let second = b"second payload"; + + let mut stream = Cursor::new(Vec::new()); + let first_compressed_size = write_payload(&mut stream, first).await.unwrap(); + let second_compressed_size = write_payload(&mut stream, second).await.unwrap(); + + let mut stream = Cursor::new(stream.into_inner()); + + let decoded = decode_payload(&mut stream).await.unwrap(); + assert_eq!(decoded.uncompressed, first); + assert_eq!(decoded.compressed_size, first_compressed_size); + + let decoded = decode_payload(&mut stream).await.unwrap(); + assert_eq!(decoded.uncompressed, second); + assert_eq!(decoded.compressed_size, second_compressed_size); + } + + #[tokio::test] + async fn decode_payload_reads_all_snappy_chunks_for_one_payload() { + let payload = vec![42; 128 * 1024]; + + let mut stream = Cursor::new(Vec::new()); + let compressed_size = write_payload(&mut stream, &payload).await.unwrap(); + + let mut stream = Cursor::new(stream.into_inner()); + let decoded = decode_payload(&mut stream).await.unwrap(); + + assert_eq!(decoded.uncompressed, payload); + assert_eq!(decoded.compressed_size, compressed_size); + } + + #[tokio::test] + async fn decode_payload_handles_empty_payload_before_following_payload() { + let second = b"after empty"; + + let mut stream = Cursor::new(Vec::new()); + let empty_compressed_size = write_payload(&mut stream, &[]).await.unwrap(); + write_payload(&mut stream, second).await.unwrap(); + + let mut stream = Cursor::new(stream.into_inner()); + + let decoded = decode_payload(&mut stream).await.unwrap(); + assert!(decoded.uncompressed.is_empty()); + assert_eq!(decoded.compressed_size, empty_compressed_size); + + let decoded = decode_payload(&mut stream).await.unwrap(); + assert_eq!(decoded.uncompressed, second); + } }