Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
model::VectorCodec,
wire::{
DEFAULT_MAX_DECODE_COUNT, Reader, check_byte_len, check_decode_count, check_element_bytes,
decode_zigzag, encode_varuint, encode_zigzag, extend_repeat,
decode_zigzag, encode_varuint, encode_zigzag, extend_repeat_with_budget,
},
};

Expand Down Expand Up @@ -175,12 +175,18 @@ fn encode_u64_rle(values: &[u64], out: &mut Vec<u8>) {
}

fn decode_u64_rle(reader: &mut Reader<'_>) -> Result<Vec<u64>> {
let start = reader.position();
let runs_len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?;
let mut out = Vec::new();
let mut runs = Vec::with_capacity(runs_len);
for _ in 0..runs_len {
let value = reader.read_varuint()?;
let count = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?;
extend_repeat(&mut out, value, count)?;
runs.push((value, count));
}
let budget_input = reader.position() - start;
let mut out = Vec::new();
for (value, count) in runs {
extend_repeat_with_budget(&mut out, value, count, 8, Some(budget_input))?;
Comment thread
cursor[bot] marked this conversation as resolved.
}
Ok(out)
}
Expand Down Expand Up @@ -461,12 +467,18 @@ fn encode_i64_rle(values: &[i64], out: &mut Vec<u8>) {
}

fn decode_i64_rle(reader: &mut Reader<'_>) -> Result<Vec<i64>> {
let start = reader.position();
let runs_len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?;
let mut out = Vec::new();
let mut runs = Vec::with_capacity(runs_len);
for _ in 0..runs_len {
let value = decode_zigzag(reader.read_varuint()?);
let count = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?;
extend_repeat(&mut out, value, count)?;
runs.push((value, count));
}
let budget_input = reader.position() - start;
let mut out = Vec::new();
for (value, count) in runs {
extend_repeat_with_budget(&mut out, value, count, 8, Some(budget_input))?;
}
Ok(out)
}
Expand Down
12 changes: 11 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use error::{Result, TwilicError};
pub use model::{Message, Schema, Value};
pub use protocol::{SessionEncoder, TwilicCodec};
pub use session::{SessionOptions, UnknownReferencePolicy};
pub use wire::DEFAULT_MAX_DECODE_COUNT;
pub use wire::{DEFAULT_MAX_DECODE_COUNT, DEFAULT_MAX_DECODE_OUTPUT_RATIO};

pub fn encode(value: &Value) -> Result<Vec<u8>> {
v2::encode(value)
Expand Down Expand Up @@ -90,6 +90,16 @@ mod tests {
assert!(err.to_string().contains(DECODE_COUNT_LIMIT_MSG));
}

#[test]
fn wire_extend_repeat_rejects_output_ratio_bomb() {
use crate::wire::{DECODE_OUTPUT_RATIO_MSG, extend_repeat_with_budget};
let mut out = Vec::new();
let err = extend_repeat_with_budget(&mut out, 0u8, 100_000, 1, Some(8))
.expect_err("expected output ratio error");
assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG));
assert!(out.is_empty());
}

