Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 155 additions & 32 deletions crates/net/p2p/src/req_resp/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,8 @@ pub const MAX_COMPRESSED_PAYLOAD_SIZE: usize = 32 + MAX_PAYLOAD_SIZE + MAX_PAYLO
/// Decoded payload together with the size of its on-wire snappy-compressed
/// bytes (excluding the varint length prefix).
///
/// `compressed_size` is accurate for single-chunk streams (Status request /
/// response, BlocksByRoot request). For multi-chunk streams (BlocksByRoot
/// response) the value is over-reported on the first chunk because
/// `decode_payload` slurps the whole stream via `read_to_end` before parsing
/// the first varint, so `compressed_size` measures everything left after that
/// varint rather than just this chunk's snappy frame. The metric is still
/// useful for single-chunk traffic and as an order-of-magnitude signal on
/// multi-chunk responses; precise per-chunk accounting would require refactoring
/// `decode_payload` to read one varint + one snappy frame at a time.
/// `compressed_size` excludes the varint length prefix and covers only the
/// snappy frame for this payload.
pub struct DecodedPayload {
pub uncompressed: Vec<u8>,
pub compressed_size: usize,
Expand All @@ -30,20 +23,7 @@ pub async fn decode_payload<T>(io: &mut T) -> io::Result<DecodedPayload>
where
T: AsyncRead + Unpin + Send,
{
let mut buf = vec![];
let read = io
.take(MAX_COMPRESSED_PAYLOAD_SIZE as u64)
.read_to_end(&mut buf)
.await?;

if read < 2 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"message too short",
));
}
let (size, rest) = decode_varint(&buf)?;
let compressed_size = rest.len();
let size = read_varint(io).await?;

if size as usize > MAX_PAYLOAD_SIZE {
return Err(io::Error::new(
Expand All @@ -52,22 +32,112 @@ where
));
}

let mut decoder = snap::read::FrameDecoder::new(rest);
let mut uncompressed = Vec::new();
io::Read::read_to_end(&mut decoder, &mut uncompressed)?;
if uncompressed.len() != size as usize {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"uncompressed size does not match received size",
));
if size == 0 {
return Ok(DecodedPayload {
uncompressed: Vec::new(),
compressed_size: 0,
});
}

let (uncompressed, compressed_size) = read_snappy_frame(io, size as usize).await?;

Ok(DecodedPayload {
uncompressed,
compressed_size,
})
}

async fn read_varint<T>(io: &mut T) -> io::Result<u32>
where
T: AsyncRead + Unpin + Send,
{
let mut buf = [0_u8; 5];

for i in 0..5 {
io.read_exact(&mut buf[i..=i]).await?;

if buf[i] & 0x80 == 0 {
return decode_varint(&buf[..=i]).map(|(size, _)| size);
}
}

Err(io::Error::new(
io::ErrorKind::InvalidData,
"message size is bigger than 28 bits",
))
}

async fn read_snappy_frame<T>(
io: &mut T,
expected_uncompressed_size: usize,
) -> io::Result<(Vec<u8>, usize)>
where
T: AsyncRead + Unpin + Send,
{
let mut frame = Vec::new();
let mut uncompressed_len = 0_usize;

while uncompressed_len < expected_uncompressed_size {
let mut header = [0_u8; 4];
io.read_exact(&mut header).await?;
let chunk_type = header[0];
let chunk_len = u32::from_le_bytes([header[1], header[2], header[3], 0]) as usize;

if frame.len() + header.len() + chunk_len > MAX_COMPRESSED_PAYLOAD_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"compressed message size exceeds maximum allowed",
));
}

frame.extend_from_slice(&header);

let data_start = frame.len();
frame.resize(data_start + chunk_len, 0);
io.read_exact(&mut frame[data_start..]).await?;
let data = &frame[data_start..];

let chunk_uncompressed_len = match chunk_type {
0x00 => {
let block = data.get(4..).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "compressed chunk too short")
})?;
snap::raw::decompress_len(block)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
}
0x01 => chunk_len.checked_sub(4).ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "uncompressed chunk too short")
})?,
0x80..=0xff => 0,
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unsupported snappy chunk type",
));
}
};

uncompressed_len = uncompressed_len
.checked_add(chunk_uncompressed_len)
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "uncompressed size overflow")
})?;

if uncompressed_len > expected_uncompressed_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"uncompressed size does not match received size",
));
}
}

let mut decoder = snap::read::FrameDecoder::new(frame.as_slice());
let mut uncompressed = vec![0; expected_uncompressed_size];
io::Read::read_exact(&mut decoder, &mut uncompressed)?;

Ok((uncompressed, frame.len()))
}

/// Write a varint-prefixed, snappy-compressed SSZ payload. Returns the size
/// of the snappy-compressed bytes (excluding the varint length prefix).
pub async fn write_payload<T>(io: &mut T, encoded: &[u8]) -> io::Result<usize>
Expand Down Expand Up @@ -137,7 +207,8 @@ pub fn decode_varint(buf: &[u8]) -> io::Result<(u32, &[u8])> {

#[cfg(test)]
mod tests {
use super::decode_varint;
use super::{decode_payload, decode_varint, write_payload};
use futures::io::Cursor;

#[test]
fn test_decode_varint() {
Expand All @@ -149,4 +220,56 @@ mod tests {
let expected: &[u8] = &[];
assert_eq!(rest, expected);
}

#[tokio::test]
async fn decode_payload_leaves_following_payload_in_stream() {
let first = b"first payload";
let second = b"second payload";

let mut stream = Cursor::new(Vec::new());
let first_compressed_size = write_payload(&mut stream, first).await.unwrap();
let second_compressed_size = write_payload(&mut stream, second).await.unwrap();

let mut stream = Cursor::new(stream.into_inner());

let decoded = decode_payload(&mut stream).await.unwrap();
assert_eq!(decoded.uncompressed, first);
assert_eq!(decoded.compressed_size, first_compressed_size);

let decoded = decode_payload(&mut stream).await.unwrap();
assert_eq!(decoded.uncompressed, second);
assert_eq!(decoded.compressed_size, second_compressed_size);
}

#[tokio::test]
async fn decode_payload_reads_all_snappy_chunks_for_one_payload() {
let payload = vec![42; 128 * 1024];

let mut stream = Cursor::new(Vec::new());
let compressed_size = write_payload(&mut stream, &payload).await.unwrap();

let mut stream = Cursor::new(stream.into_inner());
let decoded = decode_payload(&mut stream).await.unwrap();

assert_eq!(decoded.uncompressed, payload);
assert_eq!(decoded.compressed_size, compressed_size);
}

#[tokio::test]
async fn decode_payload_handles_empty_payload_before_following_payload() {
let second = b"after empty";

let mut stream = Cursor::new(Vec::new());
let empty_compressed_size = write_payload(&mut stream, &[]).await.unwrap();
write_payload(&mut stream, second).await.unwrap();

let mut stream = Cursor::new(stream.into_inner());

let decoded = decode_payload(&mut stream).await.unwrap();
assert!(decoded.uncompressed.is_empty());
assert_eq!(decoded.compressed_size, empty_compressed_size);

let decoded = decode_payload(&mut stream).await.unwrap();
assert_eq!(decoded.uncompressed, second);
}
}