diff --git a/src/ltx_audio_vae.h b/src/ltx_audio_vae.h index d1d765d75..aa4f49ade 100644 --- a/src/ltx_audio_vae.h +++ b/src/ltx_audio_vae.h @@ -5,6 +5,7 @@ #include #include #include +#include #include "ggml_extend.hpp" @@ -200,6 +201,37 @@ namespace LTXV { return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml(tensor)); } + static inline bool is_power_of_two(size_t x) { + return x > 0 && (x & (x - 1)) == 0; + } + + static void fft(std::vector>& a) { + constexpr float ktau = 6.28318530717958647692f; + size_t n = a.size(); + GGML_ASSERT(is_power_of_two(n) && "Filter length must be a power of two for FFT-based spectrogram computation"); + for (int i = 1, j = 0; i < n; i++) { + int bit = n >> 1; + for (; j & bit; bit >>= 1) + j ^= bit; + j ^= bit; + if (i < j) + std::swap(a[i], a[j]); + } + for (int len = 2; len <= n; len <<= 1) { + float ang = -ktau / len; + std::complex wlen{std::cos(ang), std::sin(ang)}; + for (int i = 0; i < n; i += len) { + std::complex w{1, 0}; + for (int j = 0; j < len / 2; j++) { + std::complex u = a[i + j], v = a[i + j + len / 2] * w; + a[i + j] = u + v; + a[i + j + len / 2] = u - v; + w = w * wlen; + } + } + } + } + static sd::Tensor compute_log_mel_spectrogram(const sd::Tensor& waveform_in, const sd::Tensor& forward_basis, const sd::Tensor& mel_basis, @@ -208,19 +240,20 @@ namespace LTXV { GGML_ASSERT(forward_basis.dim() >= 3); GGML_ASSERT(mel_basis.dim() >= 2); - const int64_t time = waveform.shape()[0]; - const int64_t channels = waveform.shape()[1]; - const int64_t batch = waveform.shape()[2]; - const int64_t filter_len = forward_basis.shape()[0]; - const int64_t basis_freq2 = forward_basis.shape().back(); - const int64_t n_freqs = basis_freq2 / 2; - const int64_t n_mels = mel_basis.shape()[1]; + const int64_t time = waveform.shape()[0]; + const int64_t channels = waveform.shape()[1]; + const int64_t batch = waveform.shape()[2]; + const int64_t filter_len = forward_basis.shape()[0]; + const int64_t n_freqs = filter_len / 2 + 1; + const int64_t n_mels = mel_basis.shape()[1]; + const int64_t left_pad = std::max(0, filter_len - hop_length); const int64_t padded_time = time + left_pad; const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length; sd::Tensor log_mel({n_mels, frame_count, channels, batch}); std::vector padded(static_cast(padded_time), 0.0f); + std::vector> fft_buffer(static_cast(filter_len)); std::vector magnitude(static_cast(n_freqs), 0.0f); for (int64_t b = 0; b < batch; ++b) { @@ -232,23 +265,27 @@ namespace LTXV { for (int64_t frame = 0; frame < frame_count; ++frame) { const int64_t frame_offset = frame * hop_length; + + for (int64_t k = 0; k < filter_len; ++k) { + float gate_weight = forward_basis.index(k, 0, 0); + + float gated_sample = padded[static_cast(frame_offset + k)] * gate_weight; + + fft_buffer[static_cast(k)] = {gated_sample, 0.0f}; + } + + fft(fft_buffer); + for (int64_t f = 0; f < n_freqs; ++f) { - double real = 0.0; - double imag = 0.0; - for (int64_t k = 0; k < filter_len; ++k) { - const float sample = padded[static_cast(frame_offset + k)]; - real += static_cast(sample) * static_cast(forward_basis.index(k, 0, f)); - imag += static_cast(sample) * static_cast(forward_basis.index(k, 0, f + n_freqs)); - } - magnitude[static_cast(f)] = static_cast(std::sqrt(real * real + imag * imag)); + magnitude[static_cast(f)] = std::abs(fft_buffer[static_cast(f)]); } for (int64_t m = 0; m < n_mels; ++m) { - double mel_value = 0.0; + float mel_value = 0.0f; for (int64_t f = 0; f < n_freqs; ++f) { - mel_value += static_cast(mel_basis.index(f, m)) * static_cast(magnitude[static_cast(f)]); + mel_value += mel_basis.index(f, m) * magnitude[static_cast(f)]; } - log_mel.index(m, frame, c, b) = static_cast(std::log(std::max(mel_value, 1e-5))); + log_mel.index(m, frame, c, b) = std::log(std::max(mel_value, 1e-5f)); } } }