diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c25a67..8605e44 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,4 +80,15 @@ if(LAC_BUILD_TESTS) target_compile_options(lac_cli_tests PRIVATE -UNDEBUG) endif() add_test(NAME lac_cli_tests COMMAND lac_cli_tests $) + + add_executable(lac_rice_tests + tests/test_rice.cpp + ) + target_link_libraries(lac_rice_tests PRIVATE lac) + if(MSVC) + target_compile_options(lac_rice_tests PRIVATE /UNDEBUG) + else() + target_compile_options(lac_rice_tests PRIVATE -UNDEBUG) + endif() + add_test(NAME lac_rice_tests COMMAND lac_rice_tests) endif() diff --git a/src/codec/rice/rice.cpp b/src/codec/rice/rice.cpp index 08245b9..765013f 100644 --- a/src/codec/rice/rice.cpp +++ b/src/codec/rice/rice.cpp @@ -15,10 +15,13 @@ int32_t Rice::unsigned_to_signed(uint32_t u) { } void Rice::encode(BitWriter& w, int32_t value, uint32_t k) { - uint32_t u = signed_to_unsigned(value); + const uint32_t u = signed_to_unsigned(value); - uint32_t q = u >> k; - uint32_t r = u & ((1u << k) - 1); + // k <= 31 by construction (adaptive k clamps to 31; static k is a u5 field), + // but guard the shifts regardless: `u >> k` and `1u << k` are undefined + // behaviour for k >= 32. Rice::decode already rejects k > 31. + const uint32_t q = (k >= 32u) ? 0u : (u >> k); + const uint32_t r = (k >= 32u) ? u : (u & ((uint32_t{1} << k) - 1u)); w.write_unary_ones(q); w.write_bit(0); diff --git a/src/codec/simd/neon.cpp b/src/codec/simd/neon.cpp index bbe8b3a..f5e50ac 100644 --- a/src/codec/simd/neon.cpp +++ b/src/codec/simd/neon.cpp @@ -16,10 +16,16 @@ inline void ms_encode_scalar_impl(const int32_t* L, const int32_t* R, size_t n) { if (!L || !R || !M || !S || n == 0) return; for (size_t i = 0; i < n; ++i) { - int32_t l = L[i]; - int32_t r = R[i]; - M[i] = (l + r) >> 1; - S[i] = l - r; + const int32_t l = L[i]; + const int32_t r = R[i]; + // Compute the sum/difference in uint32 to avoid signed-overflow UB. + // This reproduces the wrapping semantics of the NEON vaddq_s32 / + // vsubq_s32 path bit-for-bit, and is identical to `l + r` / `l - r` + // for the validated 16/24-bit sample domain (which never overflows). + const int32_t sum = + static_cast(static_cast(l) + static_cast(r)); + M[i] = sum >> 1; + S[i] = static_cast(static_cast(l) - static_cast(r)); } } diff --git a/tests/test_rice.cpp b/tests/test_rice.cpp new file mode 100644 index 0000000..5ab2a72 --- /dev/null +++ b/tests/test_rice.cpp @@ -0,0 +1,65 @@ +// Direct Rice coder tests: round-trip, signed-mapping boundaries, and the +// k > 31 rejection contract. Built with asserts forced on (-UNDEBUG) and run +// under the ASan/UBSan CI job to catch shift/overflow UB. +#include +#include +#include +#include +#include + +#include "codec/bitstream/bit_reader.hpp" +#include "codec/bitstream/bit_writer.hpp" +#include "codec/rice/rice.hpp" + +namespace { + +bool roundtrip(int32_t value, uint32_t k) { + BitWriter w; + Rice::encode(w, value, k); + w.flush_to_byte(); + std::vector buf = w.take_buffer(); + BitReader r(buf.data(), buf.size()); + int32_t out = 0; + if (!Rice::decode(r, k, out)) return false; + return out == value; +} + +} // namespace + +int main() { + // Moderate values round-trip for every valid k. The unsigned-mapped value + // stays small, so the unary quotient is bounded even at k = 0. + for (uint32_t k = 0; k <= 31; ++k) { + for (int32_t v = -64; v <= 64; ++v) { + assert(roundtrip(v, k)); + } + } + + // Full-range boundary values, restricted to high k so the quotient is tiny. + // INT32_MIN maps to the maximum unsigned value (0xFFFFFFFF), which is the + // hardest case for the k == 31 shift boundary. + const int32_t kMin = std::numeric_limits::min(); + const int32_t kMax = std::numeric_limits::max(); + for (uint32_t k = 24; k <= 31; ++k) { + assert(roundtrip(0, k)); + assert(roundtrip(1, k)); + assert(roundtrip(-1, k)); + assert(roundtrip(kMax, k)); + assert(roundtrip(kMin, k)); + } + assert(roundtrip(kMin, 31)); + assert(roundtrip(kMax, 31)); + + // Decode must reject k > 31 (shifts by >= 32 are undefined behaviour). + { + const std::vector data(8, 0xFFu); + for (uint32_t k = 32; k <= 40; ++k) { + BitReader r(data.data(), data.size()); + int32_t out = 0; + assert(!Rice::decode(r, k, out)); + } + } + + std::cout << "rice tests ok\n"; + return 0; +}