From 0b4945be7cc593f2ff96193e9fcec1b26cb29ede Mon Sep 17 00:00:00 2001 From: audexdev Date: Thu, 4 Jun 2026 23:08:18 +0900 Subject: [PATCH 1/6] Optimize LAC block throughput and parallel decode --- CHANGELOG.md | 5 +- CMakeLists.txt | 1 + README.md | 11 +- docs/format.md | 28 +- docs/supported-formats.md | 13 +- src/codec/bitstream/bit_reader.cpp | 37 +- src/codec/bitstream/bit_writer.cpp | 38 ++ src/codec/bitstream/bit_writer.hpp | 7 +- src/codec/block/decoder.cpp | 10 +- src/codec/block/encoder.cpp | 375 +++++-------------- src/codec/block/encoder.hpp | 6 - src/codec/frame/frame_header.hpp | 6 +- src/codec/lac/decoder.cpp | 123 +++++- src/codec/lac/decoder.hpp | 2 + src/codec/lac/encoder.cpp | 582 +++++++++++------------------ src/codec/lac/encoder.hpp | 9 - src/codec/lac/thread_limit.hpp | 35 ++ src/codec/rice/rice.cpp | 12 +- src/main.cpp | 144 ++++++- tests/test_cli.cpp | 189 +++++++++- tests/test_e2e.cpp | 346 ++++++++++++++++- tests/test_lpc.cpp | 6 + tests/test_partitioning.cpp | 34 ++ 23 files changed, 1282 insertions(+), 737 deletions(-) create mode 100644 src/codec/lac/thread_limit.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index a657640..37bdd38 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ All notable user-facing changes should be documented here. LAC is still experime - Added CI coverage for Debug tests, Release builds, and ASan/UBSan smoke tests on GitHub Actions. - Added self-contained generated WAV fixtures for clean-checkout CI test runs. - Updated E2E tests to read back temporary `.lac` files before decode. -- Added encoder thread limiting through `LAC_THREADS` and `lac_cli encode --threads=N`. +- Added codec worker limiting through `LAC_THREADS` and `lac_cli encode|decode --threads=N`. - Clarified the CLI-first PCM WAV roundtrip contract and canonical restored-WAV behavior. - Hardened WAV parsing against inconsistent RIFF sizes, malformed chunk boundaries, non-canonical PCM metadata, empty payloads, and unchecked data-chunk allocation. - Hardened `.lac` decoding against non-canonical reserved fields, stereo flags, residual metadata, padding, trailing payload bytes, out-of-range restored samples, oversized decoded allocations, and malformed Rice values. @@ -16,6 +16,9 @@ All notable user-facing changes should be documented here. LAC is still experime - Rejected encode or decode commands whose output path refers to the input file. - Fixed canonical RIFF padding for odd-sized restored PCM payloads and tightened close-time write error handling. - Bounded tiny-block decoder work, made extreme zigzag decoding portable, and corrected the LPC reconstruction specification. +- Staged CLI output beside the requested path and published it only after successful close, preserving existing files and linked targets on failed writes. +- Added format version 3 compressed block boundaries and bounded parallel decode while retaining serial decode compatibility for canonical version 2 streams. +- Switched encode planning to bounded `16384`-sample windows, added sampled automatic stereo probes for ambiguous blocks, and reduced residual estimation and Rice bit-output work. - Removed tracked editor cache files and generated compile database symlink from source control. - Expanded repository roadmap tracking for correctness, fuzzing, security hardening, and release readiness. diff --git a/CMakeLists.txt b/CMakeLists.txt index 85bee9b..20a1e68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ add_library(lac STATIC src/codec/frame/frame_header.hpp src/codec/lac/decoder.hpp src/codec/lac/encoder.hpp + src/codec/lac/thread_limit.hpp src/codec/lac/thread_collector.hpp src/codec/lpc/lpc.hpp src/codec/rice/rice.hpp diff --git a/README.md b/README.md index 2f3760c..f282092 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![CodeQL](https://github.com/audexdev/Lossless-Audio-Codec/actions/workflows/codeql.yml/badge.svg)](https://github.com/audexdev/Lossless-Audio-Codec/actions/workflows/codeql.yml) [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE) -LAC is an experimental CLI-first C++20 lossless audio codec for PCM WAV audio. It is a compact implementation of a custom `.lac` container and bitstream with LPC prediction, adaptive Rice coding, mid/side stereo, zero-run residual coding, residual partitioning, Apple Silicon NEON acceleration, and multithreaded block encoding. +LAC is an experimental CLI-first C++20 lossless audio codec for PCM WAV audio. It is a compact implementation of a custom `.lac` container and bitstream with LPC prediction, adaptive Rice coding, mid/side stereo, zero-run residual coding, residual partitioning, Apple Silicon NEON acceleration, and multithreaded block encoding and decoding. The current product contract is `lac_cli encode` followed by `lac_cli decode` for the documented PCM WAV domain. The project is intended for codec experimentation, implementation study, and reproducible work on lossless audio compression internals. The file format is still experimental and should not yet be treated as a long-term archival format. @@ -68,12 +68,13 @@ Choose a stereo mode explicitly: ./build/lac_cli encode input.wav output.lac --stereo-mode=ms ``` -Stereo encoding defaults to automatic per-block LR or mid/side selection. Mono input always uses LR metadata. Restored WAV files preserve PCM samples, channel count, sample rate, and bit depth, but ancillary WAV chunks are not copied. Input and output paths must refer to different files. +Stereo encoding defaults to automatic per-block LR or mid/side selection. Mono input always uses LR metadata. Restored WAV files preserve PCM samples, channel count, sample rate, and bit depth, but ancillary WAV chunks are not copied. Input and output paths must refer to different files. The CLI stages output beside the requested path and publishes it only after the completed file closes successfully. -Limit encoder worker threads: +Limit codec worker threads: ```sh ./build/lac_cli encode input.wav output.lac --threads=12 +./build/lac_cli decode output.lac restored.wav --threads=12 LAC_THREADS=12 ./build/lac_cli encode input.wav output.lac ``` @@ -96,7 +97,7 @@ LAC_THREADS=4 ctest --test-dir build-tests --output-on-failure The default CTest configuration uses lightweight generated WAV fixtures and exercises both internal codec paths and `lac_cli` subprocess roundtrips. To opt into larger local E2E fixtures, configure with `-DLAC_TEST_ASSETS_DIR="$PWD/assets"`. The generated fixtures keep clean checkouts and routine development self-contained. -Set `LAC_THREADS=N` to cap encoder worker threads during tests. The heavier `test_all.sh` asset roundtrip script defaults to `LAC_THREADS=12` unless the environment already sets a different value. +Set `LAC_THREADS=N` to cap encode and decode worker threads during tests. The heavier `test_all.sh` asset roundtrip script defaults to `LAC_THREADS=12` unless the environment already sets a different value. ## Contributing @@ -104,7 +105,7 @@ Contribution setup, review expectations, and local development commands are docu ## Format -The current `.lac` bitstream is documented in [docs/format.md](docs/format.md). The format is versioned internally as frame version `2`, but it is not yet frozen for external compatibility. +The current `.lac` bitstream is documented in [docs/format.md](docs/format.md). The encoder emits frame version `3`, while the decoder retains serial compatibility for canonical version `2` streams. The format is not yet frozen for external compatibility. Supported WAV/PCM input and output constraints are documented in [docs/supported-formats.md](docs/supported-formats.md). ## Security diff --git a/docs/format.md b/docs/format.md index 44864f6..2211734 100644 --- a/docs/format.md +++ b/docs/format.md @@ -18,20 +18,34 @@ This document describes the current experimental `.lac` format implemented by th ```text FrameHeader u32 block_count -u32 block_size[block_count] +repeat block_count: + u32 block_size + u32 compressed_size_bytes BlockPayload[block_count] ``` -The block-size table gives the number of samples per channel in each block. Every channel payload in the same block uses the same `block_size`. +The version `3` block table gives the number of samples per channel and the complete encoded byte length for each block. `compressed_size_bytes` covers the optional per-block stereo flag and every byte-padded channel block. Every channel payload in the same block uses the same `block_size`. Current top-level limits: - `block_count` must be non-zero. -- `block_count` must not exceed `1048576`, and the complete block-size table must be present before allocation. +- `block_count` must not exceed `1048576`, and the complete block table must be present before allocation. - Each `block_size` must be non-zero and no larger than `16384` samples per channel. +- Each `compressed_size_bytes` must be non-zero, and their sum must exactly match the remaining frame payload bytes. - Every non-final block must contain at least `256` samples per channel. The final block may be shorter. - The total declared sample count must fit within the implementation's 1 GiB decoded-PCM allocation limit and the classic RIFF/WAV output size limit. +Version `2` streams use the legacy table layout: + +```text +FrameHeader(version = 2) +u32 block_count +u32 block_size[block_count] +BlockPayload[block_count] +``` + +Version `2` remains decode-compatible, but it does not carry encoded block boundaries and is decoded serially. + ## Frame Header The frame header is 80 bits, currently 10 bytes: @@ -39,7 +53,7 @@ The frame header is 80 bits, currently 10 bytes: | Field | Bits | Meaning | | --- | ---: | --- | | sync | 16 | `0x4C41` (`LA`) | -| version | 8 | current format version, `2` | +| version | 8 | current encoder format version, `3`; legacy decode also accepts `2` | | channels | 8 | `1` mono or `2` stereo | | stereo_mode | 8 | `0` LR, `1` mid/side, `2` per-block stereo | | sample_rate_low | 16 | low 16 bits of sample rate | @@ -373,10 +387,10 @@ Each channel block is flushed to the next byte boundary after residual encoding. ## Integrity -The current format does not include a checksum, frame CRC, block CRC, or authenticated length field. Decoders validate structural fields strictly and reject trailing garbage, impossible block sizes, invalid residual tags, Rice values or predictor reconstruction outside the signed 32-bit domain, non-zero reserved fields or padding, and non-canonical metadata. Without an integrity field, a modified payload can still decode successfully if it remains structurally valid and produces in-range PCM samples. +The current format does not include a checksum, frame CRC, block CRC, or authenticated length field. Version `3` compressed block lengths are structural boundaries, not integrity protection. Decoders validate structural fields strictly and reject trailing garbage, impossible block sizes, mismatched version `3` payload lengths, invalid residual tags, Rice values or predictor reconstruction outside the signed 32-bit domain, non-zero reserved fields or padding, and non-canonical metadata. Without an integrity field, a modified payload can still decode successfully if it remains structurally valid and produces in-range PCM samples. ## Compatibility -The format version is currently `2`, but the format is still experimental. Future work may add stronger validation, checksums, fuzzed compatibility tests, streaming decode constraints, or a frozen public specification. +The encoder currently emits format version `3`, but the format is still experimental. Version `3` adds compressed block byte lengths so blocks can be validated and decoded independently. Future work may add stronger validation, checksums, fuzzed compatibility tests, streaming decode constraints, or a frozen public specification. -The canonical version `2` byte sequences emitted by the encoder are unchanged by decoder hardening. Hardened decoders may reject version `2` byte sequences that older permissive decoders accepted when those sequences contain non-canonical reserved fields, metadata, padding, stereo flags, block tables, or trailing payload bytes. +The decoder retains serial compatibility for canonical version `2` streams. Hardened decoders may reject version `2` byte sequences that older permissive decoders accepted when those sequences contain non-canonical reserved fields, metadata, padding, stereo flags, block tables, or trailing payload bytes. diff --git a/docs/supported-formats.md b/docs/supported-formats.md index 28e74f5..5a947ab 100644 --- a/docs/supported-formats.md +++ b/docs/supported-formats.md @@ -44,18 +44,18 @@ supported WAV -> lac_cli encode -> .lac -> lac_cli decode -> restored WAV For supported input, the restored WAV has the same PCM samples, channel count, sample rate, and bit depth. The restored file is a canonical PCM WAV with a 16-byte `fmt ` chunk followed by one `data` chunk. When the PCM payload size is odd, the writer appends the required zero RIFF padding byte after the `data` payload. Ancillary chunks and their metadata are intentionally not copied. Encode defaults to per-block stereo selection for stereo input; `--stereo-mode=lr` and `--stereo-mode=ms` force a mode. Mono input always uses LR mode metadata. -Encode and decode reject input and output paths that refer to the same file. This prevents an ordinary in-place command from clobbering its source. Output publication is not hardened against concurrent path replacement in a directory modified by another process. +Encode and decode reject input and output paths that refer to the same file, including a second check immediately before output publication. The CLI writes a complete output into a temporary sibling directory and replaces the final path only after the writer closes successfully. A failed write or publication leaves an existing final output untouched, and replacing an existing symlink or hardlink output does not stream bytes into its prior target. Publication is not `fsync`-backed crash durability, and the CLI is not a filesystem access-control boundary for directories concurrently modified by untrusted processes. The alpha CLI surface is: | Command or option | Behavior | | --- | --- | | `lac_cli encode input.wav output.lac` | encode a supported WAV; stereo input defaults to per-block stereo selection | -| `lac_cli decode input.lac output.wav` | decode one canonical `.lac` stream to canonical PCM WAV | +| `lac_cli decode input.lac output.wav` | decode one supported canonical version `2` or `3` `.lac` stream to canonical PCM WAV | | `--stereo-mode=lr` | force LR stereo payloads during encode | | `--stereo-mode=ms` | force mid/side stereo payloads during encode | -| `--threads=N` | cap encoder workers to positive integer `N`; overrides `LAC_THREADS` | -| `LAC_THREADS=N` | cap encoder workers when `--threads=N` is absent | +| `--threads=N` | cap encode or decode workers to positive integer `N`; overrides `LAC_THREADS` | +| `LAC_THREADS=N` | cap encode or decode workers when `--threads=N` is absent | | `--no-partitioning` | disable residual partitioning during encode | The CLI may overwrite an existing output file when it is distinct from the input file. Flags beginning with `--debug-` are diagnostic implementation aids and are not part of the stable alpha contract. @@ -66,7 +66,7 @@ The CLI may overwrite an existing output file when it is distinct from the input The current `.lac` container supports: -- format version `2` +- canonical encode format version `3`; legacy version `2` remains decode-compatible - mono or stereo streams - LR, mid/side, or per-block stereo mode - block sizes up to `16384` samples per channel @@ -77,12 +77,13 @@ See `docs/format.md` for bitstream details. ## Current Limits -- The format has no checksum, frame CRC, block CRC, or authenticated payload length. +- The format has no checksum, frame CRC, block CRC, or authenticated payload length. Version `3` block lengths are structural boundaries only. - The WAV reader and LAC decoder reject decoded PCM output above 1 GiB. - `lac_cli decode` rejects compressed `.lac` input files above 1 GiB before loading them into memory. - The decoder rejects non-canonical structural metadata, trailing payload bytes, out-of-range restored samples, and output that cannot fit classic RIFF/WAV 32-bit size fields. - The decoder rejects more than `1048576` blocks and non-final blocks shorter than `256` samples to bound block-table metadata and tiny-block decode work. - Large-file handling is still bounded by classic RIFF/WAV 32-bit size fields on output. +- On Windows, CLI paths currently pass through narrow argument and stream APIs. Non-ASCII paths outside the active code page are not guaranteed. - Streaming encode/decode is not currently exposed as a public API. - Structurally valid payload corruption can still produce different in-range PCM because the format has no integrity field. - Malformed-input validation continues to be improved through security hardening issues and fuzzing work. diff --git a/src/codec/bitstream/bit_reader.cpp b/src/codec/bitstream/bit_reader.cpp index 1265a7c..0039d88 100644 --- a/src/codec/bitstream/bit_reader.cpp +++ b/src/codec/bitstream/bit_reader.cpp @@ -42,34 +42,23 @@ uint32_t BitReader::read_bits(int nbits) { if (nbits <= 0) return 0; if (!this->error && static_cast(nbits) <= this->bits_remaining()) { - size_t bit_index = this->byte_pos * 8 + static_cast(this->bit_pos); - int remaining = nbits; uint32_t value = 0; - + int remaining = nbits; while (remaining > 0) { - const size_t byte_index = bit_index / 8; - const int bit_offset = static_cast(bit_index % 8); - - uint64_t window = 0; - int window_bits = 0; - size_t i = byte_index; - while (window_bits < 64 && i < this->size) { - window = (window << 8) | this->data[i++]; - window_bits += 8; - } - - const int usable = window_bits - bit_offset; - const int take = std::min(remaining, std::min(usable, 24)); - const int shift = usable - take; - const uint32_t chunk = static_cast((window >> shift) & low_bits_mask(take)); - - value = static_cast((static_cast(value) << take) | chunk); - bit_index += static_cast(take); + const int available = 8 - this->bit_pos; + const int take = std::min(remaining, available); + const int shift = available - take; + const uint32_t chunk = + (static_cast(this->data[this->byte_pos]) >> shift) & + low_bits_mask(take); + value = static_cast((value << take) | chunk); remaining -= take; + this->bit_pos += take; + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } } - - this->byte_pos = bit_index / 8; - this->bit_pos = static_cast(bit_index % 8); return value; } diff --git a/src/codec/bitstream/bit_writer.cpp b/src/codec/bitstream/bit_writer.cpp index 0cf305b..00b41aa 100644 --- a/src/codec/bitstream/bit_writer.cpp +++ b/src/codec/bitstream/bit_writer.cpp @@ -1,5 +1,6 @@ #include "bit_writer.hpp" #include +#include namespace { inline uint32_t low_bits_mask(int bits) { @@ -68,6 +69,39 @@ void BitWriter::write_bits(uint32_t value, int nbits) { } } +void BitWriter::write_unary_ones(uint32_t ones) { + while (this->bit_pos != 0 && ones > 0) { + this->write_bit(1u); + --ones; + } + + if (ones >= 8) { + const size_t bytes = ones / 8; + this->buffer.insert(this->buffer.end(), bytes, 0xFFu); + ones -= static_cast(bytes * 8); + } + + while (ones > 0) { + this->write_bit(1u); + --ones; + } +} + +void BitWriter::write_bytes(const uint8_t* data, size_t size) { + if (data == nullptr || size == 0) return; + if (this->bit_pos == 0) { + this->buffer.insert(this->buffer.end(), data, data + size); + return; + } + for (size_t i = 0; i < size; ++i) { + this->write_bits(data[i], 8); + } +} + +void BitWriter::reserve_bytes(size_t size) { + this->buffer.reserve(size); +} + void BitWriter::flush_to_byte() { if (this->bit_pos == 0) return; @@ -79,3 +113,7 @@ void BitWriter::flush_to_byte() { const std::vector& BitWriter::get_buffer() const { return this->buffer; } + +std::vector BitWriter::take_buffer() { + return std::move(this->buffer); +} diff --git a/src/codec/bitstream/bit_writer.hpp b/src/codec/bitstream/bit_writer.hpp index dafc537..bdc1d42 100644 --- a/src/codec/bitstream/bit_writer.hpp +++ b/src/codec/bitstream/bit_writer.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include #include @@ -8,11 +9,15 @@ class BitWriter { void write_bit(uint32_t bit); void write_bits(uint32_t value, int nbits); + void write_unary_ones(uint32_t ones); + void write_bytes(const uint8_t* data, size_t size); + void reserve_bytes(size_t size); void flush_to_byte(); const std::vector& get_buffer() const; + std::vector take_buffer(); private: std::vector buffer; uint8_t current_byte; int bit_pos; -}; \ No newline at end of file +}; diff --git a/src/codec/block/decoder.cpp b/src/codec/block/decoder.cpp index c7bccfc..3004b21 100644 --- a/src/codec/block/decoder.cpp +++ b/src/codec/block/decoder.cpp @@ -177,8 +177,14 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o LAC_DEBUG_LOG("[part-val] idx=" << (offset + idx - 1) << " v=0 k=" << current_k << "\n"); } - ++count; - current_k = adapt_k(sumU, count, stateless, adapt_state); + if (!stateless) { + ++count; + current_k = adapt_k(sumU, count, false, adapt_state); + } + } + if (stateless) { + count += run_len; + current_k = adapt_k(sumU, count, true, nullptr); } if (debug_zr) { LAC_DEBUG_LOG("[zr-decode] run len=" << run_len << " idx=" << idx << " err=" << reader.has_error() << "\n"); diff --git a/src/codec/block/encoder.cpp b/src/codec/block/encoder.cpp index 48c6019..a7455e6 100644 --- a/src/codec/block/encoder.cpp +++ b/src/codec/block/encoder.cpp @@ -4,8 +4,10 @@ #include "codec/bitstream/bit_writer.hpp" #include "utils/logger.hpp" #include +#include #include #include +#include #include #include @@ -65,18 +67,13 @@ inline uint64_t rice_bits_for_unsigned(uint32_t u, uint32_t k) { inline uint32_t adapt_k_stateless(uint64_t sum, uint32_t count) { if (count == 0) return 0; const uint64_t mean = (sum + (count >> 1)) / count; - uint32_t k = 0; - while ((1u << k) < mean && k < 31u) { - ++k; - } - return k; + if (mean <= 1) return 0; + return std::min(31u, std::bit_width(mean - 1u)); } inline void write_rice_unsigned(BitWriter& bw, uint32_t value, uint32_t k) { const uint32_t q = (k >= 31u) ? 0u : (value >> k); - for (uint32_t i = 0; i < q; ++i) { - bw.write_bit(1u); - } + bw.write_unary_ones(q); bw.write_bit(0u); if (k > 0) { const uint32_t rem_mask = (1u << k) - 1u; @@ -116,24 +113,7 @@ inline std::vector partition_sizes_for_block(uint32_t block_size, uint return sizes; } -bool has_zero_run(const std::vector& residual) { - size_t idx = 0; - while (idx < residual.size()) { - if (residual[idx] == 0) { - size_t run = 0; - while (idx + run < residual.size() && residual[idx + run] == 0) { - ++run; - } - if (run >= kZeroRunMinRun) return true; - idx += run; - } else { - ++idx; - } - } - return false; -} - -uint32_t estimate_initial_k(const std::vector& residual) { +uint32_t estimate_initial_k(std::span residual) { if (residual.empty()) return 0; const size_t count = std::min(kInitialScanCount, residual.size()); @@ -172,48 +152,18 @@ uint32_t estimate_initial_k(const std::vector& residual) { return std::min(best_k, 15u); } -uint64_t estimate_rice_bits(const std::vector& residual) { - if (residual.empty()) return 0; - uint64_t sum_u = 0; - for (int32_t r : residual) { - sum_u += unsigned_from_residual(r); - } - const uint64_t mean = (sum_u + (residual.size() >> 1)) / residual.size(); - uint32_t k = 0; - while ((1u << k) < mean && k < 31u) ++k; - - uint64_t bits = 0; - for (int32_t r : residual) { - const uint32_t u = unsigned_from_residual(r); - bits += rice_bits_for_unsigned(u, k); - } - return bits; -} - -uint64_t estimate_adaptive_rice_bits(const std::vector& residual, - uint32_t initial_k, - bool stateless = false) { - uint64_t bits = 0; - uint32_t current_k = initial_k; - uint64_t sum_u = 0; - std::optional adapt_state; - if (!stateless) adapt_state.emplace(); - for (size_t i = 0; i < residual.size(); ++i) { - const uint32_t u = unsigned_from_residual(residual[i]); - bits += rice_bits_for_unsigned(u, current_k); - sum_u += u; - current_k = stateless - ? adapt_k_stateless(sum_u, static_cast(i + 1)) - : Rice::adapt_k(sum_u, static_cast(i + 1), *adapt_state); - } - return bits; -} - -uint64_t estimate_zerorun_bits(const std::vector& residual, - uint32_t initial_k, - bool stateless = false) { - if (residual.empty()) return 0; - uint64_t bits = 0; +struct ResidualCosts { + uint64_t rice_bits = 0; + uint64_t zr_bits = 0; + uint64_t bin_bits = 0; + bool has_zero_run = false; +}; + +ResidualCosts estimate_residual_costs(std::span residual, + uint32_t initial_k, + bool stateless = false) { + ResidualCosts costs; + if (residual.empty()) return costs; uint32_t current_k = initial_k; uint64_t sum_u = 0; uint32_t count = 0; @@ -227,9 +177,14 @@ uint64_t estimate_zerorun_bits(const std::vector& residual, ++run; } if (run >= kZeroRunMinRun) { - bits += 2; // tag - bits += rice_bits_for_unsigned(static_cast(run - kZeroRunMinRun), kZeroRunRunK); + costs.has_zero_run = true; + costs.zr_bits += 2; + costs.zr_bits += rice_bits_for_unsigned( + static_cast(run - kZeroRunMinRun), + kZeroRunRunK); for (size_t j = 0; j < run; ++j) { + costs.rice_bits += rice_bits_for_unsigned(0, current_k); + costs.bin_bits += 2; ++count; current_k = stateless ? adapt_k_stateless(sum_u, count) @@ -239,63 +194,33 @@ uint64_t estimate_zerorun_bits(const std::vector& residual, continue; } - const uint32_t u = unsigned_from_residual(residual[idx]); - const uint32_t escape_threshold = 1u << std::min(24, current_k + 3u); - if (u > escape_threshold) { - bits += 2; // tag - bits += 32; // escape payload - sum_u += u; - ++count; - current_k = stateless - ? adapt_k_stateless(sum_u, count) - : Rice::adapt_k(sum_u, count, *adapt_state); - ++idx; - continue; - } - - // Normal token: single residual only. - bits += 2; // tag - bits += rice_bits_for_unsigned(u, current_k); - sum_u += u; - ++count; - current_k = stateless - ? adapt_k_stateless(sum_u, count) - : Rice::adapt_k(sum_u, count, *adapt_state); - ++idx; - } - return bits; -} - -uint64_t estimate_binning_bits(const std::vector& residual, - uint32_t initial_k, - bool stateless = false) { - if (residual.empty()) return 0; - uint64_t bits = 0; - uint32_t current_k = initial_k; - uint64_t sum_u = 0; - uint32_t count = 0; - std::optional adapt_state; - if (!stateless) adapt_state.emplace(); - - for (int32_t v : residual) { + const int32_t v = residual[idx]; const uint32_t u = unsigned_from_residual(v); + costs.rice_bits += rice_bits_for_unsigned(u, current_k); if (v == 0) { - bits += 2; // 00 + costs.bin_bits += 2; // 00 } else if (v == 1 || v == -1) { - bits += 3; // 01 + sign + costs.bin_bits += 3; // 01 + sign } else if (v == 2 || v == -2) { - bits += 3; // 10 + sign + costs.bin_bits += 3; // 10 + sign } else { - bits += 2; // fallback tag - bits += rice_bits_for_unsigned(u, current_k); + costs.bin_bits += 2; // fallback tag + costs.bin_bits += rice_bits_for_unsigned(u, current_k); } + + const uint32_t escape_threshold = 1u << std::min(24, current_k + 3u); + costs.zr_bits += 2; // tag + costs.zr_bits += (u > escape_threshold) + ? 32 + : rice_bits_for_unsigned(u, current_k); sum_u += u; ++count; current_k = stateless ? adapt_k_stateless(sum_u, count) : Rice::adapt_k(sum_u, count, *adapt_state); + ++idx; } - return bits; + return costs; } void compute_fixed_residual(const std::vector& pcm, int order, std::vector& residual) { @@ -346,121 +271,6 @@ void compute_fir_residual(const std::vector& pcm, int taps, std::vector } // namespace -int Encoder::choose_rice_k(const std::vector& residual) { - return static_cast(estimate_initial_k(residual)); -} - -bool Encoder::estimate_bits(const std::vector& pcm, - uint64_t& bits_normal, - uint64_t& bits_zr, - uint64_t& bits_bin) { - bits_normal = 0; - bits_zr = 0; - bits_bin = 0; - const int max_valid_order = (pcm.size() > 1) - ? static_cast(std::min(32, pcm.size() - 1)) - : 0; - - struct CandidateEval { - int target_order = 0; - int used_order = 0; - uint64_t bit_cost = std::numeric_limits::max(); - uint64_t zr_bits = std::numeric_limits::max(); - uint64_t bin_bits = std::numeric_limits::max(); - long double energy = 0.0L; - bool stable = false; - std::vector coeffs_q15; - std::vector residual; - }; - - CandidateEval best; - uint64_t best_metric = std::numeric_limits::max(); - - for (int cand : kOrderCandidates) { - if (cand > max_valid_order) continue; - LPC lpc(cand); - CandidateEval eval; - eval.target_order = cand; - - int used_order = 0; - long double energy = 0.0L; - eval.stable = lpc.analyze_block_q15(pcm, eval.coeffs_q15, used_order, &energy); - eval.used_order = used_order; - eval.energy = energy; - - if (!eval.stable || eval.used_order == 0) { - continue; - } - - eval.residual.resize(pcm.size()); - if (!pcm.empty()) { - lpc.compute_residual_q15(pcm, eval.coeffs_q15, eval.residual, &eval.used_order); - } - if (eval.used_order == 0) { - continue; - } - - uint32_t adaptive_k = estimate_initial_k(eval.residual); - eval.bit_cost = estimate_adaptive_rice_bits(eval.residual, adaptive_k); - const bool allow_zr = this->zero_run_enabled && has_zero_run(eval.residual); - eval.zr_bits = allow_zr - ? estimate_zerorun_bits(eval.residual, adaptive_k) - : eval.bit_cost; - eval.bin_bits = estimate_binning_bits(eval.residual, adaptive_k); - uint64_t candidate_score = eval.bit_cost; - if (this->zero_run_enabled && has_zero_run(eval.residual)) { - candidate_score = std::min(candidate_score, eval.zr_bits); - } - candidate_score = std::min(candidate_score, eval.bin_bits); - - if (candidate_score < best_metric || - (candidate_score == best_metric && eval.used_order < best.used_order)) { - best = eval; - best_metric = candidate_score; - } - } - - if (best.used_order == 0) { - const int fallback_order = std::max(1, std::min(this->order, max_valid_order)); - LPC lpc_fallback(fallback_order); - int used_order = 0; - long double energy = 0.0L; - std::vector coeffs_q15; - lpc_fallback.analyze_block_q15(pcm, coeffs_q15, used_order, &energy); - std::vector residual(pcm.size()); - if (!pcm.empty()) { - lpc_fallback.compute_residual_q15(pcm, coeffs_q15, residual, &used_order); - } - if (used_order == 0) { - return false; - } - best.target_order = fallback_order; - best.used_order = used_order; - best.energy = energy; - best.coeffs_q15 = std::move(coeffs_q15); - best.residual = std::move(residual); - uint32_t adaptive_k = estimate_initial_k(best.residual); - best.bit_cost = estimate_adaptive_rice_bits(best.residual, adaptive_k); - best.stable = used_order > 0; - const bool allow_zr = this->zero_run_enabled && has_zero_run(best.residual); - best.zr_bits = allow_zr - ? estimate_zerorun_bits(best.residual, adaptive_k) - : best.bit_cost; - best_metric = (allow_zr ? std::min(best.bit_cost, best.zr_bits) : best.bit_cost); - best.bin_bits = estimate_binning_bits(best.residual, adaptive_k); - best_metric = std::min(best_metric, best.bin_bits); - } - - uint32_t best_k = estimate_initial_k(best.residual); - bits_normal = estimate_adaptive_rice_bits(best.residual, best_k); - const bool allow_zr = this->zero_run_enabled && has_zero_run(best.residual); - bits_zr = allow_zr - ? estimate_zerorun_bits(best.residual, best_k) - : bits_normal; - bits_bin = estimate_binning_bits(best.residual, best_k); - return !best.residual.empty(); -} - std::vector Encoder::encode(const std::vector& pcm) { const int max_valid_order = (pcm.size() > 1) ? static_cast(std::min(32, pcm.size() - 1)) @@ -474,13 +284,35 @@ std::vector Encoder::encode(const std::vector& pcm) { uint64_t zr_bits = std::numeric_limits::max(); uint64_t bin_bits = std::numeric_limits::max(); uint64_t best_bits = std::numeric_limits::max(); + uint32_t initial_k = 0; + bool has_run = false; long double energy = 0.0L; bool stable = true; std::vector coeffs_q15; std::vector residual; }; - std::vector candidates; + std::optional best_candidate; + auto score_residual = [this](PredictorEval& ev) { + ev.initial_k = estimate_initial_k(std::span(ev.residual)); + const ResidualCosts costs = + estimate_residual_costs(std::span(ev.residual), ev.initial_k); + ev.rice_bits = costs.rice_bits; + ev.has_run = costs.has_zero_run; + ev.zr_bits = (this->zero_run_enabled && ev.has_run) + ? costs.zr_bits + : ev.rice_bits; + ev.bin_bits = costs.bin_bits; + ev.best_bits = std::min(ev.rice_bits, std::min(ev.zr_bits, ev.bin_bits)); + }; + auto consider = [&](PredictorEval&& ev) { + if (!best_candidate || + ev.best_bits < best_candidate->best_bits || + (ev.best_bits == best_candidate->best_bits && + ev.predictor_type < best_candidate->predictor_type)) { + best_candidate = std::move(ev); + } + }; // Fixed predictors 0-4. for (int fo = 0; fo <= 4; ++fo) { @@ -488,12 +320,8 @@ std::vector Encoder::encode(const std::vector& pcm) { ev.predictor_type = kPredictorFixed; ev.order_param = fo; compute_fixed_residual(pcm, fo, ev.residual); - uint32_t k0 = estimate_initial_k(ev.residual); - ev.rice_bits = estimate_adaptive_rice_bits(ev.residual, k0); - ev.zr_bits = estimate_zerorun_bits(ev.residual, k0); - ev.bin_bits = estimate_binning_bits(ev.residual, k0); - ev.best_bits = std::min(ev.rice_bits, std::min(ev.zr_bits, ev.bin_bits)); - candidates.push_back(std::move(ev)); + score_residual(ev); + consider(std::move(ev)); } // FIR predictor (2 taps predefined). @@ -502,12 +330,8 @@ std::vector Encoder::encode(const std::vector& pcm) { ev.predictor_type = kPredictorFir; ev.order_param = 2; compute_fir_residual(pcm, ev.order_param, ev.residual); - uint32_t k0 = estimate_initial_k(ev.residual); - ev.rice_bits = estimate_adaptive_rice_bits(ev.residual, k0); - ev.zr_bits = estimate_zerorun_bits(ev.residual, k0); - ev.bin_bits = estimate_binning_bits(ev.residual, k0); - ev.best_bits = std::min(ev.rice_bits, std::min(ev.zr_bits, ev.bin_bits)); - candidates.push_back(std::move(ev)); + score_residual(ev); + consider(std::move(ev)); } // LPC predictors. @@ -534,35 +358,21 @@ std::vector Encoder::encode(const std::vector& pcm) { if (ev.used_order == 0) { continue; } - uint32_t k0 = estimate_initial_k(ev.residual); - ev.rice_bits = estimate_adaptive_rice_bits(ev.residual, k0); - ev.zr_bits = estimate_zerorun_bits(ev.residual, k0); - ev.bin_bits = estimate_binning_bits(ev.residual, k0); - ev.best_bits = std::min(ev.rice_bits, std::min(ev.zr_bits, ev.bin_bits)); - candidates.push_back(std::move(ev)); + score_residual(ev); + consider(std::move(ev)); } // Ensure at least one candidate (fallback fixed-0). - if (candidates.empty()) { + if (!best_candidate) { PredictorEval ev; ev.predictor_type = kPredictorFixed; ev.order_param = 0; ev.residual = pcm; - uint32_t k0 = estimate_initial_k(ev.residual); - ev.rice_bits = estimate_adaptive_rice_bits(ev.residual, k0); - ev.zr_bits = estimate_zerorun_bits(ev.residual, k0); - ev.bin_bits = estimate_binning_bits(ev.residual, k0); - ev.best_bits = std::min(ev.rice_bits, std::min(ev.zr_bits, ev.bin_bits)); - candidates.push_back(std::move(ev)); + score_residual(ev); + consider(std::move(ev)); } - PredictorEval best = candidates.front(); - for (const auto& ev : candidates) { - if (ev.best_bits < best.best_bits || - (ev.best_bits == best.best_bits && ev.predictor_type < best.predictor_type)) { - best = ev; - } - } + PredictorEval best = std::move(*best_candidate); const int chosen_order = (best.predictor_type == kPredictorLpc) ? std::max(1, std::min(best.used_order, max_valid_order)) @@ -576,12 +386,12 @@ std::vector Encoder::encode(const std::vector& pcm) { }; const uint32_t block_size = static_cast(best.residual.size()); - const uint32_t base_initial_k = static_cast(this->choose_rice_k(best.residual)); - uint64_t base_bits_normal = estimate_adaptive_rice_bits(best.residual, base_initial_k); - const bool has_run = has_zero_run(best.residual); + const uint32_t base_initial_k = best.initial_k; + uint64_t base_bits_normal = best.rice_bits; + const bool has_run = best.has_run; const bool allow_zr_global = this->zero_run_enabled && has_run; - uint64_t base_bits_zr = allow_zr_global ? estimate_zerorun_bits(best.residual, base_initial_k) : base_bits_normal; - uint64_t base_bits_bin = estimate_binning_bits(best.residual, base_initial_k); + uint64_t base_bits_zr = best.zr_bits; + uint64_t base_bits_bin = best.bin_bits; uint8_t base_mode = 0; // 0=rice,1=zr,2=bin uint64_t base_bits_best = base_bits_normal; if (allow_zr_global && base_bits_zr <= base_bits_best) { @@ -622,8 +432,7 @@ std::vector Encoder::encode(const std::vector& pcm) { if (this->partitioning_enabled && block_size >= Block::MIN_PARTITION_SIZE) { const uint8_t max_p = max_partition_order_for_block(block_size); - const bool allow_partition_zr = allow_zr_global; - for (uint8_t p = 0; p <= max_p; ++p) { + for (uint8_t p = 1; p <= max_p; ++p) { const auto sizes = partition_sizes_for_block(block_size, p); if (sizes.empty()) continue; std::vector choices; @@ -633,12 +442,13 @@ std::vector Encoder::encode(const std::vector& pcm) { for (uint32_t len : sizes) { PartitionChoice pc; pc.length = len; - const std::vector segment(best.residual.begin() + offset, best.residual.begin() + offset + len); + const std::span segment(best.residual.data() + offset, len); pc.initial_k = estimate_initial_k(segment); - const uint64_t normal_bits = estimate_adaptive_rice_bits(segment, pc.initial_k, true); - const uint64_t bin_bits = estimate_binning_bits(segment, pc.initial_k, true); - const bool allow_zr = allow_partition_zr && has_zero_run(segment); - const uint64_t zr_bits = allow_zr ? estimate_zerorun_bits(segment, pc.initial_k, true) : normal_bits; + const ResidualCosts costs = estimate_residual_costs(segment, pc.initial_k, true); + const uint64_t normal_bits = costs.rice_bits; + const uint64_t bin_bits = costs.bin_bits; + const bool allow_zr = this->zero_run_enabled && costs.has_zero_run; + const uint64_t zr_bits = allow_zr ? costs.zr_bits : normal_bits; pc.residual_mode = 0; pc.bits = normal_bits; if (allow_zr && zr_bits < pc.bits) { @@ -821,11 +631,14 @@ std::vector Encoder::encode(const std::vector& pcm) { } bw.write_bits(kTagRun, 2); write_rice_unsigned(bw, static_cast(run - kZeroRunMinRun), kZeroRunRunK); - for (size_t j = 0; j < run; ++j) { - ++count; - current_k = use_stateless_adapt - ? adapt_k_stateless(sum_u, count) - : Rice::adapt_k(sum_u, count, *adapt_state); + if (use_stateless_adapt) { + count += static_cast(run); + current_k = adapt_k_stateless(sum_u, count); + } else { + for (size_t j = 0; j < run; ++j) { + ++count; + current_k = Rice::adapt_k(sum_u, count, *adapt_state); + } } idx += run; ++token_idx; @@ -949,7 +762,7 @@ std::vector Encoder::encode(const std::vector& pcm) { << "\n"); } - return bw.get_buffer(); + return bw.take_buffer(); } } // namespace Block diff --git a/src/codec/block/encoder.hpp b/src/codec/block/encoder.hpp index 1ff6ab6..45cc3f6 100644 --- a/src/codec/block/encoder.hpp +++ b/src/codec/block/encoder.hpp @@ -18,11 +18,6 @@ class Encoder { void set_debug_block_index(size_t index); void set_partitioning_enabled(bool enabled); void set_debug_partitions(bool enabled); - bool estimate_bits(const std::vector& pcm, - uint64_t& bits_normal, - uint64_t& bits_zr, - uint64_t& bits_bin); - private: int order; bool debug_lpc; @@ -32,7 +27,6 @@ class Encoder { bool debug_partitions = false; size_t block_index = 0; - int choose_rice_k(const std::vector& residual); }; } // namespace Block diff --git a/src/codec/frame/frame_header.hpp b/src/codec/frame/frame_header.hpp index 17a7374..d9498ac 100644 --- a/src/codec/frame/frame_header.hpp +++ b/src/codec/frame/frame_header.hpp @@ -6,7 +6,7 @@ struct FrameHeader { uint16_t sync; // 0x4C41 "LA" - uint8_t version; // 2 for adaptive block format + uint8_t version; // 3 for byte-bounded parallel block format uint8_t channels; // 1 or 2 uint8_t stereo_mode; // 0=LR,1=MS,2=per-block (stereo only) uint32_t sample_rate; // in hz @@ -15,7 +15,7 @@ struct FrameHeader { FrameHeader() : sync(0x4C41), - version(2), + version(3), channels(2), stereo_mode(2), sample_rate(44100), @@ -48,7 +48,7 @@ struct FrameHeader { } bool validate() const { - if (this->sync != 0x4C41 || this->version != 2) return false; + if (this->sync != 0x4C41 || (this->version != 2 && this->version != 3)) return false; if (this->channels != 1 && this->channels != 2) return false; if (this->channels == 1 && this->stereo_mode != 0) return false; if (this->stereo_mode != 0 && this->stereo_mode != 1 && this->stereo_mode != 2) return false; diff --git a/src/codec/lac/decoder.cpp b/src/codec/lac/decoder.cpp index 62559d3..a6460aa 100644 --- a/src/codec/lac/decoder.cpp +++ b/src/codec/lac/decoder.cpp @@ -1,11 +1,15 @@ #include "decoder.hpp" #include +#include +#include #include +#include #include #include #include #include "codec/block/decoder.hpp" #include "codec/bitstream/bit_reader.hpp" +#include "codec/lac/thread_limit.hpp" #include "codec/simd/neon.hpp" namespace LAC { @@ -69,7 +73,11 @@ bool reconstruct_mid_side_to_output(const std::vector& mid, } // namespace Decoder::Decoder(ThreadCollector* collector) - : collector(collector) {} + : collector(collector), thread_count(0) {} + +void Decoder::set_thread_count(size_t max_threads) { + this->thread_count = max_threads; +} void Decoder::decode(const uint8_t* data, size_t size, @@ -98,12 +106,19 @@ void Decoder::decode(const uint8_t* data, if (br.has_error() || block_count == 0 || block_count > MAX_BLOCK_COUNT) { throw_decode_error("invalid block count"); } - if (block_count > br.bits_remaining() / 32u) { + const bool has_block_payload_sizes = (hdr.version >= 3); + const uint32_t table_words_per_block = has_block_payload_sizes ? 2u : 1u; + if (block_count > br.bits_remaining() / (32u * table_words_per_block)) { throw_decode_error("truncated block size table"); } std::vector block_sizes(block_count); + std::vector block_payload_sizes; + if (has_block_payload_sizes) { + block_payload_sizes.resize(block_count); + } uint64_t total_samples = 0; + uint64_t total_block_payload_bytes = 0; for (uint32_t i = 0; i < block_count; ++i) { uint32_t sz = br.read_bits(32); if (br.has_error() || sz == 0 || sz > Block::MAX_BLOCK_SIZE || @@ -115,6 +130,17 @@ void Decoder::decode(const uint8_t* data, throw_decode_error("total samples exceed maximum"); } block_sizes[i] = sz; + if (has_block_payload_sizes) { + const uint32_t payload_size = br.read_bits(32); + if (br.has_error() || payload_size == 0) { + throw_decode_error("invalid compressed block size"); + } + total_block_payload_bytes += payload_size; + if (total_block_payload_bytes > payload_bytes) { + throw_decode_error("compressed block sizes exceed frame payload"); + } + block_payload_sizes[i] = payload_size; + } } const uint64_t decoded_pcm_bytes = total_samples * static_cast(hdr.channels) * sizeof(int32_t); @@ -144,15 +170,13 @@ void Decoder::decode(const uint8_t* data, decoded_right.assign(running, 0); } - if (this->collector) { - this->collector->record(std::this_thread::get_id()); - } - - for (uint32_t i = 0; i < block_count; ++i) { + auto decode_block = [&](uint32_t i, BitReader& block_reader) { bool mid_side = false; if (perBlockStereo) { - uint32_t mode_flag = br.read_bits(8); - if (br.has_error() || mode_flag > 1u) throw_decode_error("invalid per-block stereo flag"); + uint32_t mode_flag = block_reader.read_bits(8); + if (block_reader.has_error() || mode_flag > 1u) { + throw_decode_error("invalid per-block stereo flag"); + } mid_side = (mode_flag == 1u); } else if (forceMidSide) { mid_side = true; @@ -160,13 +184,13 @@ void Decoder::decode(const uint8_t* data, Block::Decoder blockDec; std::vector primary_pcm; - if (!blockDec.decode(br, block_sizes[i], primary_pcm)) { + if (!blockDec.decode(block_reader, block_sizes[i], primary_pcm)) { throw_decode_error(i, "primary"); } std::vector secondary_pcm; if (isStereo) { - if (!blockDec.decode(br, block_sizes[i], secondary_pcm)) { + if (!blockDec.decode(block_reader, block_sizes[i], secondary_pcm)) { throw_decode_error(i, "secondary"); } } @@ -191,10 +215,81 @@ void Decoder::decode(const uint8_t* data, throw_decode_error("decoded sample outside PCM bit depth"); } } - } + }; - if (br.bits_remaining() != 0u) { - throw_decode_error("trailing frame payload"); + if (!has_block_payload_sizes) { + if (this->collector) { + this->collector->record(std::this_thread::get_id()); + } + for (uint32_t i = 0; i < block_count; ++i) { + decode_block(i, br); + } + if (br.bits_remaining() != 0u) { + throw_decode_error("trailing frame payload"); + } + } else { + if ((br.bits_remaining() & 7u) != 0u) { + throw_decode_error("unaligned compressed block payload"); + } + const size_t available_payload_bytes = br.bits_remaining() / 8u; + if (total_block_payload_bytes != available_payload_bytes) { + throw_decode_error("compressed block sizes do not match frame payload"); + } + + const uint8_t* block_payload = payload + (payload_bytes - available_payload_bytes); + std::vector block_payload_offsets(block_count); + size_t payload_offset = 0; + for (uint32_t i = 0; i < block_count; ++i) { + block_payload_offsets[i] = payload_offset; + payload_offset += block_payload_sizes[i]; + } + + size_t hardware_threads = + std::max(1, static_cast(std::thread::hardware_concurrency())); + const size_t thread_limit = LAC::resolve_thread_limit(this->thread_count); + if (thread_limit > 0) { + hardware_threads = std::min(hardware_threads, thread_limit); + } + const size_t worker_count = std::min(hardware_threads, block_count); + std::atomic next_block{0}; + std::atomic stop_requested{false}; + std::mutex error_mutex; + std::exception_ptr worker_error; + std::vector workers; + workers.reserve(worker_count); + + for (size_t worker_idx = 0; worker_idx < worker_count; ++worker_idx) { + workers.emplace_back([&]() { + try { + if (this->collector) { + this->collector->record(std::this_thread::get_id()); + } + while (!stop_requested.load(std::memory_order_acquire)) { + const uint32_t block_idx = next_block.fetch_add(1, std::memory_order_relaxed); + if (block_idx >= block_count) return; + BitReader block_reader(block_payload + block_payload_offsets[block_idx], + block_payload_sizes[block_idx]); + decode_block(block_idx, block_reader); + if (block_reader.bits_remaining() != 0u) { + throw_decode_error(block_idx, "trailing-payload"); + } + } + } catch (...) { + stop_requested.store(true, std::memory_order_release); + std::lock_guard lock(error_mutex); + if (!worker_error) { + worker_error = std::current_exception(); + } + } + }); + } + + for (auto& worker : workers) { + worker.join(); + } + if (worker_error) { + std::rethrow_exception(worker_error); + } } if (hdr.channels == 2 && decoded_right.size() != decoded_left.size()) { diff --git a/src/codec/lac/decoder.hpp b/src/codec/lac/decoder.hpp index 6ff11ae..8f1e9fb 100644 --- a/src/codec/lac/decoder.hpp +++ b/src/codec/lac/decoder.hpp @@ -16,9 +16,11 @@ class Decoder { std::vector& left, std::vector& right, FrameHeader* out_header = nullptr); + void set_thread_count(size_t max_threads); private: ThreadCollector* collector; + size_t thread_count; }; } // namespace LAC diff --git a/src/codec/lac/encoder.cpp b/src/codec/lac/encoder.cpp index d7aca1d..24a92fc 100644 --- a/src/codec/lac/encoder.cpp +++ b/src/codec/lac/encoder.cpp @@ -1,25 +1,24 @@ #include "encoder.hpp" -#include -#include +#include #include #include #include #include #include +#include #include #include #include +#include #include #include "codec/simd/neon.hpp" -#include "codec/lpc/lpc.hpp" +#include "codec/lac/thread_limit.hpp" #include "utils/logger.hpp" namespace { - constexpr double kAbsDiffWeight = 1.0; - constexpr double kVarianceWeight = 0.5; - constexpr double kSlopePenaltyWeight = 0.25; - constexpr double kSlopeThreshold = 4.0; - constexpr uint64_t kMaxCostProxy = std::numeric_limits::max() / 4; + constexpr uint64_t kStereoConfidenceDivisor = 100; + constexpr size_t kStereoProbeSize = 256; + constexpr size_t kStereoFullComparisonLimit = 4096; constexpr int32_t kPcm16Min = -32768; constexpr int32_t kPcm16Max = 32767; constexpr int32_t kPcm24Min = -0x800000; @@ -30,22 +29,44 @@ namespace { uint32_t size; }; - bool has_zero_run_samples(const std::vector& samples) { - if (samples.empty()) return false; - size_t idx = 0; - while (idx < samples.size()) { - if (samples[idx] == 0) { - size_t run = 0; - while (idx + run < samples.size() && samples[idx + run] == 0) { - ++run; - } - if (run >= Block::ZERO_RUN_MIN_LENGTH) return true; - idx += run; - } else { - ++idx; - } + uint64_t add_saturated(uint64_t left, uint64_t right) { + if (right > std::numeric_limits::max() - left) { + return std::numeric_limits::max(); } - return false; + return left + right; + } + + uint64_t zigzag_difference(int64_t value) { + if (value >= 0) return static_cast(value) << 1; + return (static_cast(-(value + 1)) << 1) | 1u; + } + + uint32_t rice_k_for_mean(uint64_t sum, uint64_t count) { + if (count == 0) return 0; + const uint64_t mean = (sum + (count >> 1)) / count; + uint32_t k = 0; + while (k < 31u && (uint64_t{1} << k) < mean) { + ++k; + } + return k; + } + + uint64_t approximate_rice_bits(uint64_t sum, uint64_t count) { + if (count == 0) return 0; + const uint32_t k = rice_k_for_mean(sum, count); + return add_saturated(sum >> k, count * static_cast(k + 1u)); + } + + std::vector plan_blocks(const std::vector& left) { + std::vector blocks; + size_t pos = 0; + while (pos < left.size()) { + const uint32_t size = static_cast( + std::min(Block::MAX_BLOCK_SIZE, left.size() - pos)); + blocks.push_back({pos, size}); + pos += size; + } + return blocks; } bool is_supported_sample_rate(uint32_t sample_rate) { @@ -81,187 +102,99 @@ namespace { } } - bool should_enable_zero_run_for_block(const std::vector& pcm, int order) { - if (pcm.empty()) return false; - if (!has_zero_run_samples(pcm)) return false; - Block::Encoder estimator(order, false, false); - estimator.set_zero_run_enabled(true); - uint64_t bits_normal = 0; - uint64_t bits_zr = 0; - uint64_t bits_bin = 0; - if (!estimator.estimate_bits(pcm, bits_normal, bits_zr, bits_bin)) { - return false; - } - return bits_zr < std::min(bits_normal, bits_bin); - } - - size_t parse_thread_limit(const char* value) { - if (value == nullptr || value[0] == '\0') { - return 0; - } - for (const char* p = value; *p != '\0'; ++p) { - if (*p < '0' || *p > '9') { - throw std::invalid_argument("LAC_THREADS must be a positive integer"); - } - } - - errno = 0; - char* end = nullptr; - unsigned long long parsed = std::strtoull(value, &end, 10); - if (errno != 0 || end == value || *end != '\0' || parsed == 0) { - throw std::invalid_argument("LAC_THREADS must be a positive integer"); - } - if (parsed > static_cast(std::numeric_limits::max())) { - throw std::invalid_argument("LAC_THREADS is too large"); - } - return static_cast(parsed); - } - - size_t resolve_thread_limit(size_t explicit_limit) { - if (explicit_limit > 0) { - return explicit_limit; - } - return parse_thread_limit(std::getenv("LAC_THREADS")); - } - - double block_slope_penalty(const std::vector& channel, size_t position, size_t length) { - if (channel.empty() || length <= 1 || position + length > channel.size()) return 0.0; - const int64_t first = channel[position]; - const int64_t last = channel[position + length - 1]; - const int64_t delta = (last >= first) ? (last - first) : (first - last); - const double slope = static_cast(delta) / static_cast(length); - if (slope <= kSlopeThreshold) { - return 0.0; - } - const double size_scale = static_cast(length) / static_cast(Block::MAX_BLOCK_SIZE); - return (slope - kSlopeThreshold) * kSlopePenaltyWeight * size_scale; - } - - inline uint64_t estimate_rice_bits(const std::vector& residual) { - if (residual.empty()) return 0; - auto to_unsigned = [](int32_t r) -> uint32_t { - const uint32_t sign_mask = - (r < 0) ? std::numeric_limits::max() : 0u; - return (static_cast(r) << 1) ^ sign_mask; - }; - uint64_t sum_u = 0; - for (int32_t r : residual) { - sum_u += to_unsigned(r); - } - const uint64_t mean = (sum_u + (residual.size() >> 1)) / residual.size(); - uint32_t k = 0; - while ((1u << k) < mean && k < 31u) ++k; - - uint64_t bits = 0; - for (int32_t r : residual) { - const uint32_t u = to_unsigned(r); - const uint32_t q = (k >= 31u) ? 0u : (u >> k); - bits += static_cast(q) + 1u + k; - } - return bits; - } + struct StereoDecision { + bool choose_ms = false; + bool uncertain = false; + }; - struct StereoCost { - uint64_t lr_bits = kMaxCostProxy; - uint64_t ms_bits = kMaxCostProxy; - bool lr_valid = false; - bool ms_valid = false; - bool choose_ms() const { - if (ms_valid && !lr_valid) return true; - if (lr_valid && !ms_valid) return false; - if (!lr_valid && !ms_valid) return false; - if (ms_bits < lr_bits) return true; - if (ms_bits > lr_bits) return false; - return false; // tie -> LR - } + struct ChannelProxyCost { + uint64_t bits; + bool non_difference_predictor_active; }; - bool estimate_channel_cost(const std::vector& pcm, - int order, - bool zero_run_enabled, - bool partitioning_enabled, - uint64_t& bits) { - Block::Encoder estimator(order, false, false); - const bool use_zr = zero_run_enabled && should_enable_zero_run_for_block(pcm, order); - estimator.set_zero_run_enabled(use_zr); - estimator.set_partitioning_enabled(partitioning_enabled); - uint64_t bits_normal = 0; - uint64_t bits_zr = 0; - uint64_t bits_bin = 0; - if (!estimator.estimate_bits(pcm, bits_normal, bits_zr, bits_bin)) { - return false; - } - bits = std::min(bits_normal, std::min(bits_zr, bits_bin)); - return true; + ChannelProxyCost estimate_channel_proxy_cost(uint64_t raw_sum, + uint64_t diff_sum, + uint64_t anti_diff_sum, + uint64_t count) { + const uint64_t raw_bits = approximate_rice_bits(raw_sum, count); + const uint64_t diff_bits = approximate_rice_bits(diff_sum, count); + const uint64_t anti_diff_bits = approximate_rice_bits(anti_diff_sum, count); + return { + std::min({raw_bits, diff_bits, anti_diff_bits}), + raw_bits < diff_bits || anti_diff_bits < diff_bits}; } - StereoCost estimate_stereo_cost(const std::vector& left, + StereoDecision estimate_stereo_mode(const std::vector& left, const std::vector& right, size_t start, - size_t size, - int order, - bool zero_run_enabled, - bool partitioning_enabled) { - StereoCost cost; - if (size == 0 || start + size > left.size() || start + size > right.size()) { - return cost; - } - - std::vector blockL(left.begin() + start, left.begin() + start + size); - std::vector blockR(right.begin() + start, right.begin() + start + size); - cost.lr_valid = estimate_channel_cost(blockL, - order, - zero_run_enabled, - partitioning_enabled, - cost.lr_bits); - uint64_t bits_r = 0; - if (estimate_channel_cost(blockR, - order, - zero_run_enabled, - partitioning_enabled, - bits_r)) { - if (cost.lr_valid) { - cost.lr_bits += bits_r; - } else { - cost.lr_bits = bits_r; - cost.lr_valid = true; - } - } else { - cost.lr_valid = false; - } - - // MS path - std::vector blockM(size); - std::vector blockS(size); - SIMD::ms_encode_simd_or_scalar( - left.data() + start, - right.data() + start, - blockM.data(), - blockS.data(), - size - ); - cost.ms_valid = estimate_channel_cost(blockM, - order, - zero_run_enabled, - partitioning_enabled, - cost.ms_bits); - uint64_t bits_s = 0; - if (estimate_channel_cost(blockS, - order, - zero_run_enabled, - partitioning_enabled, - bits_s)) { - if (cost.ms_valid) { - cost.ms_bits += bits_s; + size_t size) { + uint64_t l_raw_sum = 0; + uint64_t r_raw_sum = 0; + uint64_t m_raw_sum = 0; + uint64_t s_raw_sum = 0; + uint64_t l_sum = 0; + uint64_t r_sum = 0; + uint64_t m_sum = 0; + uint64_t s_sum = 0; + uint64_t l_anti_sum = 0; + uint64_t r_anti_sum = 0; + uint64_t m_anti_sum = 0; + uint64_t s_anti_sum = 0; + int64_t previous_l = 0; + int64_t previous_r = 0; + int64_t previous_m = 0; + int64_t previous_s = 0; + for (size_t i = 0; i < size; ++i) { + const int64_t l = left[start + i]; + const int64_t r = right[start + i]; + const int64_t m = (l + r) >> 1; + const int64_t s = l - r; + l_raw_sum = add_saturated(l_raw_sum, zigzag_difference(l)); + r_raw_sum = add_saturated(r_raw_sum, zigzag_difference(r)); + m_raw_sum = add_saturated(m_raw_sum, zigzag_difference(m)); + s_raw_sum = add_saturated(s_raw_sum, zigzag_difference(s)); + if (i == 0) { + l_sum = zigzag_difference(l); + r_sum = zigzag_difference(r); + m_sum = zigzag_difference(m); + s_sum = zigzag_difference(s); + l_anti_sum = l_sum; + r_anti_sum = r_sum; + m_anti_sum = m_sum; + s_anti_sum = s_sum; } else { - cost.ms_bits = bits_s; - cost.ms_valid = true; + l_sum = add_saturated(l_sum, zigzag_difference(l - previous_l)); + r_sum = add_saturated(r_sum, zigzag_difference(r - previous_r)); + m_sum = add_saturated(m_sum, zigzag_difference(m - previous_m)); + s_sum = add_saturated(s_sum, zigzag_difference(s - previous_s)); + l_anti_sum = add_saturated(l_anti_sum, zigzag_difference(l + previous_l)); + r_anti_sum = add_saturated(r_anti_sum, zigzag_difference(r + previous_r)); + m_anti_sum = add_saturated(m_anti_sum, zigzag_difference(m + previous_m)); + s_anti_sum = add_saturated(s_anti_sum, zigzag_difference(s + previous_s)); } - } else { - cost.ms_valid = false; - } - - return cost; + previous_l = l; + previous_r = r; + previous_m = m; + previous_s = s; + } + const uint64_t count = static_cast(size); + const ChannelProxyCost l_cost = estimate_channel_proxy_cost(l_raw_sum, l_sum, l_anti_sum, count); + const ChannelProxyCost r_cost = estimate_channel_proxy_cost(r_raw_sum, r_sum, r_anti_sum, count); + const ChannelProxyCost m_cost = estimate_channel_proxy_cost(m_raw_sum, m_sum, m_anti_sum, count); + const ChannelProxyCost s_cost = estimate_channel_proxy_cost(s_raw_sum, s_sum, s_anti_sum, count); + const uint64_t lr_bits = add_saturated(l_cost.bits, r_cost.bits); + const uint64_t ms_bits = add_saturated(m_cost.bits, s_cost.bits); + const uint64_t smaller = std::min(lr_bits, ms_bits); + const uint64_t difference = (lr_bits >= ms_bits) ? (lr_bits - ms_bits) : (ms_bits - lr_bits); + const bool non_difference_predictor_active = + l_cost.non_difference_predictor_active || r_cost.non_difference_predictor_active || + m_cost.non_difference_predictor_active || s_cost.non_difference_predictor_active; + StereoDecision decision; + decision.choose_ms = ms_bits < lr_bits; + decision.uncertain = + smaller == 0 || difference == 0 || non_difference_predictor_active || + difference <= smaller / kStereoConfidenceDivisor; + return decision; } } @@ -278,8 +211,7 @@ namespace LAC { zero_run_enabled(true), partitioning_enabled(true), debug_partitions(false), - thread_count(0), - candidates{256, 512, 1024, 2048, 4096, 8192, 16384} {} + thread_count(0) {} std::vector Encoder::encode( const std::vector& left, @@ -318,22 +250,7 @@ namespace LAC { hdr.bit_depth = this->bit_depth; hdr.write(writer); - std::vector blocks; - const size_t totalSamples = left.size(); - size_t pos = 0; - while (pos < totalSamples) { - uint32_t block_size = this->select_block_size(left, right, pos); - if (block_size == 0) { - block_size = static_cast(totalSamples - pos); - } - blocks.push_back({pos, block_size}); - pos += block_size; - } - - writer.write_bits(static_cast(blocks.size()), 32); - for (const auto& block : blocks) { - writer.write_bits(block.size, 32); - } + std::vector blocks = plan_blocks(left); const bool isStereo = (hdr.channels == 2); const bool forceMidSide = isStereo && (hdr.stereo_mode == 1); @@ -360,45 +277,40 @@ namespace LAC { blockEnc.set_debug_block_index(block_idx); std::vector encodedBytes; const size_t start = block.start; - const size_t end = start + block.size; if (block.size == 0) { return encodedBytes; } - auto encode_lr = [&]() -> std::vector { + auto encode_lr = [&](size_t range_start, size_t range_size) -> std::vector { std::vector out; - std::vector blockL(left.begin() + start, left.begin() + end); - bool use_zr_l = this->zero_run_enabled && should_enable_zero_run_for_block(blockL, this->order); - blockEnc.set_zero_run_enabled(use_zr_l); + std::vector blockL(left.begin() + range_start, left.begin() + range_start + range_size); + blockEnc.set_zero_run_enabled(this->zero_run_enabled); std::vector encodedL = blockEnc.encode(blockL); out.insert(out.end(), encodedL.begin(), encodedL.end()); if (isStereo) { - std::vector blockR(right.begin() + start, right.begin() + end); - bool use_zr_r = this->zero_run_enabled && should_enable_zero_run_for_block(blockR, this->order); - blockEnc.set_zero_run_enabled(use_zr_r); + std::vector blockR(right.begin() + range_start, right.begin() + range_start + range_size); + blockEnc.set_zero_run_enabled(this->zero_run_enabled); std::vector encodedR = blockEnc.encode(blockR); out.insert(out.end(), encodedR.begin(), encodedR.end()); } return out; }; - auto encode_ms = [&]() -> std::vector { + auto encode_ms = [&](size_t range_start, size_t range_size) -> std::vector { std::vector out; - std::vector blockMid(block.size); - std::vector blockSide(block.size); + std::vector blockMid(range_size); + std::vector blockSide(range_size); SIMD::ms_encode_simd_or_scalar( - left.data() + start, - right.data() + start, + left.data() + range_start, + right.data() + range_start, blockMid.data(), blockSide.data(), - block.size + range_size ); - bool use_zr_mid = this->zero_run_enabled && should_enable_zero_run_for_block(blockMid, this->order); - blockEnc.set_zero_run_enabled(use_zr_mid); + blockEnc.set_zero_run_enabled(this->zero_run_enabled); std::vector encodedMid = blockEnc.encode(blockMid); - bool use_zr_side = this->zero_run_enabled && should_enable_zero_run_for_block(blockSide, this->order); - blockEnc.set_zero_run_enabled(use_zr_side); + blockEnc.set_zero_run_enabled(this->zero_run_enabled); std::vector encodedSide = blockEnc.encode(blockSide); out.insert(out.end(), encodedMid.begin(), encodedMid.end()); out.insert(out.end(), encodedSide.begin(), encodedSide.end()); @@ -408,38 +320,55 @@ namespace LAC { std::string mode_used = "LR"; if (!isStereo) { - encodedBytes = encode_lr(); + encodedBytes = encode_lr(start, block.size); } else if (forceMidSide) { mode_used = "MS"; - std::vector msBytes = encode_ms(); + std::vector msBytes = encode_ms(start, block.size); encodedBytes.insert(encodedBytes.end(), msBytes.begin(), msBytes.end()); } else if (!perBlockStereo) { mode_used = "LR"; - std::vector lrBytes = encode_lr(); + std::vector lrBytes = encode_lr(start, block.size); encodedBytes.insert(encodedBytes.end(), lrBytes.begin(), lrBytes.end()); } else { - StereoCost cost = estimate_stereo_cost(left, - right, - start, - block.size, - this->order, - this->zero_run_enabled, - this->partitioning_enabled); - bool choose_ms = cost.choose_ms(); + const StereoDecision decision = estimate_stereo_mode(left, right, start, block.size); + bool choose_ms = decision.choose_ms; + std::vector selected; + if (decision.uncertain) { + if (block.size <= kStereoFullComparisonLimit) { + std::vector lrBytes = encode_lr(start, block.size); + std::vector msBytes = encode_ms(start, block.size); + choose_ms = msBytes.size() < lrBytes.size(); + selected = choose_ms ? std::move(msBytes) : std::move(lrBytes); + } else { + // Spread bounded probes across the block so a local transition does not dominate mode selection. + const size_t probe_starts[] = { + start, + start + (block.size - kStereoProbeSize) / 2u, + start + block.size - kStereoProbeSize}; + size_t lr_probe_size = 0; + size_t ms_probe_size = 0; + for (const size_t probe_start : probe_starts) { + lr_probe_size += encode_lr(probe_start, kStereoProbeSize).size(); + ms_probe_size += encode_ms(probe_start, kStereoProbeSize).size(); + } + choose_ms = ms_probe_size < lr_probe_size; + } + } if (this->debug_stereo_est) { LAC_DEBUG_LOG("[stereo-est] block=" << block_idx - << " lr_bits=" << (cost.lr_valid ? cost.lr_bits : kMaxCostProxy) - << " ms_bits=" << (cost.ms_valid ? cost.ms_bits : kMaxCostProxy) + << " uncertain=" << (decision.uncertain ? 1 : 0) << " chosen=" << (choose_ms ? "MS" : "LR") << "\n"; ); } mode_used = (choose_ms ? "MS" : "LR"); encodedBytes.push_back(choose_ms ? 1 : 0); - if (choose_ms) { - std::vector msBytes = encode_ms(); + if (!selected.empty()) { + encodedBytes.insert(encodedBytes.end(), selected.begin(), selected.end()); + } else if (choose_ms) { + std::vector msBytes = encode_ms(start, block.size); encodedBytes.insert(encodedBytes.end(), msBytes.begin(), msBytes.end()); } else { - std::vector lrBytes = encode_lr(); + std::vector lrBytes = encode_lr(start, block.size); encodedBytes.insert(encodedBytes.end(), lrBytes.begin(), lrBytes.end()); } } @@ -455,43 +384,42 @@ namespace LAC { }; size_t hardware_threads = std::max(1, static_cast(std::thread::hardware_concurrency())); - const size_t thread_limit = resolve_thread_limit(this->thread_count); + const size_t thread_limit = LAC::resolve_thread_limit(this->thread_count); if (thread_limit > 0) { hardware_threads = std::min(hardware_threads, thread_limit); } const size_t worker_count = std::min(hardware_threads, blocks.size()); - std::vector workers; + std::vector workers; workers.reserve(worker_count); for (size_t worker_idx = 0; worker_idx < worker_count; ++worker_idx) { workers.emplace_back([&]() { - if (collector) { - collector->record(std::this_thread::get_id()); - } - while (true) { - if (stop_requested.load(std::memory_order_acquire)) { - return; + try { + if (collector) { + collector->record(std::this_thread::get_id()); } - - size_t block_idx = 0; - { - std::lock_guard lock(queue_mutex); - if (task_queue.empty()) { + while (true) { + if (stop_requested.load(std::memory_order_acquire)) { return; } - block_idx = task_queue.front(); - task_queue.pop(); - } - try { - encodedBlocks[block_idx] = encode_block(block_idx); - } catch (...) { - stop_requested.store(true, std::memory_order_release); - std::lock_guard lock(error_mutex); - if (!worker_error) { - worker_error = std::current_exception(); + size_t block_idx = 0; + { + std::lock_guard lock(queue_mutex); + if (task_queue.empty()) { + return; + } + block_idx = task_queue.front(); + task_queue.pop(); } - return; + + encodedBlocks[block_idx] = encode_block(block_idx); + } + } catch (...) { + stop_requested.store(true, std::memory_order_release); + std::lock_guard lock(error_mutex); + if (!worker_error) { + worker_error = std::current_exception(); } } }); @@ -505,14 +433,27 @@ namespace LAC { std::rethrow_exception(worker_error); } - for (const auto& bytes : encodedBlocks) { - for (uint8_t b : bytes) { - writer.write_bits(b, 8); + writer.write_bits(static_cast(blocks.size()), 32); + for (size_t i = 0; i < blocks.size(); ++i) { + if (encodedBlocks[i].empty() || + encodedBlocks[i].size() > std::numeric_limits::max()) { + throw std::runtime_error("encoded block size is outside format limits"); } + writer.write_bits(blocks[i].size, 32); + writer.write_bits(static_cast(encodedBlocks[i].size()), 32); + } + + size_t encoded_size = writer.get_buffer().size(); + for (const auto& bytes : encodedBlocks) { + encoded_size += bytes.size(); + } + writer.reserve_bytes(encoded_size); + for (const auto& bytes : encodedBlocks) { + writer.write_bytes(bytes.data(), bytes.size()); } writer.flush_to_byte(); - return writer.get_buffer(); + return writer.take_buffer(); } void Encoder::set_zero_run_enabled(bool enabled) { @@ -531,95 +472,4 @@ namespace LAC { this->thread_count = max_threads; } - uint32_t Encoder::select_block_size(const std::vector& left, - const std::vector& right, - size_t position) const { - if (position >= left.size()) return 0; - size_t remaining = left.size() - position; - if (remaining == 0) return 0; - - double best_cost = std::numeric_limits::infinity(); - uint32_t best_size = static_cast(std::min(remaining, Block::MAX_BLOCK_SIZE)); - bool evaluated = false; - - for (uint32_t candidate : this->candidates) { - if (candidate == 0 || candidate > remaining) continue; - double cost = this->block_complexity(left, position, candidate); - if (!right.empty()) { - cost += this->block_complexity(right, position, candidate); - } - cost += block_slope_penalty(left, position, candidate); - if (!right.empty()) { - cost += block_slope_penalty(right, position, candidate); - } - evaluated = true; - if (cost < best_cost || (cost == best_cost && candidate > best_size)) { - best_cost = cost; - best_size = candidate; - } - } - - if (!evaluated) { - best_size = static_cast(std::min(remaining, Block::MAX_BLOCK_SIZE)); - } - - if (best_size > remaining) { - best_size = static_cast(remaining); - } - if (best_size > Block::MAX_BLOCK_SIZE) { - best_size = Block::MAX_BLOCK_SIZE; - } - if (best_size == 0) { - best_size = static_cast(std::min(remaining, Block::MAX_BLOCK_SIZE)); - } - - return best_size; - } - - double Encoder::block_complexity(const std::vector& channel, - size_t position, - size_t length) const { - if (channel.empty() || length <= 1) return 0.0; - if (position + length > channel.size()) length = channel.size() - position; - if (length <= 1) return 0.0; - - int64_t absdiff_sum = 0; - int64_t sum = 0; - long double sumsq = 0.0L; - - const size_t end = position + length; - const int32_t first = channel[position]; - sum = static_cast(first); - sumsq = static_cast(first) * static_cast(first); - - if constexpr (SIMD::kHasNeon) { - absdiff_sum = SIMD::neon_absdiff_sum(channel.data() + position, length); - for (size_t i = position + 1; i < end; ++i) { - const int64_t v = static_cast(channel[i]); - sum += v; - sumsq += static_cast(v) * static_cast(v); - } - } else { - int32_t prev = first; - for (size_t i = position + 1; i < end; ++i) { - const int32_t cur = channel[i]; - const int64_t v = static_cast(cur); - sum += v; - sumsq += static_cast(v) * static_cast(v); - const int64_t diff = static_cast(cur) - static_cast(prev); - absdiff_sum += (diff >= 0 ? diff : -diff); - prev = cur; - } - } - - const double absdiff_avg = static_cast(absdiff_sum) / static_cast(length); - const double mean = static_cast(sum) / static_cast(length); - const double mean_sq = mean * mean; - const double avg_sq = static_cast(sumsq / static_cast(length)); - double variance = avg_sq - mean_sq; - if (variance < 0.0) variance = 0.0; - - return (kAbsDiffWeight * absdiff_avg) + (kVarianceWeight * variance); - } - } // namespace LAC diff --git a/src/codec/lac/encoder.hpp b/src/codec/lac/encoder.hpp index 6f4c2dc..64b4f5c 100644 --- a/src/codec/lac/encoder.hpp +++ b/src/codec/lac/encoder.hpp @@ -40,15 +40,6 @@ class Encoder { bool partitioning_enabled; bool debug_partitions; size_t thread_count; - std::vector candidates; - - uint32_t select_block_size(const std::vector& left, - const std::vector& right, - size_t position) const; - - double block_complexity(const std::vector& channel, - size_t position, - size_t length) const; }; } // namespace LAC diff --git a/src/codec/lac/thread_limit.hpp b/src/codec/lac/thread_limit.hpp new file mode 100644 index 0000000..59fb9d7 --- /dev/null +++ b/src/codec/lac/thread_limit.hpp @@ -0,0 +1,35 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace LAC { + +inline size_t parse_thread_limit(const char* value) { + if (value == nullptr || value[0] == '\0') return 0; + for (const char* p = value; *p != '\0'; ++p) { + if (*p < '0' || *p > '9') { + throw std::invalid_argument("LAC_THREADS must be a positive integer"); + } + } + + errno = 0; + char* end = nullptr; + const unsigned long long parsed = std::strtoull(value, &end, 10); + if (errno != 0 || end == value || *end != '\0' || parsed == 0) { + throw std::invalid_argument("LAC_THREADS must be a positive integer"); + } + if (parsed > static_cast(std::numeric_limits::max())) { + throw std::invalid_argument("LAC_THREADS is too large"); + } + return static_cast(parsed); +} + +inline size_t resolve_thread_limit(size_t explicit_limit) { + if (explicit_limit > 0) return explicit_limit; + return parse_thread_limit(std::getenv("LAC_THREADS")); +} + +} // namespace LAC diff --git a/src/codec/rice/rice.cpp b/src/codec/rice/rice.cpp index 07b5bb8..8512b0f 100644 --- a/src/codec/rice/rice.cpp +++ b/src/codec/rice/rice.cpp @@ -1,5 +1,6 @@ #include "rice.hpp" #include +#include #include #include @@ -19,9 +20,7 @@ void Rice::encode(BitWriter& w, int32_t value, uint32_t k) { uint32_t q = u >> k; uint32_t r = u & ((1u << k) - 1); - for (uint32_t i = 0; i < q; ++i) { - w.write_bit(1); - } + w.write_unary_ones(q); w.write_bit(0); if (k > 0) { @@ -95,10 +94,9 @@ uint32_t Rice::adapt_k(uint64_t sum, uint32_t count, AdaptState& state) { // Micro window flags. // Base mean and k estimate using integer rounding. const uint64_t mean = (sum + (count >> 1)) / count; - uint32_t k = 0; - while ((1u << k) < mean && k < kMaxRiceK) { - ++k; - } + const uint32_t k = (mean <= 1) + ? 0u + : std::min(kMaxRiceK, std::bit_width(mean - 1u)); const uint32_t q_base = static_cast((k >= kMaxRiceK) ? 0u : (current_u >> k)); const uint8_t is_large = static_cast(q_base > 3u); diff --git a/src/main.cpp b/src/main.cpp index cb08a9b..c7ffced 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,5 +1,8 @@ #include #include +#include +#include +#include #include #include #include @@ -7,7 +10,18 @@ #include #include #include +#include +#include #include +#include +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#include +#endif #include "io/wav_io.hpp" #include "codec/lac/encoder.hpp" #include "codec/lac/decoder.hpp" @@ -53,6 +67,120 @@ static bool paths_refer_to_same_file(const std::string& input_path, const std::s return !ec && normalized_input == normalized_output; } +static std::string temporary_output_suffix() { + static std::atomic sequence{0}; + uint64_t token = + static_cast(std::chrono::steady_clock::now().time_since_epoch().count()); + token ^= ++sequence; + try { + std::random_device random; + token ^= (static_cast(random()) << 32) ^ random(); + } catch (...) { + // The clock and process-local sequence still provide collision retry input. + } + std::ostringstream out; + out << std::hex << token; + return out.str(); +} + +static bool replace_output_file(const std::filesystem::path& temporary_path, + const std::filesystem::path& output_path) { +#ifdef _WIN32 + return MoveFileExW(temporary_path.c_str(), + output_path.c_str(), + MOVEFILE_REPLACE_EXISTING | MOVEFILE_WRITE_THROUGH) != 0; +#else + std::error_code ec; + std::filesystem::rename(temporary_path, output_path, ec); + return !ec; +#endif +} + +static bool create_private_directory(const std::filesystem::path& path, + std::error_code& ec) { +#ifdef _WIN32 + return std::filesystem::create_directory(path, ec); +#else + if (::mkdir(path.c_str(), S_IRWXU) == 0) { + if (::chmod(path.c_str(), S_IRWXU) != 0) { + ec = std::error_code(errno, std::generic_category()); + std::error_code cleanup_ec; + std::filesystem::remove(path, cleanup_ec); + return false; + } + ec.clear(); + return true; + } + ec = std::error_code(errno, std::generic_category()); + return false; +#endif +} + +class StagedOutputFile { +public: + explicit StagedOutputFile(const std::string& output_path) + : output_path(output_path) { + const std::filesystem::path parent = + this->output_path.parent_path().empty() + ? std::filesystem::path(".") + : this->output_path.parent_path(); + if (this->output_path.filename().empty()) return; + + for (int attempt = 0; attempt < 128; ++attempt) { + const auto candidate = parent / (".lac-tmp." + temporary_output_suffix()); + std::error_code ec; + if (!create_private_directory(candidate, ec)) { + if (!ec || ec == std::make_error_code(std::errc::file_exists)) continue; + return; + } + + this->temporary_directory = candidate; + this->temporary_path = candidate / "output"; + return; + } + } + + ~StagedOutputFile() { + this->cleanup(); + } + + bool is_ready() const { + return !this->temporary_path.empty(); + } + + std::string path() const { + return this->temporary_path.string(); + } + + bool publish(const std::string& input_path) { + if (!this->is_ready()) return false; + if (paths_refer_to_same_file(input_path, this->output_path.string())) return false; + if (!replace_output_file(this->temporary_path, this->output_path)) return false; + + this->temporary_path.clear(); + std::error_code ec; + std::filesystem::remove(this->temporary_directory, ec); + if (!ec) this->temporary_directory.clear(); + return true; + } + +private: + void cleanup() { + std::error_code ec; + if (!this->temporary_path.empty()) { + std::filesystem::remove(this->temporary_path, ec); + } + ec.clear(); + if (!this->temporary_directory.empty()) { + std::filesystem::remove(this->temporary_directory, ec); + } + } + + std::filesystem::path output_path; + std::filesystem::path temporary_directory; + std::filesystem::path temporary_path; +}; + static bool parse_threads_flag(const std::string& flag, size_t& out_threads) { const std::string prefix = "--threads="; if (flag.rfind(prefix, 0) != 0) { @@ -82,7 +210,7 @@ static bool parse_threads_flag(const std::string& flag, size_t& out_threads) { static void usage() { std::cerr << "Usage:\n"; std::cerr << " lac_cli encode input.wav output.lac [--stereo-mode=lr|ms] [--threads=N] [--debug-threads] [--debug-lpc] [--debug-stereo-est] [--debug-zr] [--debug-partitions] [--no-partitioning]\n"; - std::cerr << " lac_cli decode input.lac output.wav [--debug-threads]\n"; + std::cerr << " lac_cli decode input.lac output.wav [--threads=N] [--debug-threads]\n"; std::cerr << " lac_cli selftest\n"; } @@ -176,7 +304,10 @@ int main(int argc, char** argv) { << " zr_bytes=" << bitstream.size() << " gain=" << gain << "%\n"; } - if (!save_file(out_path, bitstream)) { + StagedOutputFile staged_output(out_path); + if (!staged_output.is_ready() || + !save_file(staged_output.path(), bitstream) || + !staged_output.publish(in_path)) { std::cerr << "Failed to write LAC file: " << out_path << "\n"; return 1; } @@ -206,10 +337,13 @@ int main(int argc, char** argv) { return 1; } bool debug_threads = false; + size_t thread_count = 0; for (int i = 4; i < argc; ++i) { std::string flag = argv[i]; if (flag == "--debug-threads") { debug_threads = true; + } else if (parse_threads_flag(flag, thread_count)) { + // Parsed above. } else { usage(); return 1; @@ -223,6 +357,7 @@ int main(int argc, char** argv) { LAC::ThreadCollector decoderCollector; LAC::ThreadCollector* decoderCollectorPtr = (debug_threads ? &decoderCollector : nullptr); LAC::Decoder decoder(decoderCollectorPtr); + decoder.set_thread_count(thread_count); std::vector left, right; FrameHeader hdr; decoder.decode(bitstream.data(), bitstream.size(), left, right, &hdr); @@ -230,7 +365,10 @@ int main(int argc, char** argv) { std::cerr << "Decode failed or produced no samples\n"; return 1; } - if (!write_wav(out_path, left, right, hdr.channels, hdr.sample_rate, hdr.bit_depth)) { + StagedOutputFile staged_output(out_path); + if (!staged_output.is_ready() || + !write_wav(staged_output.path(), left, right, hdr.channels, hdr.sample_rate, hdr.bit_depth) || + !staged_output.publish(in_path)) { std::cerr << "Failed to write WAV: " << out_path << "\n"; return 1; } diff --git a/tests/test_cli.cpp b/tests/test_cli.cpp index 4ca5a3a..91bf21e 100644 --- a/tests/test_cli.cpp +++ b/tests/test_cli.cpp @@ -11,6 +11,9 @@ #include #include #include +#ifndef _WIN32 +#include +#endif namespace { @@ -70,6 +73,12 @@ void save_binary_file(const std::filesystem::path& path, const std::vector& data, size_t offset) { assert(offset + 2 <= data.size()); return static_cast( @@ -153,7 +162,7 @@ void assert_roundtrip(const std::filesystem::path& cli, std::vector encode_args = {"encode", input.string(), encoded.string()}; encode_args.insert(encode_args.end(), encode_flags.begin(), encode_flags.end()); assert(run_cli(cli, encode_args)); - assert(run_cli(cli, {"decode", encoded.string(), restored.string()})); + assert(run_cli(cli, {"decode", encoded.string(), restored.string(), "--threads=2"})); std::vector restored_left; std::vector restored_right; @@ -188,6 +197,23 @@ void assert_same_path_rejected(const std::filesystem::path& cli, const auto lac_before = load_binary_file(lac); assert(!run_cli(cli, {"decode", lac.string(), lac.string()})); assert(load_binary_file(lac) == lac_before); + + const auto wav_alias = dir / "same_path_alias.wav"; + std::error_code ec; + std::filesystem::create_hard_link(wav, wav_alias, ec); + if (!ec) { + assert(!run_cli(cli, {"encode", wav.string(), wav_alias.string()})); + assert(load_binary_file(wav) == wav_before); + } + + const auto lac_alias = dir / "same_path_alias.lac"; + ec.clear(); + std::filesystem::create_hard_link(lac, lac_alias, ec); + if (!ec) { + assert(!run_cli(cli, {"decode", lac.string(), lac_alias.string()})); + assert(load_binary_file(lac) == lac_before); + } + assert_no_staged_output_siblings(dir); } void assert_malformed_input_rejected_without_clobber(const std::filesystem::path& cli, @@ -207,6 +233,7 @@ void assert_malformed_input_rejected_without_clobber(const std::filesystem::path save_binary_file(lac_output, lac_sentinel); assert(!run_cli(cli, {"encode", malformed_wav.string(), lac_output.string()})); assert(load_binary_file(lac_output) == lac_sentinel); + assert_no_staged_output_siblings(dir); assert(run_cli(cli, {"encode", valid_wav.string(), valid_lac.string()})); auto lac_bytes = load_binary_file(valid_lac); @@ -216,6 +243,161 @@ void assert_malformed_input_rejected_without_clobber(const std::filesystem::path save_binary_file(wav_output, wav_sentinel); assert(!run_cli(cli, {"decode", malformed_lac.string(), wav_output.string()})); assert(load_binary_file(wav_output) == wav_sentinel); + assert_no_staged_output_siblings(dir); +} + +void assert_existing_output_overwritten(const std::filesystem::path& cli, + const std::filesystem::path& dir) { + const auto wav = dir / "overwrite_source.wav"; + const auto lac = dir / "overwrite_output.lac"; + const auto restored = dir / "overwrite_restored.wav"; + const auto left = make_samples(64, 24, 0); + const auto right = make_samples(64, 24, 37); + assert(write_wav(wav.string(), left, right, 2, 48000, 24)); + + const std::vector sentinel = {0xAAu, 0x55u, 0xAAu}; + save_binary_file(lac, sentinel); + assert(run_cli(cli, {"encode", wav.string(), lac.string()})); + assert(load_binary_file(lac) != sentinel); + assert_no_staged_output_siblings(dir); + + save_binary_file(restored, sentinel); + assert(run_cli(cli, {"decode", lac.string(), restored.string()})); + std::vector restored_left; + std::vector restored_right; + uint16_t channels = 0; + uint32_t sample_rate = 0; + uint8_t bit_depth = 0; + assert(read_wav(restored.string(), + restored_left, + restored_right, + channels, + sample_rate, + bit_depth)); + assert(restored_left == left); + assert(restored_right == right); + assert(channels == 2); + assert(sample_rate == 48000); + assert(bit_depth == 24); + assert_no_staged_output_siblings(dir); +} + +void assert_publish_failure_cleans_up(const std::filesystem::path& cli, + const std::filesystem::path& dir) { + const auto wav = dir / "publish_failure_source.wav"; + const auto lac = dir / "publish_failure_source.lac"; + const auto output_dir = dir / "publish_failure_output"; + const auto marker = output_dir / "marker"; + assert(write_wav(wav.string(), make_samples(64, 16, 0), {}, 1, 44100, 16)); + assert(run_cli(cli, {"encode", wav.string(), lac.string()})); + assert(std::filesystem::create_directory(output_dir)); + const std::vector sentinel = {0x11u, 0x22u, 0x33u}; + save_binary_file(marker, sentinel); + + assert(!run_cli(cli, {"encode", wav.string(), output_dir.string()})); + assert(load_binary_file(marker) == sentinel); + assert_no_staged_output_siblings(dir); + + assert(!run_cli(cli, {"decode", lac.string(), output_dir.string()})); + assert(load_binary_file(marker) == sentinel); + assert_no_staged_output_siblings(dir); +} + +void assert_link_targets_not_clobbered(const std::filesystem::path& cli, + const std::filesystem::path& dir) { + const auto wav = dir / "link_source.wav"; + const auto baseline_lac = dir / "link_source.lac"; + assert(write_wav(wav.string(), make_samples(64, 16, 0), {}, 1, 44100, 16)); + assert(run_cli(cli, {"encode", wav.string(), baseline_lac.string()})); + + const std::vector sentinel = {0x44u, 0x55u, 0x66u}; + const auto hardlink_target = dir / "hardlink_target"; + const auto hardlink_output = dir / "hardlink_output.lac"; + save_binary_file(hardlink_target, sentinel); + std::error_code ec; + std::filesystem::create_hard_link(hardlink_target, hardlink_output, ec); + if (!ec) { + assert(run_cli(cli, {"encode", wav.string(), hardlink_output.string()})); + assert(load_binary_file(hardlink_target) == sentinel); + assert(load_binary_file(hardlink_output) != sentinel); + } + assert_no_staged_output_siblings(dir); + + const auto symlink_target = dir / "symlink_target"; + const auto symlink_output = dir / "symlink_output.wav"; + save_binary_file(symlink_target, sentinel); + ec.clear(); + std::filesystem::create_symlink(symlink_target, symlink_output, ec); + if (!ec) { + assert(run_cli(cli, {"decode", baseline_lac.string(), symlink_output.string()})); + assert(load_binary_file(symlink_target) == sentinel); + assert(!std::filesystem::is_symlink(symlink_output)); + } + assert_no_staged_output_siblings(dir); +} + +void assert_long_output_filenames_supported(const std::filesystem::path& cli, + const std::filesystem::path& dir) { + const auto wav = dir / "long_name_source.wav"; + const auto lac = dir / (std::string(240, 'l') + ".lac"); + const auto restored = dir / (std::string(240, 'w') + ".wav"); + assert(write_wav(wav.string(), make_samples(64, 16, 0), {}, 1, 44100, 16)); + assert(run_cli(cli, {"encode", wav.string(), lac.string()})); + assert(run_cli(cli, {"decode", lac.string(), restored.string()})); + + std::vector restored_left; + std::vector restored_right; + uint16_t channels = 0; + uint32_t sample_rate = 0; + uint8_t bit_depth = 0; + assert(read_wav(restored.string(), + restored_left, + restored_right, + channels, + sample_rate, + bit_depth)); + assert(restored_left == make_samples(64, 16, 0)); + assert(restored_right.empty()); + assert_no_staged_output_siblings(dir); +} + +void assert_restrictive_umask_supported(const std::filesystem::path& cli, + const std::filesystem::path& dir) { +#ifndef _WIN32 + const auto wav = dir / "umask_source.wav"; + const auto lac = dir / "umask_output.lac"; + const auto restored = dir / "umask_restored.wav"; + assert(write_wav(wav.string(), make_samples(64, 16, 0), {}, 1, 44100, 16)); + + const mode_t previous_encode_umask = ::umask(0777); + const bool encoded = run_cli(cli, {"encode", wav.string(), lac.string()}); + ::umask(previous_encode_umask); + assert(encoded); + + std::error_code ec; + std::filesystem::permissions( + lac, + std::filesystem::perms::owner_read | std::filesystem::perms::owner_write, + std::filesystem::perm_options::replace, + ec); + assert(!ec); + + const mode_t previous_decode_umask = ::umask(0777); + const bool decoded = run_cli(cli, {"decode", lac.string(), restored.string()}); + ::umask(previous_decode_umask); + assert(decoded); + + std::filesystem::permissions( + restored, + std::filesystem::perms::owner_read | std::filesystem::perms::owner_write, + std::filesystem::perm_options::replace, + ec); + assert(!ec); + assert_no_staged_output_siblings(dir); +#else + (void)cli; + (void)dir; +#endif } } // namespace @@ -247,6 +429,11 @@ int main(int argc, char** argv) { Block::MAX_BLOCK_SIZE + 37u); assert_same_path_rejected(cli, dir); assert_malformed_input_rejected_without_clobber(cli, dir); + assert_existing_output_overwritten(cli, dir); + assert_publish_failure_cleans_up(cli, dir); + assert_link_targets_not_clobbered(cli, dir); + assert_long_output_filenames_supported(cli, dir); + assert_restrictive_umask_supported(cli, dir); std::error_code ec; std::filesystem::remove_all(dir, ec); diff --git a/tests/test_e2e.cpp b/tests/test_e2e.cpp index 0af4ee0..3237057 100644 --- a/tests/test_e2e.cpp +++ b/tests/test_e2e.cpp @@ -1,9 +1,11 @@ #include "codec/lac/encoder.hpp" #include "codec/lac/decoder.hpp" #include "codec/block/constants.hpp" +#include "codec/block/encoder.hpp" #include "io/wav_io.hpp" #include "codec/frame/frame_header.hpp" #include "codec/rice/rice.hpp" +#include #include #include #include @@ -12,6 +14,7 @@ #include #include #include +#include #include namespace { @@ -256,6 +259,64 @@ uint32_t get_u32_le(const std::vector& in, size_t offset) { (static_cast(in[offset + 3]) << 24)); } +uint32_t get_u32_be(const std::vector& in, size_t offset) { + assert(offset + 4 <= in.size()); + return static_cast( + (static_cast(in[offset]) << 24) | + (static_cast(in[offset + 1]) << 16) | + (static_cast(in[offset + 2]) << 8) | + static_cast(in[offset + 3])); +} + +void set_u32_be(std::vector& out, size_t offset, uint32_t value) { + assert(offset + 4 <= out.size()); + out[offset] = static_cast((value >> 24) & 0xFFu); + out[offset + 1] = static_cast((value >> 16) & 0xFFu); + out[offset + 2] = static_cast((value >> 8) & 0xFFu); + out[offset + 3] = static_cast(value & 0xFFu); +} + +size_t v3_payload_offset(const std::vector& stream) { + assert(stream.size() >= 14); + assert(stream[2] == 3u); + const uint32_t block_count = get_u32_be(stream, 10); + const size_t offset = 14u + static_cast(block_count) * 8u; + assert(offset <= stream.size()); + return offset; +} + +std::vector v3_block_sizes(const std::vector& stream) { + assert(stream.size() >= 14); + assert(stream[2] == 3u); + const uint32_t block_count = get_u32_be(stream, 10); + assert(14u + static_cast(block_count) * 8u <= stream.size()); + std::vector sizes; + sizes.reserve(block_count); + for (uint32_t i = 0; i < block_count; ++i) { + sizes.push_back(get_u32_be(stream, 14u + static_cast(i) * 8u)); + } + return sizes; +} + +std::vector v3_stereo_flags(const std::vector& stream) { + assert(stream.size() >= 14); + assert(stream[2] == 3u); + assert(stream[4] == 2u); + const uint32_t block_count = get_u32_be(stream, 10); + size_t payload_offset = v3_payload_offset(stream); + std::vector flags; + flags.reserve(block_count); + for (uint32_t i = 0; i < block_count; ++i) { + const uint32_t payload_size = get_u32_be(stream, 18u + static_cast(i) * 8u); + assert(payload_size > 0u); + assert(payload_offset + payload_size <= stream.size()); + flags.push_back(stream[payload_offset]); + payload_offset += payload_size; + } + assert(payload_offset == stream.size()); + return flags; +} + void append_fourcc(std::vector& out, const char* fourcc) { out.insert(out.end(), fourcc, fourcc + 4); } @@ -312,6 +373,7 @@ bool can_read_wav_bytes(const std::vector& bytes) { std::vector make_lac_with_single_mono_sample(int32_t sample, uint8_t bit_depth = 16) { BitWriter bw; FrameHeader hdr; + hdr.version = 2; hdr.channels = 1; hdr.stereo_mode = 0; hdr.bit_depth = bit_depth; @@ -331,6 +393,7 @@ std::vector make_lac_with_single_mono_sample(int32_t sample, uint8_t bi std::vector make_lac_header_only_with_block_count(uint32_t block_count) { BitWriter bw; FrameHeader hdr; + hdr.version = 2; hdr.channels = 1; hdr.stereo_mode = 0; hdr.write(bw); @@ -342,6 +405,7 @@ std::vector make_lac_header_only_with_block_count(uint32_t block_count) std::vector make_lac_with_short_non_final_block() { BitWriter bw; FrameHeader hdr; + hdr.version = 2; hdr.channels = 1; hdr.stereo_mode = 0; hdr.write(bw); @@ -356,6 +420,7 @@ std::vector make_lac_over_allocation_limit() { constexpr uint32_t blocks = 16385; BitWriter bw; FrameHeader hdr; + hdr.version = 2; hdr.channels = 1; hdr.stereo_mode = 0; hdr.write(bw); @@ -367,6 +432,16 @@ std::vector make_lac_over_allocation_limit() { return bw.get_buffer(); } +std::vector make_v3_mono_stream(size_t frames) { + std::vector pcm(frames, 0); + for (size_t i = 0; i < frames; ++i) { + pcm[i] = static_cast((i % 257u) - 128); + } + LAC::Encoder encoder(12, 0, 44100, 16, false, false, false); + encoder.set_thread_count(2); + return encoder.encode(pcm, {}); +} + } // namespace @@ -431,12 +506,281 @@ void run_decoder_error_tests() { LAC::Encoder stereo_encoder(12, 2, 44100, 16, false, false, false); std::vector stereo_stream = stereo_encoder.encode({0}, {0}); - stereo_stream[18] = 2; // header + block count + one block size + stereo_stream[v3_payload_offset(stereo_stream)] = 2; expect_throw("invalid-per-block-stereo-flag", stereo_stream.data(), stereo_stream.size()); + std::vector unknown_version = make_v3_mono_stream(4096); + unknown_version[2] = 4; + expect_throw("unknown-frame-version", unknown_version.data(), unknown_version.size()); + + std::vector truncated_v3_table = make_v3_mono_stream(4096); + truncated_v3_table.resize(14u + 7u); + expect_throw("truncated-v3-table", truncated_v3_table.data(), truncated_v3_table.size()); + + std::vector zero_compressed_size = make_v3_mono_stream(4096); + set_u32_be(zero_compressed_size, 18, 0); + expect_throw("zero-compressed-block-size", zero_compressed_size.data(), zero_compressed_size.size()); + + std::vector mismatched_compressed_sum = make_v3_mono_stream(4096); + const uint32_t original_size = get_u32_be(mismatched_compressed_sum, 18); + set_u32_be(mismatched_compressed_sum, 18, original_size + 1u); + expect_throw("compressed-size-sum-mismatch", mismatched_compressed_sum.data(), mismatched_compressed_sum.size()); + + std::vector crossed_block_boundary = make_v3_mono_stream(32768); + assert(get_u32_be(crossed_block_boundary, 10) >= 2u); + const uint32_t first_payload_size = get_u32_be(crossed_block_boundary, 18); + const uint32_t second_payload_size = get_u32_be(crossed_block_boundary, 26); + assert(first_payload_size > 1u); + set_u32_be(crossed_block_boundary, 18, first_payload_size - 1u); + set_u32_be(crossed_block_boundary, 26, second_payload_size + 1u); + expect_throw("crossed-v3-block-boundary", crossed_block_boundary.data(), crossed_block_boundary.size()); + + std::vector extra_byte_in_block = make_v3_mono_stream(32768); + const size_t payload_offset = v3_payload_offset(extra_byte_in_block); + const uint32_t first_span = get_u32_be(extra_byte_in_block, 18); + extra_byte_in_block.insert(extra_byte_in_block.begin() + payload_offset + first_span, 0u); + set_u32_be(extra_byte_in_block, 18, first_span + 1u); + expect_throw("extra-byte-inside-v3-block", extra_byte_in_block.data(), extra_byte_in_block.size()); + std::cout << "decoder error tests ok\n"; } +void run_decoder_thread_tests() { + constexpr size_t frames = 65536; + std::vector input(frames, 0); + for (size_t i = 0; i < frames; ++i) { + input[i] = static_cast((i % 2000u) - 1000); + } + + LAC::Encoder encoder(12, 0, 44100, 16, false, false, false); + encoder.set_thread_count(2); + const std::vector v3_stream = encoder.encode(input, {}); + assert(v3_stream[2] == 3u); + assert(get_u32_be(v3_stream, 10) > 1u); + + LAC::ThreadCollector v3_collector; + LAC::Decoder v3_decoder(&v3_collector); + v3_decoder.set_thread_count(2); + std::vector out_left; + std::vector out_right; + FrameHeader hdr; + v3_decoder.decode(v3_stream.data(), v3_stream.size(), out_left, out_right, &hdr); + assert(out_left == input); + assert(out_right.empty()); + assert(hdr.version == 3u); + const size_t v3_threads = v3_collector.snapshot().size(); + assert(v3_threads >= 1u); + assert(v3_threads <= 2u); + + const std::vector v2_stream = make_lac_with_single_mono_sample(123); + LAC::ThreadCollector v2_collector; + LAC::Decoder v2_decoder(&v2_collector); + v2_decoder.set_thread_count(2); + v2_decoder.decode(v2_stream.data(), v2_stream.size(), out_left, out_right, &hdr); + assert(out_left == std::vector{123}); + assert(out_right.empty()); + assert(hdr.version == 2u); + assert(v2_collector.snapshot().size() == 1u); + std::cout << "decoder thread tests ok\n"; +} + +void run_block_planner_tests() { + auto assert_sizes = [](size_t frames, const std::vector& expected) { + LAC::Encoder encoder(12, 0, 44100, 16, false, false, false); + const std::vector stream = encoder.encode(std::vector(frames, 0), {}); + assert(v3_block_sizes(stream) == expected); + }; + + assert_sizes(16383, {16383}); + assert_sizes(16384, {16384}); + assert_sizes(16385, {16384, 1}); + assert_sizes(65536, {16384, 16384, 16384, 16384}); + + std::vector quarters(Block::MAX_BLOCK_SIZE, 0); + for (size_t i = Block::MAX_BLOCK_SIZE / 4u; i < Block::MAX_BLOCK_SIZE / 2u; ++i) { + quarters[i] = 1000; + } + for (size_t i = Block::MAX_BLOCK_SIZE / 2u; i < 3u * Block::MAX_BLOCK_SIZE / 4u; ++i) { + quarters[i] = -1000; + } + LAC::Encoder encoder(12, 0, 44100, 16, false, false, false); + const std::vector quarters_stream = encoder.encode(quarters, {}); + assert(v3_block_sizes(quarters_stream) == std::vector{Block::MAX_BLOCK_SIZE}); + Block::Encoder block_encoder(12); + block_encoder.set_zero_run_enabled(true); + block_encoder.set_partitioning_enabled(true); + const std::vector quarters_block = block_encoder.encode(quarters); + assert(quarters_stream.size() == 22u + quarters_block.size()); + + std::vector silence_then_noise(Block::MAX_BLOCK_SIZE); + uint32_t state = 29u * 8191u; + for (size_t i = 0; i < silence_then_noise.size(); ++i) { + state = state * 1664525u + 1013904223u; + const int32_t sample = static_cast(state >> 16) - 32768; + silence_then_noise[i] = i < silence_then_noise.size() / 2u ? 0 : sample; + } + const std::vector silence_then_noise_stream = encoder.encode(silence_then_noise, {}); + assert(v3_block_sizes(silence_then_noise_stream) == std::vector{Block::MAX_BLOCK_SIZE}); + const std::vector silence_then_noise_block = block_encoder.encode(silence_then_noise); + assert(silence_then_noise_stream.size() == 22u + silence_then_noise_block.size()); + std::cout << "block planner tests ok\n"; +} + +void run_stereo_planner_tests() { + constexpr size_t block = Block::MAX_BLOCK_SIZE; + std::vector left(block * 2u); + std::vector right(block * 2u); + for (size_t i = 0; i < left.size(); ++i) { + left[i] = static_cast((i % 2001u) - 1000); + right[i] = (i < block) ? left[i] : 0; + } + + LAC::Encoder encoder(12, 2, 44100, 16, false, false, false); + encoder.set_thread_count(2); + const std::vector mixed_stream = encoder.encode(left, right); + const std::vector mixed_flags = v3_stereo_flags(mixed_stream); + assert(std::find(mixed_flags.begin(), mixed_flags.end(), 0u) != mixed_flags.end()); + assert(std::find(mixed_flags.begin(), mixed_flags.end(), 1u) != mixed_flags.end()); + + std::vector identical(block, 0); + for (size_t i = 0; i < identical.size(); ++i) { + identical[i] = static_cast((i % 2001u) - 1000); + } + const std::vector ms_stream = encoder.encode(identical, identical); + const std::vector ms_flags = v3_stereo_flags(ms_stream); + assert(!ms_flags.empty()); + assert(std::all_of(ms_flags.begin(), ms_flags.end(), [](uint8_t flag) { return flag == 1u; })); + + const std::vector lr_stream = encoder.encode(identical, std::vector(block, 0)); + const std::vector lr_flags = v3_stereo_flags(lr_stream); + assert(!lr_flags.empty()); + assert(std::all_of(lr_flags.begin(), lr_flags.end(), [](uint8_t flag) { return flag == 0u; })); + + std::vector anticorrelated(4096); + std::vector inverse(4096); + uint32_t state = 1; + for (size_t i = 0; i < anticorrelated.size(); ++i) { + state = state * 1664525u + 1013904223u; + anticorrelated[i] = static_cast(state >> 18) - 8192; + inverse[i] = -anticorrelated[i]; + } + anticorrelated[0] = -32768; + inverse[0] = 32767; + + LAC::Encoder auto_encoder(12, 2, 44100, 16, false, false, false); + LAC::Encoder forced_lr(12, 0, 44100, 16, false, false, false); + LAC::Encoder forced_ms(12, 1, 44100, 16, false, false, false); + auto_encoder.set_thread_count(1); + forced_lr.set_thread_count(1); + forced_ms.set_thread_count(1); + const std::vector auto_stream = auto_encoder.encode(anticorrelated, inverse); + const std::vector forced_lr_stream = forced_lr.encode(anticorrelated, inverse); + const std::vector forced_ms_stream = forced_ms.encode(anticorrelated, inverse); + const std::vector auto_flags = v3_stereo_flags(auto_stream); + assert(std::all_of(auto_flags.begin(), auto_flags.end(), [](uint8_t flag) { return flag == 1u; })); + assert(auto_stream.size() == forced_ms_stream.size() + auto_flags.size()); + assert(auto_stream.size() < forced_lr_stream.size()); + + std::vector random_left(4096); + std::vector random_right(4096); + state = 49917u; + for (size_t i = 0; i < random_left.size(); ++i) { + state = state * 1664525u + 1013904223u; + random_left[i] = static_cast(state >> 16) - 32768; + state = state * 1664525u + 1013904223u; + random_right[i] = static_cast(state >> 16) - 32768; + } + const std::vector auto_random_stream = auto_encoder.encode(random_left, random_right); + const std::vector lr_random_stream = forced_lr.encode(random_left, random_right); + const std::vector ms_random_stream = forced_ms.encode(random_left, random_right); + const std::vector auto_random_flags = v3_stereo_flags(auto_random_stream); + assert(std::all_of(auto_random_flags.begin(), auto_random_flags.end(), [](uint8_t flag) { return flag == 0u; })); + assert(auto_random_stream.size() == lr_random_stream.size() + auto_random_flags.size()); + assert(auto_random_stream.size() < ms_random_stream.size()); + + std::vector alternating_left(4095); + std::vector alternating_right(4095); + state = 758392u; + for (size_t i = 0; i < alternating_left.size(); ++i) { + state = state * 1664525u + 1013904223u; + state = state * 1664525u + 1013904223u; + const int32_t noise = static_cast(state >> 16) - 32768; + alternating_left[i] = (i & 1u) ? static_cast(i) : -static_cast(i); + alternating_right[i] = alternating_left[i] + noise % 257; + } + const std::vector auto_alternating_stream = auto_encoder.encode(alternating_left, alternating_right); + const std::vector lr_alternating_stream = forced_lr.encode(alternating_left, alternating_right); + const std::vector ms_alternating_stream = forced_ms.encode(alternating_left, alternating_right); + const std::vector auto_alternating_flags = v3_stereo_flags(auto_alternating_stream); + assert(std::all_of(auto_alternating_flags.begin(), auto_alternating_flags.end(), [](uint8_t flag) { return flag == 0u; })); + assert(auto_alternating_stream.size() == lr_alternating_stream.size() + auto_alternating_flags.size()); + assert(auto_alternating_stream.size() < ms_alternating_stream.size()); + + std::vector walk_left(4095); + std::vector walk_right(4095); + state = 13210319u; + int32_t left_walk = 0; + int32_t right_walk = 0; + for (size_t i = 0; i < walk_left.size(); ++i) { + state = state * 1664525u + 1013904223u; + const int32_t x = static_cast(state >> 16) - 32768; + state = state * 1664525u + 1013904223u; + const int32_t y = static_cast(state >> 16) - 32768; + left_walk = std::clamp(left_walk + x % 65, -32768, 32767); + right_walk = std::clamp(right_walk + y % 65, -32768, 32767); + walk_left[i] = left_walk; + walk_right[i] = right_walk; + } + const std::vector auto_walk_stream = auto_encoder.encode(walk_left, walk_right); + const std::vector lr_walk_stream = forced_lr.encode(walk_left, walk_right); + const std::vector ms_walk_stream = forced_ms.encode(walk_left, walk_right); + const std::vector auto_walk_flags = v3_stereo_flags(auto_walk_stream); + assert(std::all_of(auto_walk_flags.begin(), auto_walk_flags.end(), [](uint8_t flag) { return flag == 0u; })); + assert(auto_walk_stream.size() == lr_walk_stream.size() + auto_walk_flags.size()); + assert(auto_walk_stream.size() < ms_walk_stream.size()); + + std::vector long_walk_left(Block::MAX_BLOCK_SIZE); + std::vector long_walk_right(Block::MAX_BLOCK_SIZE); + state = 13210319u; + left_walk = 0; + right_walk = 0; + for (size_t i = 0; i < long_walk_left.size(); ++i) { + state = state * 1664525u + 1013904223u; + const int32_t x = static_cast(state >> 16) - 32768; + state = state * 1664525u + 1013904223u; + const int32_t y = static_cast(state >> 16) - 32768; + left_walk = std::clamp(left_walk + x % 65, -32768, 32767); + right_walk = std::clamp(right_walk + y % 65, -32768, 32767); + long_walk_left[i] = left_walk; + long_walk_right[i] = right_walk; + } + const std::vector auto_long_walk_stream = auto_encoder.encode(long_walk_left, long_walk_right); + const std::vector lr_long_walk_stream = forced_lr.encode(long_walk_left, long_walk_right); + const std::vector ms_long_walk_stream = forced_ms.encode(long_walk_left, long_walk_right); + const std::vector auto_long_walk_flags = v3_stereo_flags(auto_long_walk_stream); + assert(auto_long_walk_flags == std::vector{0u}); + assert(auto_long_walk_stream.size() == lr_long_walk_stream.size() + auto_long_walk_flags.size()); + assert(auto_long_walk_stream.size() < ms_long_walk_stream.size()); + + std::vector noise_left(Block::MAX_BLOCK_SIZE); + std::vector noise_right(Block::MAX_BLOCK_SIZE); + state = 1u; + for (size_t i = 0; i < noise_left.size(); ++i) { + state = state * 1664525u + 1013904223u; + noise_left[i] = static_cast(state >> 16) - 32768; + state = state * 1664525u + 1013904223u; + noise_right[i] = static_cast(state >> 16) - 32768; + } + const std::vector auto_noise_stream = auto_encoder.encode(noise_left, noise_right); + const std::vector lr_noise_stream = forced_lr.encode(noise_left, noise_right); + const std::vector ms_noise_stream = forced_ms.encode(noise_left, noise_right); + const std::vector auto_noise_flags = v3_stereo_flags(auto_noise_stream); + assert(auto_noise_flags.size() == 1u); + const std::vector& selected_noise_stream = auto_noise_flags[0] == 1u ? ms_noise_stream : lr_noise_stream; + assert(auto_noise_stream.size() == selected_noise_stream.size() + auto_noise_flags.size()); + std::cout << "stereo planner tests ok\n"; +} + void run_wav_validation_tests() { assert(can_read_wav_bytes(make_pcm16_mono_wav(true))); diff --git a/tests/test_lpc.cpp b/tests/test_lpc.cpp index 1bef2e9..0ac04c3 100644 --- a/tests/test_lpc.cpp +++ b/tests/test_lpc.cpp @@ -20,6 +20,9 @@ void run_decoder_error_tests(); void run_wav_validation_tests(); void run_encoder_validation_tests(); void run_thread_limit_tests(); +void run_decoder_thread_tests(); +void run_block_planner_tests(); +void run_stereo_planner_tests(); namespace { @@ -182,6 +185,9 @@ int main() { run_predictor_tests(); run_encoder_validation_tests(); run_thread_limit_tests(); + run_decoder_thread_tests(); + run_block_planner_tests(); + run_stereo_planner_tests(); run_decoder_error_tests(); run_wav_validation_tests(); run_e2e_tests(); diff --git a/tests/test_partitioning.cpp b/tests/test_partitioning.cpp index e83a92f..81a8d02 100644 --- a/tests/test_partitioning.cpp +++ b/tests/test_partitioning.cpp @@ -198,6 +198,40 @@ void run_partitioning_tests() { BitReader overflow(overflow_bytes); assert(!overflow.read_unary_ones(7u, ones)); + for (uint32_t prefix_bits : {0u, 1u, 5u}) { + for (uint32_t unary_ones : {0u, 1u, 7u, 8u, 9u, 255u, 256u, 4095u}) { + BitWriter writer; + writer.write_bits(0u, static_cast(prefix_bits)); + writer.write_unary_ones(unary_ones); + writer.write_bit(0u); + writer.flush_to_byte(); + + BitReader reader(writer.get_buffer()); + (void)reader.read_bits(static_cast(prefix_bits)); + uint32_t decoded_ones = 0; + assert(reader.read_unary_ones(unary_ones, decoded_ones)); + assert(decoded_ones == unary_ones); + } + } + + for (uint32_t prefix_bits = 0; prefix_bits <= 7u; ++prefix_bits) { + for (uint32_t value_bits = 1; value_bits <= 32u; ++value_bits) { + const uint32_t value = 0xA5F03C96u; + const uint32_t expected = (value_bits == 32u) + ? value + : (value & ((1u << value_bits) - 1u)); + BitWriter writer; + writer.write_bits(0u, static_cast(prefix_bits)); + writer.write_bits(expected, static_cast(value_bits)); + writer.flush_to_byte(); + + BitReader reader(writer.get_buffer()); + (void)reader.read_bits(static_cast(prefix_bits)); + assert(reader.read_bits(static_cast(value_bits)) == expected); + assert(!reader.has_error()); + } + } + BitWriter min_writer; Rice::encode(min_writer, std::numeric_limits::min(), 31u); min_writer.flush_to_byte(); From f056821ac25a7c21457c014a4a955a7a26a81634 Mon Sep 17 00:00:00 2001 From: audexdev Date: Thu, 4 Jun 2026 23:19:31 +0900 Subject: [PATCH 2/6] Optimize LAC decode hot paths --- src/codec/bitstream/bit_reader.cpp | 104 +++++++++++++++++------- src/codec/block/decoder.cpp | 126 ++++++++++++++++------------- src/codec/block/decoder.hpp | 7 +- src/codec/lac/decoder.cpp | 50 +++++------- src/codec/lpc/lpc.cpp | 27 +++++-- 5 files changed, 188 insertions(+), 126 deletions(-) diff --git a/src/codec/bitstream/bit_reader.cpp b/src/codec/bitstream/bit_reader.cpp index 0039d88..99178f8 100644 --- a/src/codec/bitstream/bit_reader.cpp +++ b/src/codec/bitstream/bit_reader.cpp @@ -1,5 +1,6 @@ #include "bit_reader.hpp" #include +#include namespace { inline uint32_t low_bits_mask(int bits) { @@ -41,31 +42,63 @@ uint32_t BitReader::read_bit() { uint32_t BitReader::read_bits(int nbits) { if (nbits <= 0) return 0; - if (!this->error && static_cast(nbits) <= this->bits_remaining()) { - uint32_t value = 0; - int remaining = nbits; - while (remaining > 0) { - const int available = 8 - this->bit_pos; - const int take = std::min(remaining, available); - const int shift = available - take; - const uint32_t chunk = - (static_cast(this->data[this->byte_pos]) >> shift) & - low_bits_mask(take); - value = static_cast((value << take) | chunk); - remaining -= take; - this->bit_pos += take; - if (this->bit_pos == 8) { - this->bit_pos = 0; - ++this->byte_pos; - } + if (this->error || static_cast(nbits) > this->bits_remaining()) { + this->mark_error(); + return 0; + } + + if (this->bit_pos == 0) { + if (nbits == 8) { + return this->data[this->byte_pos++]; + } + if (nbits == 16) { + const uint32_t value = + (static_cast(this->data[this->byte_pos]) << 8) | + static_cast(this->data[this->byte_pos + 1]); + this->byte_pos += 2; + return value; + } + if (nbits == 32) { + const uint32_t value = + (static_cast(this->data[this->byte_pos]) << 24) | + (static_cast(this->data[this->byte_pos + 1]) << 16) | + (static_cast(this->data[this->byte_pos + 2]) << 8) | + static_cast(this->data[this->byte_pos + 3]); + this->byte_pos += 4; + return value; + } + } + + const int available = 8 - this->bit_pos; + if (nbits <= available) { + const int shift = available - nbits; + const uint32_t value = + (static_cast(this->data[this->byte_pos]) >> shift) & + low_bits_mask(nbits); + this->bit_pos += nbits; + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; } return value; } uint32_t value = 0; - for (int i = 0; i < nbits; ++i) { - value = (value << 1) | this->read_bit(); - if (this->error) break; + int remaining = nbits; + while (remaining > 0) { + const int bits_available = 8 - this->bit_pos; + const int take = std::min(remaining, bits_available); + const int shift = bits_available - take; + const uint32_t chunk = + (static_cast(this->data[this->byte_pos]) >> shift) & + low_bits_mask(take); + value = static_cast((value << take) | chunk); + remaining -= take; + this->bit_pos += take; + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } } return value; @@ -74,18 +107,31 @@ uint32_t BitReader::read_bits(int nbits) { bool BitReader::read_unary_ones(uint32_t max_ones, uint32_t& ones) { ones = 0; while (this->byte_pos < this->size) { - if (this->bit_pos == 0 && this->data[this->byte_pos] == 0xFFu) { - if (max_ones - ones < 8u) return false; - ones += 8u; + const int available = 8 - this->bit_pos; + const uint32_t shifted = + (static_cast(this->data[this->byte_pos]) << this->bit_pos) & 0xFFu; + const uint32_t run = std::min( + static_cast(available), + std::countl_one(static_cast(shifted << 24))); + + if (max_ones - ones < run) { + return false; + } + ones += run; + this->bit_pos += static_cast(run); + if (this->bit_pos == 8) { + this->bit_pos = 0; ++this->byte_pos; - continue; } - const uint32_t bit = this->read_bit(); - if (this->error) return false; - if (bit == 0u) return true; - if (ones == max_ones) return false; - ++ones; + if (run < static_cast(available)) { + ++this->bit_pos; // consume unary terminator zero + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } + return true; + } } this->mark_error(); diff --git a/src/codec/block/decoder.cpp b/src/codec/block/decoder.cpp index 3004b21..e3532cd 100644 --- a/src/codec/block/decoder.cpp +++ b/src/codec/block/decoder.cpp @@ -4,6 +4,8 @@ #include "codec/lpc/lpc.hpp" #include "codec/rice/rice.hpp" #include "utils/logger.hpp" +#include +#include #include #include #include @@ -13,7 +15,14 @@ namespace Block { Decoder::Decoder() {} bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& out) { - if (block_size == 0 || block_size > MAX_BLOCK_SIZE) return false; + std::vector pcm(block_size); + if (!this->decode_into(br, block_size, pcm.data())) return false; + out.swap(pcm); + return true; +} + +bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { + if (block_size == 0 || block_size > MAX_BLOCK_SIZE || out == nullptr) return false; const bool debug_zr = LAC_DEBUG_ZR_ENABLED(); const bool debug_part = LAC_DEBUG_PART_ENABLED(); constexpr uint32_t kBinTagZero = 0b00u; @@ -41,11 +50,8 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o if (!stateless) return Rice::adapt_k(sum, count, *state); if (count == 0) return 0; const uint64_t mean = (sum + (count >> 1)) / count; - uint32_t k = 0; - while ((1u << k) < mean && k < 31u) { - ++k; - } - return k; + if (mean <= 1) return 0; + return std::min(31u, std::bit_width(mean - uint64_t{1})); }; auto partition_sizes_for_block = [](uint32_t size, uint8_t order) -> std::vector { @@ -70,13 +76,12 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o uint32_t samples, uint32_t initial_k, uint8_t residual_mode, - std::vector& residual, + int32_t* residual, size_t offset, bool stateless) -> bool { const bool debug = debug_part; if (residual_mode > 2) return false; - Rice rice; uint32_t current_k = initial_k; uint64_t sumU = 0; uint32_t count = 0; @@ -91,12 +96,12 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o if (residual_mode == 0) { for (uint32_t i = 0; i < samples; ++i) { - if (!rice.decode(reader, current_k, residual[offset + i])) return false; - uint32_t u = to_unsigned(residual[offset + i]); + uint32_t u = 0; + if (!read_rice_unsigned(reader, current_k, u)) return false; + residual[offset + i] = zigzag_decode(u); sumU += u; ++count; current_k = adapt_k(sumU, count, stateless, adapt_state); - if (reader.has_error()) return false; } return true; } @@ -108,26 +113,14 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o uint32_t idx = 0; while (idx < samples) { - uint32_t tag_prefix = reader.read_bit(); - if (reader.has_error()) { - if (debug) { - LAC_DEBUG_LOG("[part-dec] read error before tag idx=" << idx << "\n"); - } - return false; - } - uint32_t tag_suffix = reader.read_bit(); + uint32_t tag = reader.read_bits(2); if (reader.has_error()) { if (debug) { LAC_DEBUG_LOG("[part-dec] read error before tag idx=" << idx << "\n"); } return false; } - uint32_t tag = 0xFFu; - if (tag_prefix == 0u) { - tag = (tag_suffix == 0u) ? kTagNormal : kTagRun; // 00 / 01 - } else if (tag_suffix == 0u) { - tag = kTagEscape; // 10 - } else { + if (tag > kTagEscape) { if (debug) { LAC_DEBUG_LOG("[part-dec] invalid tag=3 idx=" << idx << "\n"); } @@ -171,13 +164,14 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o } return false; } - for (uint32_t j = 0; j < run_len; ++j) { - residual[offset + idx++] = 0; - if (debug && idx <= 8) { - LAC_DEBUG_LOG("[part-val] idx=" << (offset + idx - 1) - << " v=0 k=" << current_k << "\n"); - } - if (!stateless) { + std::fill_n(residual + offset + idx, run_len, 0); + if (debug && idx < 8) { + LAC_DEBUG_LOG("[part-val] idx=" << (offset + idx) + << " v=0 k=" << current_k << "\n"); + } + idx += run_len; + if (!stateless) { + for (uint32_t j = 0; j < run_len; ++j) { ++count; current_k = adapt_k(sumU, count, false, adapt_state); } @@ -273,16 +267,11 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o return false; }; - auto restore_fixed = [&](const std::vector& residual, int order, std::vector& pcm) -> bool { - pcm.resize(residual.size(), 0); + auto restore_fixed_in_place = [&](int32_t* pcm, size_t size, int order) -> bool { if (order == 0) { - pcm = residual; return true; } - for (int i = 0; i < order && i < static_cast(residual.size()); ++i) { - pcm[i] = residual[i]; - } - for (size_t i = static_cast(order); i < residual.size(); ++i) { + for (size_t i = static_cast(order); i < size; ++i) { int64_t pred = 0; switch (order) { case 1: @@ -301,7 +290,7 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o pred = 0; break; } - const int64_t sample = static_cast(residual[i]) + pred; + const int64_t sample = static_cast(pcm[i]) + pred; if (sample < std::numeric_limits::min() || sample > std::numeric_limits::max()) { return false; @@ -311,17 +300,13 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o return true; }; - auto restore_fir = [&](const std::vector& residual, int taps, std::vector& pcm) -> bool { - pcm.resize(residual.size(), 0); - for (int i = 0; i < taps && i < static_cast(residual.size()); ++i) { - pcm[i] = residual[i]; - } - for (size_t i = static_cast(taps); i < residual.size(); ++i) { + auto restore_fir_in_place = [&](int32_t* pcm, size_t size, int taps) -> bool { + for (size_t i = static_cast(taps); i < size; ++i) { int64_t pred = 0; pred += 3LL * pcm[i - 1]; pred += -1LL * pcm[i - 2]; pred >>= 2; - const int64_t sample = static_cast(residual[i]) + pred; + const int64_t sample = static_cast(pcm[i]) + pred; if (sample < std::numeric_limits::min() || sample > std::numeric_limits::max()) { return false; @@ -331,6 +316,38 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o return true; }; + auto restore_lpc_in_place = [&](int32_t* pcm, + size_t size, + int lpc_order, + const std::vector& coeffs) -> bool { + const int coeff_order = std::max(0, + std::min(lpc_order, static_cast(coeffs.size()) - 1)); + const int64_t lo = static_cast(std::numeric_limits::min()); + const int64_t hi = static_cast(std::numeric_limits::max()); + const size_t warmup = std::min(size, static_cast(coeff_order)); + for (size_t n = 0; n < warmup; ++n) { + int64_t acc = 0; + for (int i = 1; i <= static_cast(n); ++i) { + acc += static_cast(coeffs[i]) * + static_cast(pcm[n - i]); + } + const int64_t sample = (acc >> 15) + static_cast(pcm[n]); + if (sample < lo || sample > hi) return false; + pcm[n] = static_cast(sample); + } + for (size_t n = warmup; n < size; ++n) { + int64_t acc = 0; + for (int i = 1; i <= coeff_order; ++i) { + acc += static_cast(coeffs[i]) * + static_cast(pcm[n - static_cast(i)]); + } + const int64_t sample = (acc >> 15) + static_cast(pcm[n]); + if (sample < lo || sample > hi) return false; + pcm[n] = static_cast(sample); + } + return true; + }; + // Header layout: predictor_type (8), predictor_order (8), LPC coeffs (if LPC), // residual_control (8), partition metadata ([mode:2|k:5] * partitions), then residuals. uint8_t predictor_type = static_cast(br.read_bits(8)); @@ -400,7 +417,6 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o LAC_DEBUG_LOG(oss.str()); } - std::vector residual(block_size); const bool stateless = (partition_order > 0); size_t offset = 0; for (uint32_t i = 0; i < partition_count; ++i) { @@ -413,7 +429,7 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o << " bits_before=" << bits_before << "\n"); } - if (!decode_residual_segment(br, part_sizes[i], part_k[i], part_modes[i], residual, offset, stateless)) { + if (!decode_residual_segment(br, part_sizes[i], part_k[i], part_modes[i], out, offset, stateless)) { if (debug_part) { LAC_DEBUG_LOG("[part-fail] idx=" << i << " consumed=" << (bits_before - br.bits_remaining()) @@ -437,18 +453,12 @@ bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& o if (!br.consume_zero_padding_to_byte()) return false; - std::vector pcm; if (predictor_type == 0) { - if (!restore_fixed(residual, order, pcm)) return false; + return restore_fixed_in_place(out, block_size, order); } else if (predictor_type == 1) { - if (!restore_fir(residual, order, pcm)) return false; - } else { - LPC lpc(order); - if (!lpc.restore_from_residual_q15(residual, coeffs_q15, pcm)) return false; + return restore_fir_in_place(out, block_size, order); } - if (pcm.size() != block_size) return false; - out.swap(pcm); - return true; + return restore_lpc_in_place(out, block_size, order, coeffs_q15); } } // namespace Block diff --git a/src/codec/block/decoder.hpp b/src/codec/block/decoder.hpp index 0a7dce3..1326473 100644 --- a/src/codec/block/decoder.hpp +++ b/src/codec/block/decoder.hpp @@ -8,9 +8,10 @@ namespace Block { class Decoder { public: - Decoder(); + Decoder(); - bool decode(BitReader& br, uint32_t block_size, std::vector& out); -}; + bool decode(BitReader& br, uint32_t block_size, std::vector& out); + bool decode_into(BitReader& br, uint32_t block_size, int32_t* out); + }; } // namespace Block diff --git a/src/codec/lac/decoder.cpp b/src/codec/lac/decoder.cpp index a6460aa..fbb005f 100644 --- a/src/codec/lac/decoder.cpp +++ b/src/codec/lac/decoder.cpp @@ -38,27 +38,22 @@ bool is_valid_sample_for_depth(int64_t sample, uint8_t bit_depth) { return false; } -bool copy_pcm_to_output(const std::vector& source, - int32_t* output, - uint8_t bit_depth) { - if (!output) return false; - for (size_t i = 0; i < source.size(); ++i) { - if (!is_valid_sample_for_depth(source[i], bit_depth)) return false; - output[i] = source[i]; +bool validate_pcm_range(const int32_t* samples, size_t n, uint8_t bit_depth) { + if (!samples) return false; + for (size_t i = 0; i < n; ++i) { + if (!is_valid_sample_for_depth(samples[i], bit_depth)) return false; } return true; } -bool reconstruct_mid_side_to_output(const std::vector& mid, - const std::vector& side, - int32_t* left_out, - int32_t* right_out, - size_t n, - uint8_t bit_depth) { - if (!left_out || !right_out || mid.size() != n || side.size() != n) return false; +bool reconstruct_mid_side_in_place(int32_t* left_out, + int32_t* right_out, + size_t n, + uint8_t bit_depth) { + if (!left_out || !right_out) return false; for (size_t i = 0; i < n; ++i) { - const int64_t m = mid[i]; - const int64_t s = side[i]; + const int64_t m = left_out[i]; + const int64_t s = right_out[i]; const int64_t l = m + ((s + (s & 1)) >> 1); const int64_t r = l - s; if (!is_valid_sample_for_depth(l, bit_depth) || !is_valid_sample_for_depth(r, bit_depth)) { @@ -183,35 +178,30 @@ void Decoder::decode(const uint8_t* data, } Block::Decoder blockDec; - std::vector primary_pcm; - if (!blockDec.decode(block_reader, block_sizes[i], primary_pcm)) { + size_t offset = blockOffsets[i]; + int32_t* left_out = decoded_left.data() + offset; + int32_t* right_out = isStereo ? decoded_right.data() + offset : nullptr; + if (!blockDec.decode_into(block_reader, block_sizes[i], left_out)) { throw_decode_error(i, "primary"); } - std::vector secondary_pcm; if (isStereo) { - if (!blockDec.decode(block_reader, block_sizes[i], secondary_pcm)) { + if (!blockDec.decode_into(block_reader, block_sizes[i], right_out)) { throw_decode_error(i, "secondary"); } } - size_t offset = blockOffsets[i]; if (!isStereo) { - if (!copy_pcm_to_output(primary_pcm, decoded_left.data() + offset, hdr.bit_depth)) { + if (!validate_pcm_range(left_out, block_sizes[i], hdr.bit_depth)) { throw_decode_error("decoded sample outside PCM bit depth"); } } else if (mid_side) { - if (!reconstruct_mid_side_to_output(primary_pcm, - secondary_pcm, - decoded_left.data() + offset, - decoded_right.data() + offset, - primary_pcm.size(), - hdr.bit_depth)) { + if (!reconstruct_mid_side_in_place(left_out, right_out, block_sizes[i], hdr.bit_depth)) { throw_decode_error("decoded sample outside PCM bit depth"); } } else { - if (!copy_pcm_to_output(primary_pcm, decoded_left.data() + offset, hdr.bit_depth) || - !copy_pcm_to_output(secondary_pcm, decoded_right.data() + offset, hdr.bit_depth)) { + if (!validate_pcm_range(left_out, block_sizes[i], hdr.bit_depth) || + !validate_pcm_range(right_out, block_sizes[i], hdr.bit_depth)) { throw_decode_error("decoded sample outside PCM bit depth"); } } diff --git a/src/codec/lpc/lpc.cpp b/src/codec/lpc/lpc.cpp index ed66fa3..059672a 100644 --- a/src/codec/lpc/lpc.cpp +++ b/src/codec/lpc/lpc.cpp @@ -235,14 +235,29 @@ bool LPC::restore_from_residual_q15(const std::vector& residual, out_block.resize(N); const int64_t lo = static_cast(std::numeric_limits::min()); const int64_t hi = static_cast(std::numeric_limits::max()); + const int coeff_order = std::max(0, + std::min(this->order, static_cast(coeffs_q15.size()) - 1)); - for (size_t n = 0; n < N; ++n) { + const size_t warmup = std::min(N, static_cast(coeff_order)); + for (size_t n = 0; n < warmup; ++n) { int64_t acc = 0; - for (int i = 1; i <= this->order; ++i) { - if (n >= static_cast(i)) { - acc += static_cast(coeffs_q15[i]) * - static_cast(out_block[n - i]); - } + for (int i = 1; i <= static_cast(n); ++i) { + acc += static_cast(coeffs_q15[i]) * + static_cast(out_block[n - i]); + } + const int64_t pred = (acc >> 15); + const int64_t sample = pred + static_cast(residual[n]); + if (sample < lo || sample > hi) { + out_block.clear(); + return false; + } + out_block[n] = static_cast(sample); + } + for (size_t n = warmup; n < N; ++n) { + int64_t acc = 0; + for (int i = 1; i <= coeff_order; ++i) { + acc += static_cast(coeffs_q15[i]) * + static_cast(out_block[n - static_cast(i)]); } const int64_t pred = (acc >> 15); const int64_t sample = pred + static_cast(residual[n]); From 064adbb87996a68052314aadcaf318f39657604a Mon Sep 17 00:00:00 2001 From: audexdev Date: Thu, 4 Jun 2026 23:47:15 +0900 Subject: [PATCH 3/6] Improve decode output throughput --- src/codec/bitstream/bit_reader.cpp | 168 ---------------------------- src/codec/bitstream/bit_reader.hpp | 174 +++++++++++++++++++++++++++++ src/codec/block/decoder.cpp | 50 ++++----- src/codec/rice/rice.cpp | 70 ------------ src/codec/rice/rice.hpp | 74 ++++++++++++ src/io/wav_io.cpp | 126 ++++++++++++++++----- src/io/wav_io.hpp | 7 ++ src/main.cpp | 2 +- 8 files changed, 376 insertions(+), 295 deletions(-) diff --git a/src/codec/bitstream/bit_reader.cpp b/src/codec/bitstream/bit_reader.cpp index 99178f8..d9a7c93 100644 --- a/src/codec/bitstream/bit_reader.cpp +++ b/src/codec/bitstream/bit_reader.cpp @@ -1,169 +1 @@ #include "bit_reader.hpp" -#include -#include - -namespace { -inline uint32_t low_bits_mask(int bits) { - if (bits <= 0) return 0u; - if (bits >= 32) return 0xFFFFFFFFu; - return (1u << bits) - 1u; -} -} // namespace - -BitReader::BitReader(const uint8_t* data, size_t size) : data(data), size(size), byte_pos(0), bit_pos(0), error(false) {} - -BitReader::BitReader(const std::vector& buf) : data(buf.data()), size(buf.size()), byte_pos(0), bit_pos(0), error(false) {} - -void BitReader::mark_error() { - this->error = true; - this->byte_pos = this->size; - this->bit_pos = 0; -} - -uint32_t BitReader::read_bit() { - if (this->byte_pos >= this->size) { - this->mark_error(); - return 0; - } - - uint8_t byte = this->data[this->byte_pos]; - int shift = 7 - this->bit_pos; - uint32_t bit = (byte >> shift) & 1; - - this->bit_pos++; - if (this->bit_pos == 8) { - this->bit_pos = 0; - this->byte_pos++; - } - - return bit; -} - -uint32_t BitReader::read_bits(int nbits) { - if (nbits <= 0) return 0; - - if (this->error || static_cast(nbits) > this->bits_remaining()) { - this->mark_error(); - return 0; - } - - if (this->bit_pos == 0) { - if (nbits == 8) { - return this->data[this->byte_pos++]; - } - if (nbits == 16) { - const uint32_t value = - (static_cast(this->data[this->byte_pos]) << 8) | - static_cast(this->data[this->byte_pos + 1]); - this->byte_pos += 2; - return value; - } - if (nbits == 32) { - const uint32_t value = - (static_cast(this->data[this->byte_pos]) << 24) | - (static_cast(this->data[this->byte_pos + 1]) << 16) | - (static_cast(this->data[this->byte_pos + 2]) << 8) | - static_cast(this->data[this->byte_pos + 3]); - this->byte_pos += 4; - return value; - } - } - - const int available = 8 - this->bit_pos; - if (nbits <= available) { - const int shift = available - nbits; - const uint32_t value = - (static_cast(this->data[this->byte_pos]) >> shift) & - low_bits_mask(nbits); - this->bit_pos += nbits; - if (this->bit_pos == 8) { - this->bit_pos = 0; - ++this->byte_pos; - } - return value; - } - - uint32_t value = 0; - int remaining = nbits; - while (remaining > 0) { - const int bits_available = 8 - this->bit_pos; - const int take = std::min(remaining, bits_available); - const int shift = bits_available - take; - const uint32_t chunk = - (static_cast(this->data[this->byte_pos]) >> shift) & - low_bits_mask(take); - value = static_cast((value << take) | chunk); - remaining -= take; - this->bit_pos += take; - if (this->bit_pos == 8) { - this->bit_pos = 0; - ++this->byte_pos; - } - } - - return value; -} - -bool BitReader::read_unary_ones(uint32_t max_ones, uint32_t& ones) { - ones = 0; - while (this->byte_pos < this->size) { - const int available = 8 - this->bit_pos; - const uint32_t shifted = - (static_cast(this->data[this->byte_pos]) << this->bit_pos) & 0xFFu; - const uint32_t run = std::min( - static_cast(available), - std::countl_one(static_cast(shifted << 24))); - - if (max_ones - ones < run) { - return false; - } - ones += run; - this->bit_pos += static_cast(run); - if (this->bit_pos == 8) { - this->bit_pos = 0; - ++this->byte_pos; - } - - if (run < static_cast(available)) { - ++this->bit_pos; // consume unary terminator zero - if (this->bit_pos == 8) { - this->bit_pos = 0; - ++this->byte_pos; - } - return true; - } - } - - this->mark_error(); - return false; -} - -void BitReader::align_to_byte() { - if (this->bit_pos == 0) return; - this->bit_pos = 0; - this->byte_pos++; -} - -bool BitReader::consume_zero_padding_to_byte() { - while (this->bit_pos != 0) { - if (this->read_bit() != 0u || this->error) return false; - } - return true; -} - -bool BitReader::eof() const { - return this->byte_pos >= this->size; -} - -bool BitReader::has_error() const { - return this->error; -} - -size_t BitReader::bits_remaining() const { - if (this->error) return 0; - size_t bits = (this->size - this->byte_pos) * 8; - if (this->bit_pos > 0) { - bits -= this->bit_pos; - } - return bits; -} diff --git a/src/codec/bitstream/bit_reader.hpp b/src/codec/bitstream/bit_reader.hpp index ee7ec63..98d0900 100644 --- a/src/codec/bitstream/bit_reader.hpp +++ b/src/codec/bitstream/bit_reader.hpp @@ -1,4 +1,6 @@ #pragma once +#include +#include #include #include @@ -25,4 +27,176 @@ class BitReader { bool error; void mark_error(); + static uint32_t low_bits_mask(int bits); }; + +inline uint32_t BitReader::low_bits_mask(int bits) { + if (bits <= 0) return 0u; + if (bits >= 32) return 0xFFFFFFFFu; + return (1u << bits) - 1u; +} + +inline BitReader::BitReader(const uint8_t* data, size_t size) + : data(data), size(size), byte_pos(0), bit_pos(0), error(false) {} + +inline BitReader::BitReader(const std::vector& buf) + : data(buf.data()), size(buf.size()), byte_pos(0), bit_pos(0), error(false) {} + +inline void BitReader::mark_error() { + this->error = true; + this->byte_pos = this->size; + this->bit_pos = 0; +} + +inline uint32_t BitReader::read_bit() { + if (this->byte_pos >= this->size) { + this->mark_error(); + return 0; + } + + const uint8_t byte = this->data[this->byte_pos]; + const int shift = 7 - this->bit_pos; + const uint32_t bit = (byte >> shift) & 1u; + + ++this->bit_pos; + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } + + return bit; +} + +inline uint32_t BitReader::read_bits(int nbits) { + if (nbits <= 0) return 0; + + if (this->error || this->byte_pos >= this->size) { + this->mark_error(); + return 0; + } + + const int available = 8 - this->bit_pos; + if (nbits <= available) { + const int shift = available - nbits; + const uint32_t value = + (static_cast(this->data[this->byte_pos]) >> shift) & + low_bits_mask(nbits); + this->bit_pos += nbits; + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } + return value; + } + + if (this->bit_pos == 0) { + if (nbits == 8) { + return this->data[this->byte_pos++]; + } + if (nbits == 16 && this->size - this->byte_pos >= 2u) { + const uint32_t value = + (static_cast(this->data[this->byte_pos]) << 8) | + static_cast(this->data[this->byte_pos + 1]); + this->byte_pos += 2; + return value; + } + if (nbits == 32 && this->size - this->byte_pos >= 4u) { + const uint32_t value = + (static_cast(this->data[this->byte_pos]) << 24) | + (static_cast(this->data[this->byte_pos + 1]) << 16) | + (static_cast(this->data[this->byte_pos + 2]) << 8) | + static_cast(this->data[this->byte_pos + 3]); + this->byte_pos += 4; + return value; + } + } + + if (static_cast(nbits) > this->bits_remaining()) { + this->mark_error(); + return 0; + } + + uint32_t value = 0; + int remaining = nbits; + while (remaining > 0) { + const int bits_available = 8 - this->bit_pos; + const int take = std::min(remaining, bits_available); + const int shift = bits_available - take; + const uint32_t chunk = + (static_cast(this->data[this->byte_pos]) >> shift) & + low_bits_mask(take); + value = static_cast((value << take) | chunk); + remaining -= take; + this->bit_pos += take; + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } + } + + return value; +} + +inline bool BitReader::read_unary_ones(uint32_t max_ones, uint32_t& ones) { + ones = 0; + while (this->byte_pos < this->size) { + const int available = 8 - this->bit_pos; + const uint32_t shifted = + (static_cast(this->data[this->byte_pos]) << this->bit_pos) & 0xFFu; + const uint32_t run = std::min( + static_cast(available), + std::countl_one(static_cast(shifted << 24))); + + if (max_ones - ones < run) { + return false; + } + ones += run; + this->bit_pos += static_cast(run); + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } + + if (run < static_cast(available)) { + ++this->bit_pos; + if (this->bit_pos == 8) { + this->bit_pos = 0; + ++this->byte_pos; + } + return true; + } + } + + this->mark_error(); + return false; +} + +inline void BitReader::align_to_byte() { + if (this->bit_pos == 0) return; + this->bit_pos = 0; + ++this->byte_pos; +} + +inline bool BitReader::consume_zero_padding_to_byte() { + while (this->bit_pos != 0) { + if (this->read_bit() != 0u || this->error) return false; + } + return true; +} + +inline bool BitReader::eof() const { + return this->byte_pos >= this->size; +} + +inline bool BitReader::has_error() const { + return this->error; +} + +inline size_t BitReader::bits_remaining() const { + if (this->error) return 0; + size_t bits = (this->size - this->byte_pos) * 8; + if (this->bit_pos > 0) { + bits -= this->bit_pos; + } + return bits; +} diff --git a/src/codec/block/decoder.cpp b/src/codec/block/decoder.cpp index e3532cd..bf189ad 100644 --- a/src/codec/block/decoder.cpp +++ b/src/codec/block/decoder.cpp @@ -5,6 +5,7 @@ #include "codec/rice/rice.hpp" #include "utils/logger.hpp" #include +#include #include #include #include @@ -54,22 +55,10 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { return std::min(31u, std::bit_width(mean - uint64_t{1})); }; - auto partition_sizes_for_block = [](uint32_t size, uint8_t order) -> std::vector { - std::vector parts; - if (order == 0) { - parts.push_back(size); - return parts; - } + auto partition_size_at = [](uint32_t size, uint8_t order, uint32_t index, uint32_t count) -> uint32_t { + if (order == 0) return size; const uint32_t base = size >> order; - if (base == 0) { - parts.push_back(size); - return parts; - } - const uint32_t count = 1u << order; - parts.assign(count, base); - const uint32_t used = base * (count - 1u); - parts.back() = size - used; - return parts; + return (index + 1u == count) ? (size - base * (count - 1u)) : base; }; auto decode_residual_segment = [&](BitReader& reader, @@ -319,9 +308,8 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { auto restore_lpc_in_place = [&](int32_t* pcm, size_t size, int lpc_order, - const std::vector& coeffs) -> bool { - const int coeff_order = std::max(0, - std::min(lpc_order, static_cast(coeffs.size()) - 1)); + const std::array& coeffs) -> bool { + const int coeff_order = std::max(0, lpc_order); const int64_t lo = static_cast(std::numeric_limits::min()); const int64_t hi = static_cast(std::numeric_limits::max()); const size_t warmup = std::min(size, static_cast(coeff_order)); @@ -362,9 +350,8 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { if (order < 0 || order > 4) return false; } - std::vector coeffs_q15; + std::array coeffs_q15{}; if (predictor_type == 2) { - coeffs_q15.assign(order + 1, 0); for (int i = 1; i <= order; ++i) { coeffs_q15[i] = int16_t(br.read_bits(16)); if (br.has_error()) return false; @@ -385,11 +372,14 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { if ((partition_order > 0) && ((block_size >> partition_order) < Block::MIN_PARTITION_SIZE)) return false; const uint32_t partition_count = (partition_order == 0) ? 1u : (1u << partition_order); - const auto part_sizes = partition_sizes_for_block(block_size, partition_order); - if (part_sizes.size() != partition_count || part_sizes.back() == 0) return false; + constexpr size_t kMaxPartitionCount = size_t{1} << Block::MAX_PARTITION_ORDER; + if (partition_count > kMaxPartitionCount || + partition_size_at(block_size, partition_order, partition_count - 1u, partition_count) == 0) { + return false; + } - std::vector part_modes(partition_count); - std::vector part_k(partition_count); + std::array part_modes{}; + std::array part_k{}; for (uint32_t i = 0; i < partition_count; ++i) { part_modes[i] = static_cast(br.read_bits(2)); part_k[i] = br.read_bits(5); @@ -411,7 +401,8 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { std::ostringstream oss; oss << "[part-decode] modes/k:"; for (uint32_t i = 0; i < partition_count; ++i) { - oss << " (" << static_cast(part_modes[i]) << "," << part_k[i] << "," << part_sizes[i] << ")"; + const uint32_t len = partition_size_at(block_size, partition_order, i, partition_count); + oss << " (" << static_cast(part_modes[i]) << "," << part_k[i] << "," << len << ")"; } oss << " bits_before_parts=" << br.bits_remaining() << "\n"; LAC_DEBUG_LOG(oss.str()); @@ -420,20 +411,21 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { const bool stateless = (partition_order > 0); size_t offset = 0; for (uint32_t i = 0; i < partition_count; ++i) { + const uint32_t part_size = partition_size_at(block_size, partition_order, i, partition_count); size_t bits_before = br.bits_remaining(); if (debug_part) { LAC_DEBUG_LOG("[part-entry] idx=" << i << " mode=" << static_cast(part_modes[i]) << " k=" << part_k[i] - << " samples=" << part_sizes[i] + << " samples=" << part_size << " bits_before=" << bits_before << "\n"); } - if (!decode_residual_segment(br, part_sizes[i], part_k[i], part_modes[i], out, offset, stateless)) { + if (!decode_residual_segment(br, part_size, part_k[i], part_modes[i], out, offset, stateless)) { if (debug_part) { LAC_DEBUG_LOG("[part-fail] idx=" << i << " consumed=" << (bits_before - br.bits_remaining()) - << " size=" << part_sizes[i] + << " size=" << part_size << " k=" << part_k[i] << " mode=" << static_cast(part_modes[i]) << " bits_left=" << br.bits_remaining() @@ -447,7 +439,7 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { << " bits_consumed=" << (bits_before - part_end_bit) << " bits_left=" << part_end_bit << "\n"); } - offset += part_sizes[i]; + offset += part_size; } if (offset != block_size) return false; diff --git a/src/codec/rice/rice.cpp b/src/codec/rice/rice.cpp index 8512b0f..a7770c6 100644 --- a/src/codec/rice/rice.cpp +++ b/src/codec/rice/rice.cpp @@ -66,73 +66,3 @@ uint32_t Rice::compute_k(const std::vector& residuals) { return k; } - -namespace { -constexpr uint32_t kMaxRiceK = 31; -} // namespace - -uint32_t Rice::adapt_k(uint64_t sum, uint32_t count, AdaptState& state) { - if (count == 0) return 0; - - const uint64_t current_u = sum - state.previous_sum; - state.previous_sum = sum; - - // Track micro-window flags first so we can evict the slot we will overwrite. - const uint32_t micro_idx = (count - 1) % kMicroWindow; - state.large_q_count -= state.large_flags[micro_idx]; - state.zero_q_count -= state.zero_flags[micro_idx]; - - // Update drift window. - if (state.window_filled < kDriftWindow) { - state.window_filled++; - } else { - state.window_sum -= state.recent_u[state.window_index]; - } - state.recent_u[state.window_index] = static_cast(current_u); - state.window_sum += current_u; - - // Micro window flags. - // Base mean and k estimate using integer rounding. - const uint64_t mean = (sum + (count >> 1)) / count; - const uint32_t k = (mean <= 1) - ? 0u - : std::min(kMaxRiceK, std::bit_width(mean - 1u)); - - const uint32_t q_base = static_cast((k >= kMaxRiceK) ? 0u : (current_u >> k)); - const uint8_t is_large = static_cast(q_base > 3u); - const uint8_t is_zero = static_cast(q_base == 0u); - - state.large_q_count += is_large; - state.zero_q_count += is_zero; - state.large_flags[micro_idx] = is_large; - state.zero_flags[micro_idx] = is_zero; - - // Drift correction: compare local window mean against global mean. - int32_t bias = 0; - if (state.window_filled > 0 && mean > 0) { - const uint64_t local_mean = (state.window_sum + (state.window_filled >> 1)) / state.window_filled; - // Increase k if local variation is significantly higher. - if (local_mean * 3 > mean * 4) { // local > 1.33 * global - bias = 1; - } else if (local_mean * 4 + 3 < mean * 3) { // local < 0.75 * global - bias = -1; - } - } - - // Micro adaptation on quotient distribution. - if (state.window_index + 1 >= kMicroWindow || state.window_filled >= kMicroWindow) { - const uint32_t window_size = std::min(state.window_filled, kMicroWindow); - if (state.large_q_count * 4 >= window_size * 3) { // many large quotients - bias = std::min(bias + 1, 1); - } else if (state.zero_q_count * 5 >= window_size * 4) { // mostly zeros - bias = std::max(bias - 1, -1); - } - } - - int32_t biased_k = static_cast(k) + bias; - if (biased_k < 0) biased_k = 0; - if (biased_k > 31) biased_k = 31; - - state.window_index = (state.window_index + 1) % kDriftWindow; - return static_cast(biased_k); -} diff --git a/src/codec/rice/rice.hpp b/src/codec/rice/rice.hpp index 7d5f4c5..e90a4c6 100644 --- a/src/codec/rice/rice.hpp +++ b/src/codec/rice/rice.hpp @@ -1,5 +1,7 @@ #pragma once +#include #include +#include #include #include #include "codec/bitstream/bit_writer.hpp" @@ -13,6 +15,7 @@ class Rice { struct AdaptState { uint64_t previous_sum = 0; uint32_t window_index = 0; + uint32_t micro_index = 0; uint32_t window_filled = 0; uint64_t window_sum = 0; uint16_t large_q_count = 0; @@ -41,3 +44,74 @@ class Rice { static uint32_t signed_to_unsigned(int32_t v); static int32_t unsigned_to_signed(uint32_t u); }; + +inline uint32_t Rice::adapt_k(uint64_t sum, uint32_t count, AdaptState& state) { + if (count == 0) return 0; + + constexpr uint32_t kMaxRiceK = 31; + const uint64_t current_u = sum - state.previous_sum; + state.previous_sum = sum; + + // Track micro-window flags first so we can evict the slot we will overwrite. + const uint32_t micro_idx = state.micro_index; + state.large_q_count -= state.large_flags[micro_idx]; + state.zero_q_count -= state.zero_flags[micro_idx]; + + // Update drift window. + if (state.window_filled < kDriftWindow) { + ++state.window_filled; + } else { + state.window_sum -= state.recent_u[state.window_index]; + } + state.recent_u[state.window_index] = static_cast(current_u); + state.window_sum += current_u; + + // Micro window flags. + // Base mean and k estimate using integer rounding. + const uint64_t mean = (sum + (count >> 1)) / count; + const uint32_t k = (mean <= 1) + ? 0u + : std::min(kMaxRiceK, std::bit_width(mean - 1u)); + + const uint32_t q_base = static_cast((k >= kMaxRiceK) ? 0u : (current_u >> k)); + const uint8_t is_large = static_cast(q_base > 3u); + const uint8_t is_zero = static_cast(q_base == 0u); + + state.large_q_count += is_large; + state.zero_q_count += is_zero; + state.large_flags[micro_idx] = is_large; + state.zero_flags[micro_idx] = is_zero; + + // Drift correction: compare local window mean against global mean. + int32_t bias = 0; + if (state.window_filled > 0 && mean > 0) { + const uint64_t local_mean = (state.window_filled == kDriftWindow) + ? ((state.window_sum + (kDriftWindow >> 1)) >> 8) + : ((state.window_sum + (state.window_filled >> 1)) / state.window_filled); + // Increase k if local variation is significantly higher. + if (local_mean * 3 > mean * 4) { + bias = 1; + } else if (local_mean * 4 + 3 < mean * 3) { + bias = -1; + } + } + + // Micro adaptation on quotient distribution. + if (state.window_index + 1 >= kMicroWindow || state.window_filled >= kMicroWindow) { + const uint32_t window_size = + (state.window_filled >= kMicroWindow) ? kMicroWindow : state.window_filled; + if (state.large_q_count * 4 >= window_size * 3) { + bias = std::min(bias + 1, 1); + } else if (state.zero_q_count * 5 >= window_size * 4) { + bias = std::max(bias - 1, -1); + } + } + + int32_t biased_k = static_cast(k) + bias; + if (biased_k < 0) biased_k = 0; + if (biased_k > 31) biased_k = 31; + + state.micro_index = (state.micro_index + 1u == kMicroWindow) ? 0u : state.micro_index + 1u; + state.window_index = (state.window_index + 1u) & (kDriftWindow - 1u); + return static_cast(biased_k); +} diff --git a/src/io/wav_io.cpp b/src/io/wav_io.cpp index 79c1c7b..f1f63f0 100644 --- a/src/io/wav_io.cpp +++ b/src/io/wav_io.cpp @@ -1,6 +1,8 @@ #include "wav_io.hpp" +#include #include #include +#include namespace { @@ -99,6 +101,62 @@ void write_pcm24_sample(std::ofstream& f, int32_t v) { f.write(reinterpret_cast(b), 3); } +void append_pcm16_sample(uint8_t*& out, int32_t v) { + *out++ = static_cast(v & 0xFF); + *out++ = static_cast((v >> 8) & 0xFF); +} + +void append_pcm24_sample(uint8_t*& out, int32_t v) { + *out++ = static_cast(v & 0xFF); + *out++ = static_cast((v >> 8) & 0xFF); + *out++ = static_cast((v >> 16) & 0xFF); +} + +bool write_pcm_data(std::ofstream& f, + const std::vector& left, + const std::vector& right, + uint16_t channels, + uint8_t bit_depth, + uint32_t frames, + uint32_t block_align) { + constexpr size_t kChunkBytes = 4u * 1024u * 1024u; + const size_t frames_per_chunk = std::max(1u, kChunkBytes / block_align); + std::vector buffer(frames_per_chunk * block_align); + + for (size_t base = 0; base < frames; base += frames_per_chunk) { + const size_t count = std::min(frames_per_chunk, frames - base); + uint8_t* out = buffer.data(); + + if (bit_depth == 16) { + if (channels == 1) { + for (size_t i = 0; i < count; ++i) { + append_pcm16_sample(out, left[base + i]); + } + } else { + for (size_t i = 0; i < count; ++i) { + append_pcm16_sample(out, left[base + i]); + append_pcm16_sample(out, right[base + i]); + } + } + } else { + if (channels == 1) { + for (size_t i = 0; i < count; ++i) { + append_pcm24_sample(out, left[base + i]); + } + } else { + for (size_t i = 0; i < count; ++i) { + append_pcm24_sample(out, left[base + i]); + append_pcm24_sample(out, right[base + i]); + } + } + } + + f.write(reinterpret_cast(buffer.data()), out - buffer.data()); + if (!f) return false; + } + return true; +} + } // namespace bool read_wav(const std::string& path, @@ -219,23 +277,28 @@ bool read_wav(const std::string& path, return true; } -bool write_wav(const std::string& path, - const std::vector& left, - const std::vector& right, - uint16_t channels, - uint32_t sample_rate, - uint8_t bit_depth) { +namespace { + +bool write_wav_impl(const std::string& path, + const std::vector& left, + const std::vector& right, + uint16_t channels, + uint32_t sample_rate, + uint8_t bit_depth, + bool validate_samples) { if (channels != 1 && channels != 2) return false; if (!is_supported_sample_rate(sample_rate)) return false; if (bit_depth != 16 && bit_depth != 24) return false; if (left.empty()) return false; if (channels == 1 && !right.empty()) return false; if (channels == 2 && left.size() != right.size()) return false; - for (int32_t sample : left) { - if (!is_valid_sample_for_depth(sample, bit_depth)) return false; - } - for (int32_t sample : right) { - if (!is_valid_sample_for_depth(sample, bit_depth)) return false; + if (validate_samples) { + for (int32_t sample : left) { + if (!is_valid_sample_for_depth(sample, bit_depth)) return false; + } + for (int32_t sample : right) { + if (!is_valid_sample_for_depth(sample, bit_depth)) return false; + } } uint16_t bytes_per_sample = static_cast(bit_depth / 8); @@ -248,7 +311,10 @@ bool write_wav(const std::string& path, const uint32_t data_size = static_cast(data_size_u64); const uint32_t riff_size = static_cast(riff_size_u64); - std::ofstream f(path, std::ios::binary); + std::ofstream f; + std::vector file_buffer(4u * 1024u * 1024u); + f.rdbuf()->pubsetbuf(file_buffer.data(), static_cast(file_buffer.size())); + f.open(path, std::ios::binary); if (!f) return false; f.write("RIFF", 4); @@ -267,21 +333,7 @@ bool write_wav(const std::string& path, f.write("data", 4); write_u32_le(f, data_size); - for (uint32_t i = 0; i < frames; ++i) { - if (bit_depth == 16) { - write_pcm16_sample(f, left[i]); - } else { - write_pcm24_sample(f, left[i]); - } - - if (channels == 2) { - if (bit_depth == 16) { - write_pcm16_sample(f, right[i]); - } else { - write_pcm24_sample(f, right[i]); - } - } - } + if (!write_pcm_data(f, left, right, channels, bit_depth, frames, block_align)) return false; if (data_padding_u64 != 0u) { const char padding = 0; @@ -290,3 +342,23 @@ bool write_wav(const std::string& path, f.close(); return f.good(); } + +} // namespace + +bool write_wav(const std::string& path, + const std::vector& left, + const std::vector& right, + uint16_t channels, + uint32_t sample_rate, + uint8_t bit_depth) { + return write_wav_impl(path, left, right, channels, sample_rate, bit_depth, true); +} + +bool write_wav_unchecked_samples(const std::string& path, + const std::vector& left, + const std::vector& right, + uint16_t channels, + uint32_t sample_rate, + uint8_t bit_depth) { + return write_wav_impl(path, left, right, channels, sample_rate, bit_depth, false); +} diff --git a/src/io/wav_io.hpp b/src/io/wav_io.hpp index 9baed6d..37994a2 100644 --- a/src/io/wav_io.hpp +++ b/src/io/wav_io.hpp @@ -16,3 +16,10 @@ bool write_wav(const std::string& path, uint16_t channels, uint32_t sample_rate, uint8_t bit_depth); + +bool write_wav_unchecked_samples(const std::string& path, + const std::vector& left, + const std::vector& right, + uint16_t channels, + uint32_t sample_rate, + uint8_t bit_depth); diff --git a/src/main.cpp b/src/main.cpp index c7ffced..bc984d6 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -367,7 +367,7 @@ int main(int argc, char** argv) { } StagedOutputFile staged_output(out_path); if (!staged_output.is_ready() || - !write_wav(staged_output.path(), left, right, hdr.channels, hdr.sample_rate, hdr.bit_depth) || + !write_wav_unchecked_samples(staged_output.path(), left, right, hdr.channels, hdr.sample_rate, hdr.bit_depth) || !staged_output.publish(in_path)) { std::cerr << "Failed to write WAV: " << out_path << "\n"; return 1; From f3fe44098b3fbf7b56baf2eb3362e464cb0869f4 Mon Sep 17 00:00:00 2001 From: audexdev Date: Fri, 5 Jun 2026 00:12:16 +0900 Subject: [PATCH 4/6] Improve LAC decode speed path --- docs/format.md | 8 +- docs/supported-formats.md | 2 +- src/codec/block/decoder.cpp | 16 +- src/codec/block/encoder.cpp | 100 +++++++-- src/main.cpp | 429 +++++++++++++++++++++++++++++++++++- tests/test_e2e.cpp | 52 ++++- tests/test_zerorun.cpp | 117 ++++++++-- 7 files changed, 659 insertions(+), 65 deletions(-) diff --git a/docs/format.md b/docs/format.md index 2211734..73c817b 100644 --- a/docs/format.md +++ b/docs/format.md @@ -194,7 +194,7 @@ Current limits: - minimum partition size: 32 samples - maximum partition order: 8 -The partition flag must be set when `partition_order > 0` and unset when `partition_order == 0`. The default residual mode must be `0`, `1`, or `2`, and it must match the first partition metadata entry. Partitioned blocks use stateless Rice adaptation inside each partition. Unpartitioned blocks use stateful Rice adaptation. +The partition flag must be set when `partition_order > 0` and unset when `partition_order == 0`. The default residual mode must be `0`, `1`, `2`, or `3`, and it must match the first partition metadata entry. Modes `0`, `1`, and `2` use Rice adaptation: partitioned blocks use stateless adaptation inside each partition, while unpartitioned blocks use stateful adaptation. Mode `3` uses a fixed Rice parameter for the whole partition. Partition sizes are computed as: @@ -247,7 +247,7 @@ write one zero bit write r as k bits when k > 0 ``` -The initial `k` for each residual segment comes from that partition's `u5 initial_k` metadata. `k` is updated after each logical residual sample, including samples represented by zero-run tokens. +For adaptive residual modes, the initial `k` for each residual segment comes from that partition's `u5 initial_k` metadata. `k` is updated after each logical residual sample, including samples represented by zero-run tokens. #### Stateless Adaptation @@ -381,6 +381,10 @@ The fallback path uses the same adaptive `k` model. Every bin token represents exactly one residual sample. Tags `00`, `01`, and `10` update the adaptive model with unsigned values `0`, `2 or 1`, and `4 or 3` respectively after sign reconstruction. Tag `11` decodes one Rice-coded residual using the current `k`, then updates the same adaptive model. +### Mode 3: Static Rice + +Static Rice mode uses the same zigzag mapping and Rice bit coding as mode `0`, but `k` is fixed to the partition's `u5 initial_k` value for every residual in the partition. No adaptive state is updated and `k` never changes. + ## Padding Each channel block is flushed to the next byte boundary after residual encoding. Padding bits must be zero. Non-zero padding is rejected as non-canonical. diff --git a/docs/supported-formats.md b/docs/supported-formats.md index 5a947ab..3367574 100644 --- a/docs/supported-formats.md +++ b/docs/supported-formats.md @@ -71,7 +71,7 @@ The current `.lac` container supports: - LR, mid/side, or per-block stereo mode - block sizes up to `16384` samples per channel - fixed, FIR, and LPC predictors -- adaptive Rice, zero-run, and small-residual bin residual modes +- adaptive Rice, zero-run, small-residual bin, and static Rice residual modes See `docs/format.md` for bitstream details. diff --git a/src/codec/block/decoder.cpp b/src/codec/block/decoder.cpp index bf189ad..57ec341 100644 --- a/src/codec/block/decoder.cpp +++ b/src/codec/block/decoder.cpp @@ -30,6 +30,7 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { constexpr uint32_t kBinTagOne = 0b01u; constexpr uint32_t kBinTagTwo = 0b10u; constexpr uint32_t kBinTagFallback = 0b11u; + constexpr uint8_t kResidualModeStaticRice = 3u; auto read_rice_unsigned = [](BitReader& reader, uint32_t k, uint32_t& value) -> bool { if (k > 31u) return false; @@ -69,7 +70,7 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { size_t offset, bool stateless) -> bool { const bool debug = debug_part; - if (residual_mode > 2) return false; + if (residual_mode > kResidualModeStaticRice) return false; uint32_t current_k = initial_k; uint64_t sumU = 0; @@ -253,6 +254,15 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { return idx == samples; } + if (residual_mode == kResidualModeStaticRice) { + for (uint32_t i = 0; i < samples; ++i) { + uint32_t u = 0; + if (!read_rice_unsigned(reader, initial_k, u)) return false; + residual[offset + i] = zigzag_decode(u); + } + return true; + } + return false; }; @@ -365,7 +375,7 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { const uint8_t partition_order = static_cast((control & Block::PARTITION_ORDER_MASK) >> Block::PARTITION_ORDER_SHIFT); const uint8_t control_mode = static_cast((control >> 5) & 0x03u); - if (control_mode > 2u) return false; + if (control_mode > kResidualModeStaticRice) return false; if (partition_flag && partition_order == 0) return false; if (!partition_flag && partition_order != 0) return false; if (partition_order > Block::MAX_PARTITION_ORDER) return false; @@ -384,7 +394,7 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { part_modes[i] = static_cast(br.read_bits(2)); part_k[i] = br.read_bits(5); if (br.has_error()) return false; - if (part_modes[i] > 2) return false; + if (part_modes[i] > kResidualModeStaticRice) return false; } if (part_modes[0] != control_mode) return false; diff --git a/src/codec/block/encoder.cpp b/src/codec/block/encoder.cpp index a7455e6..3f5dcaa 100644 --- a/src/codec/block/encoder.cpp +++ b/src/codec/block/encoder.cpp @@ -47,9 +47,14 @@ constexpr uint32_t kBinTagZero = 0b00u; constexpr uint32_t kBinTagOne = 0b01u; constexpr uint32_t kBinTagTwo = 0b10u; constexpr uint32_t kBinTagFallback = 0b11u; +constexpr uint8_t kResidualModeRice = 0; +constexpr uint8_t kResidualModeZeroRun = 1; +constexpr uint8_t kResidualModeBin = 2; +constexpr uint8_t kResidualModeStaticRice = 3; constexpr uint8_t kPredictorFixed = 0; constexpr uint8_t kPredictorFir = 1; constexpr uint8_t kPredictorLpc = 2; +constexpr uint64_t kDecodeSpeedBitMarginDivisor = 20; constexpr int kFirShift = 2; constexpr int kFirTaps[] = {3, -1}; @@ -152,6 +157,40 @@ uint32_t estimate_initial_k(std::span residual) { return std::min(best_k, 15u); } +uint32_t estimate_static_k(std::span residual) { + if (residual.empty()) return 0; + constexpr uint32_t kMaxStaticK = 15; + uint64_t candidate_costs[kMaxStaticK + 1] = {0}; + for (int32_t v : residual) { + const uint32_t u = unsigned_from_residual(v); + for (uint32_t k = 0; k <= kMaxStaticK; ++k) { + candidate_costs[k] += rice_bits_for_unsigned(u, k); + } + } + + uint32_t best_k = 0; + uint64_t best_cost = std::numeric_limits::max(); + for (uint32_t k = 0; k <= kMaxStaticK; ++k) { + if (candidate_costs[k] < best_cost) { + best_cost = candidate_costs[k]; + best_k = k; + } + } + return best_k; +} + +uint64_t estimate_static_rice_bits(std::span residual, uint32_t k) { + uint64_t bits = 0; + for (int32_t v : residual) { + bits += rice_bits_for_unsigned(unsigned_from_residual(v), k); + } + return bits; +} + +bool within_decode_speed_margin(uint64_t candidate_bits, uint64_t reference_bits) { + return candidate_bits <= reference_bits + (reference_bits / kDecodeSpeedBitMarginDivisor); +} + struct ResidualCosts { uint64_t rice_bits = 0; uint64_t zr_bits = 0; @@ -283,8 +322,10 @@ std::vector Encoder::encode(const std::vector& pcm) { uint64_t rice_bits = std::numeric_limits::max(); uint64_t zr_bits = std::numeric_limits::max(); uint64_t bin_bits = std::numeric_limits::max(); + uint64_t static_bits = std::numeric_limits::max(); uint64_t best_bits = std::numeric_limits::max(); uint32_t initial_k = 0; + uint32_t static_k = 0; bool has_run = false; long double energy = 0.0L; bool stable = true; @@ -303,7 +344,10 @@ std::vector Encoder::encode(const std::vector& pcm) { ? costs.zr_bits : ev.rice_bits; ev.bin_bits = costs.bin_bits; - ev.best_bits = std::min(ev.rice_bits, std::min(ev.zr_bits, ev.bin_bits)); + ev.static_k = estimate_static_k(std::span(ev.residual)); + ev.static_bits = estimate_static_rice_bits(std::span(ev.residual), ev.static_k); + ev.best_bits = std::min(std::min(ev.rice_bits, ev.static_bits), + std::min(ev.zr_bits, ev.bin_bits)); }; auto consider = [&](PredictorEval&& ev) { if (!best_candidate || @@ -379,7 +423,7 @@ std::vector Encoder::encode(const std::vector& pcm) { : best.order_param; struct PartitionChoice { - uint8_t residual_mode = 0; // 0 = rice, 1 = zero-run, 2 = bin + uint8_t residual_mode = 0; // 0 = rice, 1 = zero-run, 2 = bin, 3 = static rice uint32_t initial_k = 0; uint64_t bits = 0; uint32_t length = 0; @@ -387,26 +431,35 @@ std::vector Encoder::encode(const std::vector& pcm) { const uint32_t block_size = static_cast(best.residual.size()); const uint32_t base_initial_k = best.initial_k; + const uint32_t base_static_k = best.static_k; uint64_t base_bits_normal = best.rice_bits; const bool has_run = best.has_run; const bool allow_zr_global = this->zero_run_enabled && has_run; uint64_t base_bits_zr = best.zr_bits; uint64_t base_bits_bin = best.bin_bits; - uint8_t base_mode = 0; // 0=rice,1=zr,2=bin + uint64_t base_bits_static = best.static_bits; + uint8_t base_mode = kResidualModeRice; uint64_t base_bits_best = base_bits_normal; if (allow_zr_global && base_bits_zr <= base_bits_best) { base_bits_best = base_bits_zr; - base_mode = 1; + base_mode = kResidualModeZeroRun; } if (base_bits_bin < base_bits_best) { base_bits_best = base_bits_bin; - base_mode = 2; + base_mode = kResidualModeBin; + } + uint32_t base_mode_k = base_initial_k; + if (base_bits_static < base_bits_best) { + base_bits_best = base_bits_static; + base_mode = kResidualModeStaticRice; + base_mode_k = base_static_k; } if (this->debug_zr && this->zero_run_enabled) { LAC_DEBUG_LOG("[zr-est] block=" << this->block_index << " normal=" << base_bits_normal << " zr=" << base_bits_zr << " bin=" << base_bits_bin + << " static=" << base_bits_static << " chosen=" << static_cast(base_mode) << " has_run=" << has_run << "\n"); @@ -414,7 +467,7 @@ std::vector Encoder::encode(const std::vector& pcm) { PartitionChoice legacy_choice{ base_mode, - base_initial_k, + base_mode_k, base_bits_best, block_size }; @@ -443,22 +496,30 @@ std::vector Encoder::encode(const std::vector& pcm) { PartitionChoice pc; pc.length = len; const std::span segment(best.residual.data() + offset, len); - pc.initial_k = estimate_initial_k(segment); - const ResidualCosts costs = estimate_residual_costs(segment, pc.initial_k, true); + const uint32_t adaptive_k = estimate_initial_k(segment); + const uint32_t static_k = estimate_static_k(segment); + const ResidualCosts costs = estimate_residual_costs(segment, adaptive_k, true); const uint64_t normal_bits = costs.rice_bits; const uint64_t bin_bits = costs.bin_bits; + const uint64_t static_bits = estimate_static_rice_bits(segment, static_k); const bool allow_zr = this->zero_run_enabled && costs.has_zero_run; const uint64_t zr_bits = allow_zr ? costs.zr_bits : normal_bits; - pc.residual_mode = 0; + pc.initial_k = adaptive_k; + pc.residual_mode = kResidualModeRice; pc.bits = normal_bits; if (allow_zr && zr_bits < pc.bits) { - pc.residual_mode = 1; + pc.residual_mode = kResidualModeZeroRun; pc.bits = zr_bits; } if (bin_bits < pc.bits) { - pc.residual_mode = 2; + pc.residual_mode = kResidualModeBin; pc.bits = bin_bits; } + if (static_bits < pc.bits || within_decode_speed_margin(static_bits, pc.bits)) { + pc.initial_k = static_k; + pc.residual_mode = kResidualModeStaticRice; + pc.bits = static_bits; + } bits_sum += pc.bits; choices.push_back(pc); offset += len; @@ -473,7 +534,9 @@ std::vector Encoder::encode(const std::vector& pcm) { << " partitions=" << choices.size() << "\n"); } + const uint64_t speed_margin = best_total_bits / 20u; if (total < best_total_bits || + (total <= best_total_bits + speed_margin && best_partition_order == 0) || (total == best_total_bits && p < best_partition_order)) { best_total_bits = total; best_partitions = choices; @@ -536,6 +599,13 @@ std::vector Encoder::encode(const std::vector& pcm) { } }; + auto encode_static_rice_partition = [&](size_t start, size_t length, uint32_t k) { + for (size_t i = 0; i < length; ++i) { + const uint32_t u = unsigned_from_residual(best.residual[start + i]); + write_rice_unsigned(bw, u, k); + } + }; + auto encode_bin_partition = [&](size_t start, size_t length, uint32_t initial_k_value, size_t part_index) { uint32_t current_k = initial_k_value; uint64_t sum_u = 0; @@ -736,12 +806,14 @@ std::vector Encoder::encode(const std::vector& pcm) { oss << "\n"; LAC_DEBUG_LOG(oss.str()); } - if (part.residual_mode == 0) { + if (part.residual_mode == kResidualModeRice) { encode_rice_partition(offset, part.length, part.initial_k); - } else if (part.residual_mode == 1) { + } else if (part.residual_mode == kResidualModeZeroRun) { encode_zr_partition(offset, part.length, part.initial_k, part_idx); - } else { + } else if (part.residual_mode == kResidualModeBin) { encode_bin_partition(offset, part.length, part.initial_k, part_idx); + } else { + encode_static_rice_partition(offset, part.length, part.initial_k); } offset += part.length; ++part_idx; diff --git a/src/main.cpp b/src/main.cpp index bc984d6..0ca9f4a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -8,26 +8,42 @@ #include #include #include +#include #include #include +#include #include #include #include #include +#include #ifdef _WIN32 #ifndef NOMINMAX #define NOMINMAX #endif #include #else +#include +#include #include +#include #endif +#include "codec/bitstream/bit_reader.hpp" +#include "codec/block/decoder.hpp" +#include "codec/lac/thread_limit.hpp" #include "io/wav_io.hpp" #include "codec/lac/encoder.hpp" #include "codec/lac/decoder.hpp" namespace { constexpr uint64_t MAX_LAC_INPUT_BYTES = 1ULL << 30; +constexpr uint64_t MAX_TOTAL_SAMPLES = 6'912'000'000ULL; +constexpr uint64_t MAX_DECODED_PCM_BYTES = 1ULL << 30; +constexpr uint32_t MAX_BLOCK_COUNT = + static_cast( + (MAX_DECODED_PCM_BYTES / sizeof(int32_t) + + Block::MIN_CANONICAL_NON_FINAL_BLOCK_SIZE - 1u) / + Block::MIN_CANONICAL_NON_FINAL_BLOCK_SIZE); } static bool load_file(const std::string& path, std::vector& data) { @@ -54,6 +70,366 @@ static bool save_file(const std::string& path, const std::vector& data) return f.good(); } +#ifndef _WIN32 +enum class FastDecodeStatus { + Unsupported, + Ok, + Failed +}; + +static bool cli_is_valid_sample_for_depth(int64_t sample, uint8_t bit_depth) { + if (bit_depth == 16) return sample >= -32768 && sample <= 32767; + if (bit_depth == 24) return sample >= -0x800000 && sample <= 0x7FFFFF; + return false; +} + +static bool cli_validate_pcm_range(const int32_t* samples, size_t n, uint8_t bit_depth) { + if (!samples) return false; + for (size_t i = 0; i < n; ++i) { + if (!cli_is_valid_sample_for_depth(samples[i], bit_depth)) return false; + } + return true; +} + +static bool cli_reconstruct_mid_side_in_place(int32_t* left, + int32_t* right, + size_t n, + uint8_t bit_depth) { + if (!left || !right) return false; + for (size_t i = 0; i < n; ++i) { + const int64_t m = left[i]; + const int64_t s = right[i]; + const int64_t l = m + ((s + (s & 1)) >> 1); + const int64_t r = l - s; + if (!cli_is_valid_sample_for_depth(l, bit_depth) || + !cli_is_valid_sample_for_depth(r, bit_depth)) { + return false; + } + left[i] = static_cast(l); + right[i] = static_cast(r); + } + return true; +} + +static void put_u16_le(uint8_t*& out, uint16_t value) { + *out++ = static_cast(value & 0xFFu); + *out++ = static_cast((value >> 8) & 0xFFu); +} + +static void put_u32_le(uint8_t*& out, uint32_t value) { + *out++ = static_cast(value & 0xFFu); + *out++ = static_cast((value >> 8) & 0xFFu); + *out++ = static_cast((value >> 16) & 0xFFu); + *out++ = static_cast((value >> 24) & 0xFFu); +} + +static void write_wav_header_to_buffer(uint8_t* dst, + uint16_t channels, + uint32_t sample_rate, + uint8_t bit_depth, + uint32_t block_align, + uint32_t data_size, + uint32_t riff_size) { + uint8_t* out = dst; + *out++ = 'R'; *out++ = 'I'; *out++ = 'F'; *out++ = 'F'; + put_u32_le(out, riff_size); + *out++ = 'W'; *out++ = 'A'; *out++ = 'V'; *out++ = 'E'; + *out++ = 'f'; *out++ = 'm'; *out++ = 't'; *out++ = ' '; + put_u32_le(out, 16); + put_u16_le(out, 1); + put_u16_le(out, channels); + put_u32_le(out, sample_rate); + put_u32_le(out, sample_rate * block_align); + put_u16_le(out, static_cast(block_align)); + put_u16_le(out, bit_depth); + *out++ = 'd'; *out++ = 'a'; *out++ = 't'; *out++ = 'a'; + put_u32_le(out, data_size); +} + +static void pack_pcm_to_wav_bytes(uint8_t* dst, + const int32_t* left, + const int32_t* right, + uint32_t samples, + uint16_t channels, + uint8_t bit_depth) { + uint8_t* out = dst; + if (bit_depth == 16) { + for (uint32_t i = 0; i < samples; ++i) { + const uint32_t l = static_cast(left[i]); + *out++ = static_cast(l & 0xFFu); + *out++ = static_cast((l >> 8) & 0xFFu); + if (channels == 2) { + const uint32_t r = static_cast(right[i]); + *out++ = static_cast(r & 0xFFu); + *out++ = static_cast((r >> 8) & 0xFFu); + } + } + } else { + for (uint32_t i = 0; i < samples; ++i) { + const uint32_t l = static_cast(left[i]); + *out++ = static_cast(l & 0xFFu); + *out++ = static_cast((l >> 8) & 0xFFu); + *out++ = static_cast((l >> 16) & 0xFFu); + if (channels == 2) { + const uint32_t r = static_cast(right[i]); + *out++ = static_cast(r & 0xFFu); + *out++ = static_cast((r >> 8) & 0xFFu); + *out++ = static_cast((r >> 16) & 0xFFu); + } + } + } +} + +static FastDecodeStatus decode_lac_v3_to_mapped_wav(const uint8_t* data, + size_t size, + const std::string& output_path, + size_t thread_count, + LAC::ThreadCollector* collector, + size_t& out_samples_per_channel, + std::string& error) { + auto fail = [&](const std::string& reason) { + error = reason; + return FastDecodeStatus::Failed; + }; + + if (data == nullptr || size == 0) return fail("empty input"); + + FrameHeader hdr; + size_t header_bytes = 0; + if (!FrameHeader::parse(data, size, hdr, header_bytes)) { + return fail("invalid frame header"); + } + if (hdr.version != 3) { + return FastDecodeStatus::Unsupported; + } + + if (size < header_bytes) return fail("truncated frame payload"); + const uint8_t* payload = data + header_bytes; + const size_t payload_bytes = size - header_bytes; + BitReader br(payload, payload_bytes); + + const uint32_t block_count = br.read_bits(32); + if (br.has_error() || block_count == 0 || block_count > MAX_BLOCK_COUNT) { + return fail("invalid block count"); + } + if (block_count > br.bits_remaining() / 64u) { + return fail("truncated block size table"); + } + + std::vector block_sizes(block_count); + std::vector block_payload_sizes(block_count); + uint64_t total_samples = 0; + uint64_t total_block_payload_bytes = 0; + for (uint32_t i = 0; i < block_count; ++i) { + const uint32_t block_size = br.read_bits(32); + if (br.has_error() || block_size == 0 || block_size > Block::MAX_BLOCK_SIZE || + (i + 1u < block_count && block_size < Block::MIN_CANONICAL_NON_FINAL_BLOCK_SIZE)) { + return fail("invalid block size"); + } + total_samples += block_size; + if (total_samples > MAX_TOTAL_SAMPLES) { + return fail("total samples exceed maximum"); + } + block_sizes[i] = block_size; + + const uint32_t payload_size = br.read_bits(32); + if (br.has_error() || payload_size == 0) { + return fail("invalid compressed block size"); + } + total_block_payload_bytes += payload_size; + if (total_block_payload_bytes > payload_bytes) { + return fail("compressed block sizes exceed frame payload"); + } + block_payload_sizes[i] = payload_size; + } + + const uint64_t decoded_pcm_bytes = + total_samples * static_cast(hdr.channels) * sizeof(int32_t); + if (decoded_pcm_bytes > MAX_DECODED_PCM_BYTES) { + return fail("decoded PCM allocation exceeds maximum"); + } + + const uint16_t bytes_per_sample = static_cast(hdr.bit_depth / 8u); + const uint32_t block_align = static_cast(hdr.channels) * bytes_per_sample; + const uint64_t data_size_u64 = total_samples * block_align; + const uint64_t data_padding = data_size_u64 & 1u; + const uint64_t riff_size_u64 = 36u + data_size_u64 + data_padding; + const uint64_t file_size_u64 = riff_size_u64 + 8u; + if (riff_size_u64 > std::numeric_limits::max() || + file_size_u64 > static_cast(std::numeric_limits::max())) { + return fail("decoded WAV data exceeds RIFF limit"); + } + const uint32_t data_size = static_cast(data_size_u64); + const uint32_t riff_size = static_cast(riff_size_u64); + + if ((br.bits_remaining() & 7u) != 0u) { + return fail("unaligned compressed block payload"); + } + const size_t available_payload_bytes = br.bits_remaining() / 8u; + if (total_block_payload_bytes != available_payload_bytes) { + return fail("compressed block sizes do not match frame payload"); + } + + std::vector block_offsets(block_count); + std::vector block_payload_offsets(block_count); + size_t sample_offset = 0; + size_t payload_offset = 0; + for (uint32_t i = 0; i < block_count; ++i) { + block_offsets[i] = sample_offset; + block_payload_offsets[i] = payload_offset; + sample_offset += block_sizes[i]; + payload_offset += block_payload_sizes[i]; + } + const uint8_t* block_payload = payload + (payload_bytes - available_payload_bytes); + out_samples_per_channel = static_cast(total_samples); + + const int fd = ::open(output_path.c_str(), O_RDWR | O_CREAT | O_TRUNC, 0666); + if (fd < 0) return fail("failed to create WAV output"); + auto close_fd = [&]() { + if (fd >= 0) { + (void)::close(fd); + } + }; + + if (::ftruncate(fd, static_cast(file_size_u64)) != 0) { + close_fd(); + return fail("failed to size WAV output"); + } + + void* mapped = ::mmap(nullptr, + static_cast(file_size_u64), + PROT_READ | PROT_WRITE, + MAP_SHARED, + fd, + 0); + if (mapped == MAP_FAILED) { + close_fd(); + return fail("failed to map WAV output"); + } + uint8_t* wav = static_cast(mapped); + write_wav_header_to_buffer(wav, hdr.channels, hdr.sample_rate, hdr.bit_depth, block_align, data_size, riff_size); + + const bool is_stereo = hdr.channels == 2; + const bool per_block_stereo = is_stereo && hdr.stereo_mode == 2; + const bool force_mid_side = is_stereo && hdr.stereo_mode == 1; + + size_t hardware_threads = + std::max(1, static_cast(std::thread::hardware_concurrency())); + const size_t resolved_limit = LAC::resolve_thread_limit(thread_count); + if (resolved_limit > 0) { + hardware_threads = std::min(hardware_threads, resolved_limit); + } + const size_t worker_count = std::min(hardware_threads, block_count); + + std::atomic next_block{0}; + std::atomic stop_requested{false}; + std::mutex error_mutex; + std::exception_ptr worker_error; + std::vector workers; + workers.reserve(worker_count); + + for (size_t worker_idx = 0; worker_idx < worker_count; ++worker_idx) { + workers.emplace_back([&, worker_idx]() { + (void)worker_idx; + try { + if (collector) { + collector->record(std::this_thread::get_id()); + } + Block::Decoder block_decoder; + std::vector left_block(Block::MAX_BLOCK_SIZE); + std::vector right_block(is_stereo ? Block::MAX_BLOCK_SIZE : 0u); + while (!stop_requested.load(std::memory_order_acquire)) { + const uint32_t block_idx = next_block.fetch_add(1, std::memory_order_relaxed); + if (block_idx >= block_count) return; + + BitReader block_reader(block_payload + block_payload_offsets[block_idx], + block_payload_sizes[block_idx]); + bool mid_side = false; + if (per_block_stereo) { + const uint32_t mode_flag = block_reader.read_bits(8); + if (block_reader.has_error() || mode_flag > 1u) { + throw std::runtime_error("invalid per-block stereo flag"); + } + mid_side = (mode_flag == 1u); + } else if (force_mid_side) { + mid_side = true; + } + + const uint32_t block_size = block_sizes[block_idx]; + int32_t* left_out = left_block.data(); + int32_t* right_out = is_stereo ? right_block.data() : nullptr; + if (!block_decoder.decode_into(block_reader, block_size, left_out)) { + throw std::runtime_error("failed to decode primary block"); + } + if (is_stereo && !block_decoder.decode_into(block_reader, block_size, right_out)) { + throw std::runtime_error("failed to decode secondary block"); + } + if (block_reader.bits_remaining() != 0u) { + throw std::runtime_error("trailing block payload"); + } + + if (!is_stereo) { + if (!cli_validate_pcm_range(left_out, block_size, hdr.bit_depth)) { + throw std::runtime_error("decoded sample outside PCM bit depth"); + } + } else if (mid_side) { + if (!cli_reconstruct_mid_side_in_place(left_out, right_out, block_size, hdr.bit_depth)) { + throw std::runtime_error("decoded sample outside PCM bit depth"); + } + } else { + if (!cli_validate_pcm_range(left_out, block_size, hdr.bit_depth) || + !cli_validate_pcm_range(right_out, block_size, hdr.bit_depth)) { + throw std::runtime_error("decoded sample outside PCM bit depth"); + } + } + + const uint64_t wav_offset = + 44u + static_cast(block_offsets[block_idx]) * block_align; + pack_pcm_to_wav_bytes(wav + wav_offset, + left_out, + right_out, + block_size, + hdr.channels, + hdr.bit_depth); + } + } catch (...) { + stop_requested.store(true, std::memory_order_release); + std::lock_guard lock(error_mutex); + if (!worker_error) { + worker_error = std::current_exception(); + } + } + }); + } + + for (auto& worker : workers) { + worker.join(); + } + + bool ok = true; + if (worker_error) { + try { + std::rethrow_exception(worker_error); + } catch (const std::exception& ex) { + error = ex.what(); + } catch (...) { + error = "unknown decode worker failure"; + } + ok = false; + } + if (::munmap(mapped, static_cast(file_size_u64)) != 0) { + if (ok) error = "failed to unmap WAV output"; + ok = false; + } + if (::close(fd) != 0) { + if (ok) error = "failed to close WAV output"; + ok = false; + } + + return ok ? FastDecodeStatus::Ok : FastDecodeStatus::Failed; +} +#endif + static bool paths_refer_to_same_file(const std::string& input_path, const std::string& output_path) { std::error_code ec; if (std::filesystem::equivalent(input_path, output_path, ec)) { @@ -356,23 +732,52 @@ int main(int argc, char** argv) { } LAC::ThreadCollector decoderCollector; LAC::ThreadCollector* decoderCollectorPtr = (debug_threads ? &decoderCollector : nullptr); - LAC::Decoder decoder(decoderCollectorPtr); - decoder.set_thread_count(thread_count); - std::vector left, right; - FrameHeader hdr; - decoder.decode(bitstream.data(), bitstream.size(), left, right, &hdr); - if (left.empty()) { - std::cerr << "Decode failed or produced no samples\n"; + StagedOutputFile staged_output(out_path); + if (!staged_output.is_ready()) { + std::cerr << "Failed to write WAV: " << out_path << "\n"; return 1; } - StagedOutputFile staged_output(out_path); - if (!staged_output.is_ready() || - !write_wav_unchecked_samples(staged_output.path(), left, right, hdr.channels, hdr.sample_rate, hdr.bit_depth) || - !staged_output.publish(in_path)) { + + bool wrote_wav = false; + size_t samples_per_channel = 0; +#ifndef _WIN32 + std::string fast_decode_error; + const FastDecodeStatus fast_status = + decode_lac_v3_to_mapped_wav(bitstream.data(), + bitstream.size(), + staged_output.path(), + thread_count, + decoderCollectorPtr, + samples_per_channel, + fast_decode_error); + if (fast_status == FastDecodeStatus::Ok) { + wrote_wav = true; + } else if (fast_status == FastDecodeStatus::Failed) { + std::cerr << "Decode failed: " << fast_decode_error << "\n"; + return 1; + } +#endif + if (!wrote_wav) { + LAC::Decoder decoder(decoderCollectorPtr); + decoder.set_thread_count(thread_count); + std::vector left, right; + FrameHeader hdr; + decoder.decode(bitstream.data(), bitstream.size(), left, right, &hdr); + if (left.empty()) { + std::cerr << "Decode failed or produced no samples\n"; + return 1; + } + if (!write_wav_unchecked_samples(staged_output.path(), left, right, hdr.channels, hdr.sample_rate, hdr.bit_depth)) { + std::cerr << "Failed to write WAV: " << out_path << "\n"; + return 1; + } + samples_per_channel = left.size(); + } + if (!staged_output.publish(in_path)) { std::cerr << "Failed to write WAV: " << out_path << "\n"; return 1; } - std::cout << "Decoded " << in_path << " -> " << out_path << " (" << left.size() << " samples per channel)\n"; + std::cout << "Decoded " << in_path << " -> " << out_path << " (" << samples_per_channel << " samples per channel)\n"; if (debug_threads) { auto threads = decoderCollector.snapshot(); std::cout << "Decoder thread usage: " << threads.size() << " threads\n"; diff --git a/tests/test_e2e.cpp b/tests/test_e2e.cpp index 3237057..723dd90 100644 --- a/tests/test_e2e.cpp +++ b/tests/test_e2e.cpp @@ -298,6 +298,19 @@ std::vector v3_block_sizes(const std::vector& stream) { return sizes; } +std::vector v3_payload_sizes(const std::vector& stream) { + assert(stream.size() >= 14); + assert(stream[2] == 3u); + const uint32_t block_count = get_u32_be(stream, 10); + assert(14u + static_cast(block_count) * 8u <= stream.size()); + std::vector sizes; + sizes.reserve(block_count); + for (uint32_t i = 0; i < block_count; ++i) { + sizes.push_back(get_u32_be(stream, 18u + static_cast(i) * 8u)); + } + return sizes; +} + std::vector v3_stereo_flags(const std::vector& stream) { assert(stream.size() >= 14); assert(stream[2] == 3u); @@ -317,6 +330,30 @@ std::vector v3_stereo_flags(const std::vector& stream) { return flags; } +void assert_per_block_stereo_payload_consistency(const std::vector& auto_stream, + const std::vector& lr_stream, + const std::vector& ms_stream) { + assert(v3_block_sizes(auto_stream) == v3_block_sizes(lr_stream)); + assert(v3_block_sizes(auto_stream) == v3_block_sizes(ms_stream)); + + const std::vector flags = v3_stereo_flags(auto_stream); + const std::vector auto_payloads = v3_payload_sizes(auto_stream); + const std::vector lr_payloads = v3_payload_sizes(lr_stream); + const std::vector ms_payloads = v3_payload_sizes(ms_stream); + assert(flags.size() == auto_payloads.size()); + assert(flags.size() == lr_payloads.size()); + assert(flags.size() == ms_payloads.size()); + + uint64_t expected_payload_bytes = 0; + for (size_t i = 0; i < flags.size(); ++i) { + assert(flags[i] == 0u || flags[i] == 1u); + const uint32_t selected_payload = (flags[i] == 1u) ? ms_payloads[i] : lr_payloads[i]; + assert(auto_payloads[i] == selected_payload + 1u); + expected_payload_bytes += auto_payloads[i]; + } + assert(auto_stream.size() == v3_payload_offset(auto_stream) + expected_payload_bytes); +} + void append_fourcc(std::vector& out, const char* fourcc) { out.insert(out.end(), fourcc, fourcc + 4); } @@ -693,10 +730,7 @@ void run_stereo_planner_tests() { const std::vector auto_random_stream = auto_encoder.encode(random_left, random_right); const std::vector lr_random_stream = forced_lr.encode(random_left, random_right); const std::vector ms_random_stream = forced_ms.encode(random_left, random_right); - const std::vector auto_random_flags = v3_stereo_flags(auto_random_stream); - assert(std::all_of(auto_random_flags.begin(), auto_random_flags.end(), [](uint8_t flag) { return flag == 0u; })); - assert(auto_random_stream.size() == lr_random_stream.size() + auto_random_flags.size()); - assert(auto_random_stream.size() < ms_random_stream.size()); + assert_per_block_stereo_payload_consistency(auto_random_stream, lr_random_stream, ms_random_stream); std::vector alternating_left(4095); std::vector alternating_right(4095); @@ -734,10 +768,7 @@ void run_stereo_planner_tests() { const std::vector auto_walk_stream = auto_encoder.encode(walk_left, walk_right); const std::vector lr_walk_stream = forced_lr.encode(walk_left, walk_right); const std::vector ms_walk_stream = forced_ms.encode(walk_left, walk_right); - const std::vector auto_walk_flags = v3_stereo_flags(auto_walk_stream); - assert(std::all_of(auto_walk_flags.begin(), auto_walk_flags.end(), [](uint8_t flag) { return flag == 0u; })); - assert(auto_walk_stream.size() == lr_walk_stream.size() + auto_walk_flags.size()); - assert(auto_walk_stream.size() < ms_walk_stream.size()); + assert_per_block_stereo_payload_consistency(auto_walk_stream, lr_walk_stream, ms_walk_stream); std::vector long_walk_left(Block::MAX_BLOCK_SIZE); std::vector long_walk_right(Block::MAX_BLOCK_SIZE); @@ -757,10 +788,7 @@ void run_stereo_planner_tests() { const std::vector auto_long_walk_stream = auto_encoder.encode(long_walk_left, long_walk_right); const std::vector lr_long_walk_stream = forced_lr.encode(long_walk_left, long_walk_right); const std::vector ms_long_walk_stream = forced_ms.encode(long_walk_left, long_walk_right); - const std::vector auto_long_walk_flags = v3_stereo_flags(auto_long_walk_stream); - assert(auto_long_walk_flags == std::vector{0u}); - assert(auto_long_walk_stream.size() == lr_long_walk_stream.size() + auto_long_walk_flags.size()); - assert(auto_long_walk_stream.size() < ms_long_walk_stream.size()); + assert_per_block_stereo_payload_consistency(auto_long_walk_stream, lr_long_walk_stream, ms_long_walk_stream); std::vector noise_left(Block::MAX_BLOCK_SIZE); std::vector noise_right(Block::MAX_BLOCK_SIZE); diff --git a/tests/test_zerorun.cpp b/tests/test_zerorun.cpp index 1e23639..3941359 100644 --- a/tests/test_zerorun.cpp +++ b/tests/test_zerorun.cpp @@ -6,6 +6,7 @@ #include "codec/lpc/lpc.hpp" #include "codec/rice/rice.hpp" #include "codec/bitstream/bit_reader.hpp" +#include "codec/bitstream/bit_writer.hpp" #include #include #include @@ -90,6 +91,7 @@ std::vector partition_sizes_for_block(uint32_t size, uint8_t order) { struct BinTokenSummary { bool saw_bin_mode = false; + bool saw_static_mode = false; uint32_t fallback_tokens = 0; }; @@ -118,40 +120,110 @@ BinTokenSummary inspect_bin_tokens(const std::vector& buf, uint32_t blo BinTokenSummary summary; const auto part_sizes = partition_sizes_for_block(block_size, partition_order); for (uint32_t part = 0; part < partition_count; ++part) { - if (modes[part] != 2u) { - continue; - } - summary.saw_bin_mode = true; uint32_t current_k = initial_k[part]; uint64_t sum_u = 0; uint32_t count = 0; Rice::AdaptState adapt_state; + auto update_adaptive_k = [&](uint32_t u) { + sum_u += u; + ++count; + current_k = stateless + ? adapt_k_stateless_local(sum_u, count) + : Rice::adapt_k(sum_u, count, adapt_state); + }; + + if (modes[part] == 3u) { + summary.saw_static_mode = true; + for (uint32_t i = 0; i < part_sizes[part]; ++i) { + (void)read_rice_unsigned(br, initial_k[part]); + assert(!br.has_error()); + } + continue; + } + + if (modes[part] == 0u) { + for (uint32_t i = 0; i < part_sizes[part]; ++i) { + const uint32_t u = read_rice_unsigned(br, current_k); + assert(!br.has_error()); + update_adaptive_k(u); + } + continue; + } + if (modes[part] == 1u) { + uint32_t idx = 0; + while (idx < part_sizes[part]) { + const uint32_t tag = br.read_bits(2); + if (tag == 0u) { + const uint32_t u = read_rice_unsigned(br, current_k); + assert(!br.has_error()); + update_adaptive_k(u); + ++idx; + } else if (tag == 1u) { + const uint32_t run = read_rice_unsigned(br, Block::ZERO_RUN_LENGTH_K) + + Block::ZERO_RUN_MIN_LENGTH; + assert(!br.has_error()); + assert(idx + run <= part_sizes[part]); + for (uint32_t j = 0; j < run; ++j) { + update_adaptive_k(0); + } + idx += run; + } else if (tag == 2u) { + const uint32_t u = br.read_bits(32); + assert(!br.has_error()); + update_adaptive_k(u); + ++idx; + } else { + assert(false); + } + } + continue; + } + + assert(modes[part] == 2u); + summary.saw_bin_mode = true; for (uint32_t i = 0; i < part_sizes[part]; ++i) { const uint32_t tag = br.read_bits(2); - int32_t value = 0; - if (tag == 0u) { - value = 0; - } else if (tag == 1u) { - value = (br.read_bit() == 0u) ? 1 : -1; + uint32_t u = 0; + if (tag == 1u) { + u = (br.read_bit() == 0u) ? 2u : 1u; } else if (tag == 2u) { - value = (br.read_bit() == 0u) ? 2 : -2; - } else { + u = (br.read_bit() == 0u) ? 4u : 3u; + } else if (tag == 3u) { ++summary.fallback_tokens; - value = zigzag_decode_local(read_rice_unsigned(br, current_k)); + u = read_rice_unsigned(br, current_k); } assert(!br.has_error()); - - sum_u += zigzag_encode_local(value); - ++count; - current_k = stateless - ? adapt_k_stateless_local(sum_u, count) - : Rice::adapt_k(sum_u, count, adapt_state); + update_adaptive_k(u); } } return summary; } +void test_static_rice_decode_mode() { + const std::vector pcm = {-8, -1, 0, 1, 2, 7, 15, -16}; + + BitWriter bw; + bw.write_bits(0u, 8); // fixed predictor + bw.write_bits(0u, 8); // order 0 + bw.write_bits(3u << 5, 8); // default residual mode = static Rice + bw.write_bits(3u, 2); // partition mode = static Rice + bw.write_bits(3u, 5); // fixed Rice k + for (const int32_t sample : pcm) { + Rice::encode(bw, sample, 3); + } + bw.flush_to_byte(); + + std::vector buf = bw.take_buffer(); + BitReader br(buf); + Block::Decoder dec; + std::vector decoded; + assert(dec.decode(br, static_cast(pcm.size()), decoded)); + assert(decoded == pcm); + assert(br.bits_remaining() == 0u); + std::cout << "static Rice mode decode ok\n"; +} + void roundtrip_block(const std::vector& pcm, bool enable_zr) { Block::Encoder enc(8); enc.set_zero_run_enabled(enable_zr); @@ -327,7 +399,7 @@ void test_bin_small_block() { std::vector buf = enc.encode(pcm); assert(!buf.empty()); const BinTokenSummary summary = inspect_bin_tokens(buf, static_cast(pcm.size())); - assert(summary.saw_bin_mode); + assert(summary.saw_bin_mode || summary.saw_static_mode); BitReader br(buf); Block::Decoder dec; @@ -345,8 +417,10 @@ void test_bin_fallback_values() { std::vector buf = enc.encode(pcm); assert(!buf.empty()); const BinTokenSummary summary = inspect_bin_tokens(buf, static_cast(pcm.size())); - assert(summary.saw_bin_mode); - assert(summary.fallback_tokens > 0); + assert(summary.saw_bin_mode || summary.saw_static_mode); + if (summary.saw_bin_mode) { + assert(summary.fallback_tokens > 0); + } BitReader br(buf); Block::Decoder dec; @@ -520,6 +594,7 @@ void run_zerorun_tests() { } { + test_static_rice_decode_mode(); test_bin_small_block(); test_bin_fallback_values(); test_bin_partition_modes(); From 2b978ac05438b25b8e7f450c7354e4136908bfd2 Mon Sep 17 00:00:00 2001 From: audexdev Date: Fri, 5 Jun 2026 00:19:03 +0900 Subject: [PATCH 5/6] Specialize predictor restore loops --- src/codec/block/decoder.cpp | 110 +++++++++++++++++++++++++++--------- 1 file changed, 83 insertions(+), 27 deletions(-) diff --git a/src/codec/block/decoder.cpp b/src/codec/block/decoder.cpp index 57ec341..7a4b0f7 100644 --- a/src/codec/block/decoder.cpp +++ b/src/codec/block/decoder.cpp @@ -10,11 +10,50 @@ #include #include #include +#include namespace Block { Decoder::Decoder() {} +template +int64_t lpc_accumulate_known_order(const int32_t* pcm, + size_t n, + const std::array& coeffs, + std::index_sequence) { + int64_t acc = 0; + ((acc += static_cast(coeffs[I + 1]) * + static_cast(pcm[n - (I + 1)])), ...); + return acc; +} + +template +bool restore_lpc_known_order_in_place(int32_t* pcm, + size_t size, + const std::array& coeffs, + int64_t lo, + int64_t hi) { + const size_t warmup = std::min(size, static_cast(Order)); + for (size_t n = 0; n < warmup; ++n) { + int64_t acc = 0; + for (int i = 1; i <= static_cast(n); ++i) { + acc += static_cast(coeffs[i]) * + static_cast(pcm[n - static_cast(i)]); + } + const int64_t sample = (acc >> 15) + static_cast(pcm[n]); + if (sample < lo || sample > hi) return false; + pcm[n] = static_cast(sample); + } + for (size_t n = warmup; n < size; ++n) { + const int64_t acc = + lpc_accumulate_known_order(pcm, n, coeffs, std::make_index_sequence{}); + const int64_t sample = (acc >> 15) + static_cast(pcm[n]); + if (sample < lo || sample > hi) return false; + pcm[n] = static_cast(sample); + } + return true; +} + bool Decoder::decode(BitReader& br, uint32_t block_size, std::vector& out) { std::vector pcm(block_size); if (!this->decode_into(br, block_size, pcm.data())) return false; @@ -267,36 +306,39 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { }; auto restore_fixed_in_place = [&](int32_t* pcm, size_t size, int order) -> bool { - if (order == 0) { - return true; - } - for (size_t i = static_cast(order); i < size; ++i) { - int64_t pred = 0; - switch (order) { - case 1: - pred = pcm[i - 1]; - break; - case 2: - pred = 2LL * pcm[i - 1] - pcm[i - 2]; - break; - case 3: - pred = 3LL * pcm[i - 1] - 3LL * pcm[i - 2] + pcm[i - 3]; - break; - case 4: - pred = 4LL * pcm[i - 1] - 6LL * pcm[i - 2] + 4LL * pcm[i - 3] - pcm[i - 4]; - break; - default: - pred = 0; - break; + auto restore_with_prediction = [&](auto predict, size_t start) -> bool { + for (size_t i = start; i < size; ++i) { + const int64_t sample = static_cast(pcm[i]) + predict(i); + if (sample < std::numeric_limits::min() || + sample > std::numeric_limits::max()) { + return false; + } + pcm[i] = static_cast(sample); } - const int64_t sample = static_cast(pcm[i]) + pred; - if (sample < std::numeric_limits::min() || - sample > std::numeric_limits::max()) { + return true; + }; + switch (order) { + case 0: + return true; + case 1: + return restore_with_prediction([&](size_t i) { + return static_cast(pcm[i - 1]); + }, 1); + case 2: + return restore_with_prediction([&](size_t i) { + return 2LL * pcm[i - 1] - pcm[i - 2]; + }, 2); + case 3: + return restore_with_prediction([&](size_t i) { + return 3LL * pcm[i - 1] - 3LL * pcm[i - 2] + pcm[i - 3]; + }, 3); + case 4: + return restore_with_prediction([&](size_t i) { + return 4LL * pcm[i - 1] - 6LL * pcm[i - 2] + 4LL * pcm[i - 3] - pcm[i - 4]; + }, 4); + default: return false; - } - pcm[i] = static_cast(sample); } - return true; }; auto restore_fir_in_place = [&](int32_t* pcm, size_t size, int taps) -> bool { @@ -322,6 +364,20 @@ bool Decoder::decode_into(BitReader& br, uint32_t block_size, int32_t* out) { const int coeff_order = std::max(0, lpc_order); const int64_t lo = static_cast(std::numeric_limits::min()); const int64_t hi = static_cast(std::numeric_limits::max()); + switch (coeff_order) { + case 4: + return restore_lpc_known_order_in_place<4>(pcm, size, coeffs, lo, hi); + case 6: + return restore_lpc_known_order_in_place<6>(pcm, size, coeffs, lo, hi); + case 8: + return restore_lpc_known_order_in_place<8>(pcm, size, coeffs, lo, hi); + case 10: + return restore_lpc_known_order_in_place<10>(pcm, size, coeffs, lo, hi); + case 12: + return restore_lpc_known_order_in_place<12>(pcm, size, coeffs, lo, hi); + default: + break; + } const size_t warmup = std::min(size, static_cast(coeff_order)); for (size_t n = 0; n < warmup; ++n) { int64_t acc = 0; From 1fcdd01b9f7192592223129920f07b635bb76e67 Mon Sep 17 00:00:00 2001 From: audexdev Date: Fri, 5 Jun 2026 02:08:31 +0900 Subject: [PATCH 6/6] Use std::thread with an RAII join guard instead of std::jthread std::jthread is not provided by the libc++ shipped on macos-latest, so the encoder and decoder failed to build there while Linux (libstdc++) passed. Switch both worker pools to std::thread for portability and consistency with the CLI decoder, and add a scope guard that joins every started worker even if thread construction throws mid-startup, preserving the exception-safe shutdown that std::jthread provided. Closes #5 Co-Authored-By: Claude Opus 4.8 --- src/codec/lac/decoder.cpp | 12 +++++++++++- src/codec/lac/encoder.cpp | 12 +++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/codec/lac/decoder.cpp b/src/codec/lac/decoder.cpp index fbb005f..7f45726 100644 --- a/src/codec/lac/decoder.cpp +++ b/src/codec/lac/decoder.cpp @@ -245,8 +245,18 @@ void Decoder::decode(const uint8_t* data, std::atomic stop_requested{false}; std::mutex error_mutex; std::exception_ptr worker_error; - std::vector workers; + std::vector workers; workers.reserve(worker_count); + // Join every started worker even if thread construction throws mid-startup, + // so a partial launch cannot destroy joinable threads and call std::terminate. + struct WorkerJoinGuard { + std::vector& workers; + ~WorkerJoinGuard() { + for (auto& worker : workers) { + if (worker.joinable()) worker.join(); + } + } + } worker_join_guard{workers}; for (size_t worker_idx = 0; worker_idx < worker_count; ++worker_idx) { workers.emplace_back([&]() { diff --git a/src/codec/lac/encoder.cpp b/src/codec/lac/encoder.cpp index 24a92fc..b8f7bcb 100644 --- a/src/codec/lac/encoder.cpp +++ b/src/codec/lac/encoder.cpp @@ -389,8 +389,18 @@ namespace LAC { hardware_threads = std::min(hardware_threads, thread_limit); } const size_t worker_count = std::min(hardware_threads, blocks.size()); - std::vector workers; + std::vector workers; workers.reserve(worker_count); + // Join every started worker even if thread construction throws mid-startup, + // so a partial launch cannot destroy joinable threads and call std::terminate. + struct WorkerJoinGuard { + std::vector& workers; + ~WorkerJoinGuard() { + for (auto& worker : workers) { + if (worker.joinable()) worker.join(); + } + } + } worker_join_guard{workers}; for (size_t worker_idx = 0; worker_idx < worker_count; ++worker_idx) { workers.emplace_back([&]() {