Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,15 @@ if(LAC_BUILD_TESTS)
target_compile_options(lac_cli_tests PRIVATE -UNDEBUG)
endif()
add_test(NAME lac_cli_tests COMMAND lac_cli_tests $<TARGET_FILE:lac_cli>)

add_executable(lac_rice_tests
tests/test_rice.cpp
)
target_link_libraries(lac_rice_tests PRIVATE lac)
if(MSVC)
target_compile_options(lac_rice_tests PRIVATE /UNDEBUG)
else()
target_compile_options(lac_rice_tests PRIVATE -UNDEBUG)
endif()
add_test(NAME lac_rice_tests COMMAND lac_rice_tests)
endif()
9 changes: 6 additions & 3 deletions src/codec/rice/rice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ int32_t Rice::unsigned_to_signed(uint32_t u) {
}

void Rice::encode(BitWriter& w, int32_t value, uint32_t k) {
uint32_t u = signed_to_unsigned(value);
const uint32_t u = signed_to_unsigned(value);

uint32_t q = u >> k;
uint32_t r = u & ((1u << k) - 1);
// k <= 31 by construction (adaptive k clamps to 31; static k is a u5 field),
// but guard the shifts regardless: `u >> k` and `1u << k` are undefined
// behaviour for k >= 32. Rice::decode already rejects k > 31.
const uint32_t q = (k >= 32u) ? 0u : (u >> k);
const uint32_t r = (k >= 32u) ? u : (u & ((uint32_t{1} << k) - 1u));

w.write_unary_ones(q);
w.write_bit(0);
Expand Down
14 changes: 10 additions & 4 deletions src/codec/simd/neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ inline void ms_encode_scalar_impl(const int32_t* L, const int32_t* R,
size_t n) {
if (!L || !R || !M || !S || n == 0) return;
for (size_t i = 0; i < n; ++i) {
int32_t l = L[i];
int32_t r = R[i];
M[i] = (l + r) >> 1;
S[i] = l - r;
const int32_t l = L[i];
const int32_t r = R[i];
// Compute the sum/difference in uint32 to avoid signed-overflow UB.
// This reproduces the wrapping semantics of the NEON vaddq_s32 /
// vsubq_s32 path bit-for-bit, and is identical to `l + r` / `l - r`
// for the validated 16/24-bit sample domain (which never overflows).
const int32_t sum =
static_cast<int32_t>(static_cast<uint32_t>(l) + static_cast<uint32_t>(r));
M[i] = sum >> 1;
S[i] = static_cast<int32_t>(static_cast<uint32_t>(l) - static_cast<uint32_t>(r));
}
}

Expand Down
65 changes: 65 additions & 0 deletions tests/test_rice.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Direct Rice coder tests: round-trip, signed-mapping boundaries, and the
// k > 31 rejection contract. Built with asserts forced on (-UNDEBUG) and run
// under the ASan/UBSan CI job to catch shift/overflow UB.
#include <cassert>
#include <cstdint>
#include <iostream>
#include <limits>
#include <vector>

#include "codec/bitstream/bit_reader.hpp"
#include "codec/bitstream/bit_writer.hpp"
#include "codec/rice/rice.hpp"

namespace {

bool roundtrip(int32_t value, uint32_t k) {
BitWriter w;
Rice::encode(w, value, k);
w.flush_to_byte();
std::vector<uint8_t> buf = w.take_buffer();
BitReader r(buf.data(), buf.size());
int32_t out = 0;
if (!Rice::decode(r, k, out)) return false;
return out == value;
}

} // namespace

int main() {
// Moderate values round-trip for every valid k. The unsigned-mapped value
// stays small, so the unary quotient is bounded even at k = 0.
for (uint32_t k = 0; k <= 31; ++k) {
for (int32_t v = -64; v <= 64; ++v) {
assert(roundtrip(v, k));
}
}

// Full-range boundary values, restricted to high k so the quotient is tiny.
// INT32_MIN maps to the maximum unsigned value (0xFFFFFFFF), which is the
// hardest case for the k == 31 shift boundary.
const int32_t kMin = std::numeric_limits<int32_t>::min();
const int32_t kMax = std::numeric_limits<int32_t>::max();
for (uint32_t k = 24; k <= 31; ++k) {
assert(roundtrip(0, k));
assert(roundtrip(1, k));
assert(roundtrip(-1, k));
assert(roundtrip(kMax, k));
assert(roundtrip(kMin, k));
}
assert(roundtrip(kMin, 31));
assert(roundtrip(kMax, 31));

// Decode must reject k > 31 (shifts by >= 32 are undefined behaviour).
{
const std::vector<uint8_t> data(8, 0xFFu);
for (uint32_t k = 32; k <= 40; ++k) {
BitReader r(data.data(), data.size());
int32_t out = 0;
assert(!Rice::decode(r, k, out));
}
}

std::cout << "rice tests ok\n";
return 0;
}
Loading