#[test]
fn roundtrip_dynamic_value() {
let value = Value::Map(vec![
Expand Down
38 changes: 31 additions & 7 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ use crate::{
DictionaryFallback, DictionaryProfile, SessionOptions, SessionState, UnknownReferencePolicy,
},
wire::{
DEFAULT_MAX_DECODE_COUNT, Reader, check_decode_count, decode_zigzag, encode_bitmap,
encode_bytes, encode_string, encode_varuint, encode_zigzag, extend_repeat,
DEFAULT_MAX_DECODE_COUNT, Reader, check_decode_count, check_decode_output_bytes,
decode_zigzag, encode_bitmap, encode_bytes, encode_string, encode_varuint, encode_zigzag,
extend_repeat_with_budget, max_decode_output_bytes,
},
};

Expand Down Expand Up @@ -2922,11 +2923,16 @@ fn rle_encode_bytes(input: &[u8]) -> Vec<u8> {
fn rle_decode_bytes(input: &[u8]) -> Result<Vec<u8>> {
let mut reader = Reader::new(input);
let run_count = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?;
let mut out = Vec::new();
let mut runs = Vec::with_capacity(run_count);
for _ in 0..run_count {
let len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?;
let byte = reader.read_u8()?;
extend_repeat(&mut out, byte, len)?;
runs.push((byte, len));
}
let budget_input = reader.position();
let mut out = Vec::new();
for (byte, len) in runs {
extend_repeat_with_budget(&mut out, byte, len, 1, Some(budget_input))?;
Comment thread
cursor[bot] marked this conversation as resolved.
}
if !reader.is_eof() {
return Err(TwilicError::InvalidData(
Expand Down Expand Up @@ -3065,20 +3071,21 @@ fn control_huffman_decode_bytes(input: &[u8]) -> Result<Vec<u8>> {
}
let total = freqs.iter().map(|f| *f as usize).sum::<usize>();
check_decode_count(total, DEFAULT_MAX_DECODE_COUNT)?;
check_decode_output_bytes(total, input.len())?;
if total == 0 {
return Ok(Vec::new());
}
let (nodes, root) = build_huffman_tree(&freqs)
.ok_or(TwilicError::InvalidData("control stream huffman tree"))?;
if let HuffNode::Leaf(symbol) = nodes[root] {
let mut out = Vec::new();
extend_repeat(&mut out, symbol, total)?;
extend_repeat_with_budget(&mut out, symbol, total, 1, Some(input.len()))?;
return Ok(out);
}

let remaining = input.len().saturating_sub(reader.position());
let bitstream = reader.read_exact(remaining)?;
let mut out = Vec::with_capacity(total);
let mut out = Vec::with_capacity(total.min(max_decode_output_bytes(input.len())));
let mut byte_idx = 0usize;
let mut bit_idx = 0u8;
for _ in 0..total {
Expand Down Expand Up @@ -3219,6 +3226,7 @@ fn control_fse_frame_decode(input: &[u8]) -> Result<Vec<u8>> {
}
let table_size = 1u32 << table_log;
let len = reader.read_bounded_count(DEFAULT_MAX_DECODE_COUNT)?;
check_decode_output_bytes(len, input.len())?;
let used = reader.read_bounded_count(256)?;
if used > 256 || used > table_size as usize {
return Err(TwilicError::InvalidData("control stream fse used symbols"));
Expand Down Expand Up @@ -3280,7 +3288,7 @@ fn control_fse_frame_decode(input: &[u8]) -> Result<Vec<u8>> {
let mut renorm_idx = renorm.len();
let mask = table_size - 1;

let mut out = Vec::with_capacity(len);
let mut out = Vec::with_capacity(len.min(max_decode_output_bytes(input.len())));
for _ in 0..len {
let slot = (state & mask) as usize;
let symbol = *decode_table
Expand Down Expand Up @@ -4141,4 +4149,20 @@ mod tests {

assert_eq!(find_template_id(&templates, &probe), Some(2));
}

#[test]
fn rle_decode_bytes_uses_consumed_run_budget_not_input_len() {
let mut encoded = Vec::new();
encode_varuint(1, &mut encoded);
encode_varuint(5_000, &mut encoded);
encoded.push(0xAB);
encoded.push(0);

let err = rle_decode_bytes(&encoded).expect_err("expected output ratio guard");
assert!(
err.to_string()
.contains(crate::wire::DECODE_OUTPUT_RATIO_MSG),
"unexpected error: {err}"
);
}
}
37 changes: 37 additions & 0 deletions src/wire.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::error::{Result, TwilicError};

pub const DEFAULT_MAX_DECODE_COUNT: usize = 1 << 20;
pub const DEFAULT_MAX_DECODE_OUTPUT_RATIO: usize = 1 << 10;

pub const DECODE_COUNT_LIMIT_MSG: &str = "decode count limit exceeded";
pub const DECODE_LENGTH_OVERFLOW_MSG: &str = "decode length overflow";
pub const DECODE_OUTPUT_RATIO_MSG: &str = "decode output ratio exceeded";

#[inline]
pub fn check_decode_count(count: usize, max: usize) -> Result<()> {
Expand Down Expand Up @@ -35,12 +37,43 @@ pub fn check_element_bytes(
}
}

#[inline]
pub fn max_decode_output_bytes(input_len: usize) -> usize {
input_len
.saturating_mul(DEFAULT_MAX_DECODE_OUTPUT_RATIO)
.min(DEFAULT_MAX_DECODE_COUNT)
}

#[inline]
pub fn check_decode_output_bytes(output_bytes: usize, input_len: usize) -> Result<()> {
if output_bytes > max_decode_output_bytes(input_len) {
return Err(TwilicError::InvalidData(DECODE_OUTPUT_RATIO_MSG));
}
Ok(())
}

pub fn extend_repeat<T: Clone>(out: &mut Vec<T>, value: T, count: usize) -> Result<()> {
extend_repeat_with_budget(out, value, count, 1, None)
}

pub fn extend_repeat_with_budget<T: Clone>(
out: &mut Vec<T>,
value: T,
count: usize,
element_bytes: usize,
input_len: Option<usize>,
) -> Result<()> {
let new_len = out
.len()
.checked_add(count)
.ok_or(TwilicError::InvalidData(DECODE_COUNT_LIMIT_MSG))?;
check_decode_count(new_len, DEFAULT_MAX_DECODE_COUNT)?;
if let Some(input_len) = input_len {
let output_bytes = new_len
.checked_mul(element_bytes)
.ok_or(TwilicError::InvalidData(DECODE_OUTPUT_RATIO_MSG))?;
check_decode_output_bytes(output_bytes, input_len)?;
}
out.extend(std::iter::repeat_n(value, count));
Ok(())
}
Expand Down Expand Up @@ -116,6 +149,10 @@ impl<'a> Reader<'a> {
self.offset
}

pub fn input_len(&self) -> usize {
self.input.len()
}

pub fn remaining(&self) -> usize {
self.input.len().saturating_sub(self.offset)
}
Expand Down
19 changes: 18 additions & 1 deletion tests/codec_spec_vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,26 @@ use twilic_rust::{
encode_i64_vector, encode_u64_vector,
},
model::VectorCodec,
wire::{Reader, encode_varuint},
wire::{DECODE_OUTPUT_RATIO_MSG, Reader, encode_varuint},
};

#[test]
fn vector_rle_rejects_decompression_bomb_despite_trailing_column_bytes() {
let mut rle = Vec::new();
encode_varuint(1, &mut rle);
encode_varuint(0, &mut rle);
encode_varuint(100_000, &mut rle);
let rle_byte_len = rle.len();
let mut bytes = rle;
bytes.extend(std::iter::repeat_n(0u8, 16 * 1024));

let mut reader = Reader::new(&bytes);
let err =
decode_u64_vector(&mut reader, VectorCodec::Rle).expect_err("expected output ratio error");
assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG));
assert_eq!(reader.position(), rle_byte_len);
}

#[test]
fn simple8b_i64_roundtrip_small_values() {
let values = vec![1, 2, 3, -1, 0, 4, -2, 6, 8, 10, -3, 5];
Expand Down
55 changes: 54 additions & 1 deletion tests/control_stream_and_control_spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,62 @@ use twilic as twilic_rust;
use twilic_rust::{
TwilicCodec, TwilicError,
model::{ControlMessage, ControlStreamCodec, KeyRef, Message, MessageKind, Value},
wire::Reader,
wire::{DECODE_OUTPUT_RATIO_MSG, Reader, encode_varuint},
};

fn encode_control_stream_wire(codec: ControlStreamCodec, encoded_payload: &[u8]) -> Vec<u8> {
let mut out = vec![MessageKind::ControlStream as u8, codec as u8];
encode_varuint(encoded_payload.len() as u64, &mut out);
out.extend_from_slice(encoded_payload);
out
}

#[test]
fn control_stream_rle_rejects_decompression_bomb() {
let mut rle = Vec::new();
encode_varuint(1, &mut rle);
encode_varuint(100_000, &mut rle);
rle.push(0x00);
let bytes = encode_control_stream_wire(ControlStreamCodec::Rle, &rle);
let mut codec = TwilicCodec::default();
let err = codec
.decode_message(&bytes)
.expect_err("expected rle output ratio error");
assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG));
}

#[test]
fn control_stream_huffman_rejects_decompression_bomb() {
let mut huff = vec![1];
encode_varuint(1, &mut huff);
huff.push(0x00);
encode_varuint(100_000, &mut huff);
let bytes = encode_control_stream_wire(ControlStreamCodec::Huffman, &huff);
let mut codec = TwilicCodec::default();
let err = codec
.decode_message(&bytes)
.expect_err("expected huffman output ratio error");
assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG));
}

#[test]
fn control_stream_fse_rejects_decompression_bomb() {
let mut frame = vec![1];
encode_varuint(100_000, &mut frame);
encode_varuint(1, &mut frame);
frame.push(0x00);
encode_varuint(2, &mut frame);
encode_varuint(0, &mut frame);
let mut fse = vec![3];
fse.extend_from_slice(&frame);
let bytes = encode_control_stream_wire(ControlStreamCodec::Fse, &fse);
let mut codec = TwilicCodec::default();
let err = codec
.decode_message(&bytes)
.expect_err("expected fse output ratio error");
assert!(err.to_string().contains(DECODE_OUTPUT_RATIO_MSG));
}

#[test]
fn control_stream_roundtrips_for_all_declared_codecs() {
let mut codec = TwilicCodec::default();
Expand Down