diff --git a/src/codec.rs b/src/codec.rs index 11555db..9e13dd9 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -3,7 +3,7 @@ use crate::{ model::VectorCodec, wire::{ DEFAULT_MAX_DECODE_COUNT, Reader, check_byte_len, check_decode_count, check_element_bytes, - decode_zigzag, encode_varuint, encode_zigzag, extend_repeat, + decode_zigzag, encode_varuint, encode_zigzag, extend_repeat_with_budget, }, }; @@ -175,12 +175,18 @@ fn encode_u64_rle(values: &[u64], out: &mut Vec) { } fn decode_u64_rle(reader: &mut Reader<'_>) -> Result> { + let start = reader.position(); let runs_len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?; - let mut out = Vec::new(); + let mut runs = Vec::with_capacity(runs_len); for _ in 0..runs_len { let value = reader.read_varuint()?; let count = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?; - extend_repeat(&mut out, value, count)?; + runs.push((value, count)); + } + let budget_input = reader.position() - start; + let mut out = Vec::new(); + for (value, count) in runs { + extend_repeat_with_budget(&mut out, value, count, 8, Some(budget_input))?; } Ok(out) } @@ -461,12 +467,18 @@ fn encode_i64_rle(values: &[i64], out: &mut Vec) { } fn decode_i64_rle(reader: &mut Reader<'_>) -> Result> { + let start = reader.position(); let runs_len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?; - let mut out = Vec::new(); + let mut runs = Vec::with_capacity(runs_len); for _ in 0..runs_len { let value = decode_zigzag(reader.read_varuint()?); let count = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?; - extend_repeat(&mut out, value, count)?; + runs.push((value, count)); + } + let budget_input = reader.position() - start; + let mut out = Vec::new(); + for (value, count) in runs { + extend_repeat_with_budget(&mut out, value, count, 8, Some(budget_input))?; } Ok(out) } diff --git a/src/lib.rs b/src/lib.rs index 7770573..b1725a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ pub use error::{Result, TwilicError}; pub use model::{Message, Schema, Value}; pub use protocol::{SessionEncoder, TwilicCodec}; pub use session::{SessionOptions, UnknownReferencePolicy}; -pub use wire::DEFAULT_MAX_DECODE_COUNT; +pub use wire::{DEFAULT_MAX_DECODE_COUNT, DEFAULT_MAX_DECODE_OUTPUT_RATIO}; pub fn encode(value: &Value) -> Result> { v2::encode(value) @@ -90,6 +90,16 @@ mod tests { assert!(err.to_string().contains(DECODE_COUNT_LIMIT_MSG)); } + #[test] + fn wire_extend_repeat_rejects_output_ratio_bomb() { + use crate::wire::{DECODE_OUTPUT_RATIO_MSG, extend_repeat_with_budget}; + let mut out = Vec::new(); + let err = extend_repeat_with_budget(&mut out, 0u8, 100_000, 1, Some(8)) + .expect_err("expected output ratio error"); + assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG)); + assert!(out.is_empty()); + } + #[test] fn roundtrip_dynamic_value() { let value = Value::Map(vec![ diff --git a/src/protocol.rs b/src/protocol.rs index 184bdae..6b2e56a 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -13,8 +13,9 @@ use crate::{ DictionaryFallback, DictionaryProfile, SessionOptions, SessionState, UnknownReferencePolicy, }, wire::{ - DEFAULT_MAX_DECODE_COUNT, Reader, check_decode_count, decode_zigzag, encode_bitmap, - encode_bytes, encode_string, encode_varuint, encode_zigzag, extend_repeat, + DEFAULT_MAX_DECODE_COUNT, Reader, check_decode_count, check_decode_output_bytes, + decode_zigzag, encode_bitmap, encode_bytes, encode_string, encode_varuint, encode_zigzag, + extend_repeat_with_budget, max_decode_output_bytes, }, }; @@ -2922,11 +2923,16 @@ fn rle_encode_bytes(input: &[u8]) -> Vec { fn rle_decode_bytes(input: &[u8]) -> Result> { let mut reader = Reader::new(input); let run_count = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?; - let mut out = Vec::new(); + let mut runs = Vec::with_capacity(run_count); for _ in 0..run_count { let len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?; let byte = reader.read_u8()?; - extend_repeat(&mut out, byte, len)?; + runs.push((byte, len)); + } + let budget_input = reader.position(); + let mut out = Vec::new(); + for (byte, len) in runs { + extend_repeat_with_budget(&mut out, byte, len, 1, Some(budget_input))?; } if !reader.is_eof() { return Err(TwilicError::InvalidData( @@ -3065,6 +3071,7 @@ fn control_huffman_decode_bytes(input: &[u8]) -> Result> { } let total = freqs.iter().map(|f| *f as usize).sum::(); check_decode_count(total, DEFAULT_MAX_DECODE_COUNT)?; + check_decode_output_bytes(total, input.len())?; if total == 0 { return Ok(Vec::new()); } @@ -3072,13 +3079,13 @@ fn control_huffman_decode_bytes(input: &[u8]) -> Result> { .ok_or(TwilicError::InvalidData("control stream huffman tree"))?; if let HuffNode::Leaf(symbol) = nodes[root] { let mut out = Vec::new(); - extend_repeat(&mut out, symbol, total)?; + extend_repeat_with_budget(&mut out, symbol, total, 1, Some(input.len()))?; return Ok(out); } let remaining = input.len().saturating_sub(reader.position()); let bitstream = reader.read_exact(remaining)?; - let mut out = Vec::with_capacity(total); + let mut out = Vec::with_capacity(total.min(max_decode_output_bytes(input.len()))); let mut byte_idx = 0usize; let mut bit_idx = 0u8; for _ in 0..total { @@ -3219,6 +3226,7 @@ fn control_fse_frame_decode(input: &[u8]) -> Result> { } let table_size = 1u32 << table_log; let len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?; + check_decode_output_bytes(len, input.len())?; let used = reader.read_bounded_count(256)?; if used > 256 || used > table_size as usize { return Err(TwilicError::InvalidData("control stream fse used symbols")); @@ -3280,7 +3288,7 @@ fn control_fse_frame_decode(input: &[u8]) -> Result> { let mut renorm_idx = renorm.len(); let mask = table_size - 1; - let mut out = Vec::with_capacity(len); + let mut out = Vec::with_capacity(len.min(max_decode_output_bytes(input.len()))); for _ in 0..len { let slot = (state & mask) as usize; let symbol = *decode_table @@ -4141,4 +4149,20 @@ mod tests { assert_eq!(find_template_id(&templates, &probe), Some(2)); } + + #[test] + fn rle_decode_bytes_uses_consumed_run_budget_not_input_len() { + let mut encoded = Vec::new(); + encode_varuint(1, &mut encoded); + encode_varuint(5_000, &mut encoded); + encoded.push(0xAB); + encoded.push(0); + + let err = rle_decode_bytes(&encoded).expect_err("expected output ratio guard"); + assert!( + err.to_string() + .contains(crate::wire::DECODE_OUTPUT_RATIO_MSG), + "unexpected error: {err}" + ); + } } diff --git a/src/wire.rs b/src/wire.rs index 7d27d57..ec1315a 100644 --- a/src/wire.rs +++ b/src/wire.rs @@ -1,9 +1,11 @@ use crate::error::{Result, TwilicError}; pub const DEFAULT_MAX_DECODE_COUNT: usize = 1 << 20; +pub const DEFAULT_MAX_DECODE_OUTPUT_RATIO: usize = 1 << 10; pub const DECODE_COUNT_LIMIT_MSG: &str = "decode count limit exceeded"; pub const DECODE_LENGTH_OVERFLOW_MSG: &str = "decode length overflow"; +pub const DECODE_OUTPUT_RATIO_MSG: &str = "decode output ratio exceeded"; #[inline] pub fn check_decode_count(count: usize, max: usize) -> Result<()> { @@ -35,12 +37,43 @@ pub fn check_element_bytes( } } +#[inline] +pub fn max_decode_output_bytes(input_len: usize) -> usize { + input_len + .saturating_mul(DEFAULT_MAX_DECODE_OUTPUT_RATIO) + .min(DEFAULT_MAX_DECODE_COUNT) +} + +#[inline] +pub fn check_decode_output_bytes(output_bytes: usize, input_len: usize) -> Result<()> { + if output_bytes > max_decode_output_bytes(input_len) { + return Err(TwilicError::InvalidData(DECODE_OUTPUT_RATIO_MSG)); + } + Ok(()) +} + pub fn extend_repeat(out: &mut Vec, value: T, count: usize) -> Result<()> { + extend_repeat_with_budget(out, value, count, 1, None) +} + +pub fn extend_repeat_with_budget( + out: &mut Vec, + value: T, + count: usize, + element_bytes: usize, + input_len: Option, +) -> Result<()> { let new_len = out .len() .checked_add(count) .ok_or(TwilicError::InvalidData(DECODE_COUNT_LIMIT_MSG))?; check_decode_count(new_len, DEFAULT_MAX_DECODE_COUNT)?; + if let Some(input_len) = input_len { + let output_bytes = new_len + .checked_mul(element_bytes) + .ok_or(TwilicError::InvalidData(DECODE_OUTPUT_RATIO_MSG))?; + check_decode_output_bytes(output_bytes, input_len)?; + } out.extend(std::iter::repeat_n(value, count)); Ok(()) } @@ -116,6 +149,10 @@ impl<'a> Reader<'a> { self.offset } + pub fn input_len(&self) -> usize { + self.input.len() + } + pub fn remaining(&self) -> usize { self.input.len().saturating_sub(self.offset) } diff --git a/tests/codec_spec_vectors.rs b/tests/codec_spec_vectors.rs index 8cc45b8..7f761c6 100644 --- a/tests/codec_spec_vectors.rs +++ b/tests/codec_spec_vectors.rs @@ -7,9 +7,26 @@ use twilic_rust::{ encode_i64_vector, encode_u64_vector, }, model::VectorCodec, - wire::{Reader, encode_varuint}, + wire::{DECODE_OUTPUT_RATIO_MSG, Reader, encode_varuint}, }; +#[test] +fn vector_rle_rejects_decompression_bomb_despite_trailing_column_bytes() { + let mut rle = Vec::new(); + encode_varuint(1, &mut rle); + encode_varuint(0, &mut rle); + encode_varuint(100_000, &mut rle); + let rle_byte_len = rle.len(); + let mut bytes = rle; + bytes.extend(std::iter::repeat_n(0u8, 16 * 1024)); + + let mut reader = Reader::new(&bytes); + let err = + decode_u64_vector(&mut reader, VectorCodec::Rle).expect_err("expected output ratio error"); + assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG)); + assert_eq!(reader.position(), rle_byte_len); +} + #[test] fn simple8b_i64_roundtrip_small_values() { let values = vec![1, 2, 3, -1, 0, 4, -2, 6, 8, 10, -3, 5]; diff --git a/tests/control_stream_and_control_spec.rs b/tests/control_stream_and_control_spec.rs index fd49a68..2dc0883 100644 --- a/tests/control_stream_and_control_spec.rs +++ b/tests/control_stream_and_control_spec.rs @@ -3,9 +3,62 @@ use twilic as twilic_rust; use twilic_rust::{ TwilicCodec, TwilicError, model::{ControlMessage, ControlStreamCodec, KeyRef, Message, MessageKind, Value}, - wire::Reader, + wire::{DECODE_OUTPUT_RATIO_MSG, Reader, encode_varuint}, }; +fn encode_control_stream_wire(codec: ControlStreamCodec, encoded_payload: &[u8]) -> Vec { + let mut out = vec![MessageKind::ControlStream as u8, codec as u8]; + encode_varuint(encoded_payload.len() as u64, &mut out); + out.extend_from_slice(encoded_payload); + out +} + +#[test] +fn control_stream_rle_rejects_decompression_bomb() { + let mut rle = Vec::new(); + encode_varuint(1, &mut rle); + encode_varuint(100_000, &mut rle); + rle.push(0x00); + let bytes = encode_control_stream_wire(ControlStreamCodec::Rle, &rle); + let mut codec = TwilicCodec::default(); + let err = codec + .decode_message(&bytes) + .expect_err("expected rle output ratio error"); + assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG)); +} + +#[test] +fn control_stream_huffman_rejects_decompression_bomb() { + let mut huff = vec![1]; + encode_varuint(1, &mut huff); + huff.push(0x00); + encode_varuint(100_000, &mut huff); + let bytes = encode_control_stream_wire(ControlStreamCodec::Huffman, &huff); + let mut codec = TwilicCodec::default(); + let err = codec + .decode_message(&bytes) + .expect_err("expected huffman output ratio error"); + assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG)); +} + +#[test] +fn control_stream_fse_rejects_decompression_bomb() { + let mut frame = vec![1]; + encode_varuint(100_000, &mut frame); + encode_varuint(1, &mut frame); + frame.push(0x00); + encode_varuint(2, &mut frame); + encode_varuint(0, &mut frame); + let mut fse = vec![3]; + fse.extend_from_slice(&frame); + let bytes = encode_control_stream_wire(ControlStreamCodec::Fse, &fse); + let mut codec = TwilicCodec::default(); + let err = codec + .decode_message(&bytes) + .expect_err("expected fse output ratio error"); + assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG)); +} + #[test] fn control_stream_roundtrips_for_all_declared_codecs() { let mut codec = TwilicCodec::default();