Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions src/ltx_audio_vae.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <numeric>
#include <string>
#include <vector>
#include <complex>

#include "ggml_extend.hpp"

Expand Down Expand Up @@ -200,6 +201,37 @@ namespace LTXV {
return squeeze_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(tensor));
}

static inline bool is_power_of_two(size_t x) {
return x > 0 && (x & (x - 1)) == 0;
}

static void fft(std::vector<std::complex<float>>& a) {
constexpr float ktau = 6.28318530717958647692f;
size_t n = a.size();
GGML_ASSERT(is_power_of_two(n) && "Filter length must be a power of two for FFT-based spectrogram computation");
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j)
std::swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
float ang = -ktau / len;
std::complex<float> wlen{std::cos(ang), std::sin(ang)};
for (int i = 0; i < n; i += len) {
std::complex<float> w{1, 0};
for (int j = 0; j < len / 2; j++) {
std::complex<float> u = a[i + j], v = a[i + j + len / 2] * w;
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w = w * wlen;
}
}
}
}

static sd::Tensor<float> compute_log_mel_spectrogram(const sd::Tensor<float>& waveform_in,
const sd::Tensor<float>& forward_basis,
const sd::Tensor<float>& mel_basis,
Expand All @@ -208,19 +240,20 @@ namespace LTXV {
GGML_ASSERT(forward_basis.dim() >= 3);
GGML_ASSERT(mel_basis.dim() >= 2);

const int64_t time = waveform.shape()[0];
const int64_t channels = waveform.shape()[1];
const int64_t batch = waveform.shape()[2];
const int64_t filter_len = forward_basis.shape()[0];
const int64_t basis_freq2 = forward_basis.shape().back();
const int64_t n_freqs = basis_freq2 / 2;
const int64_t n_mels = mel_basis.shape()[1];
const int64_t time = waveform.shape()[0];
const int64_t channels = waveform.shape()[1];
const int64_t batch = waveform.shape()[2];
const int64_t filter_len = forward_basis.shape()[0];
const int64_t n_freqs = filter_len / 2 + 1;
const int64_t n_mels = mel_basis.shape()[1];

const int64_t left_pad = std::max<int64_t>(0, filter_len - hop_length);
const int64_t padded_time = time + left_pad;
const int64_t frame_count = padded_time < filter_len ? 0 : 1 + (padded_time - filter_len) / hop_length;

sd::Tensor<float> log_mel({n_mels, frame_count, channels, batch});
std::vector<float> padded(static_cast<size_t>(padded_time), 0.0f);
std::vector<std::complex<float>> fft_buffer(static_cast<size_t>(filter_len));
std::vector<float> magnitude(static_cast<size_t>(n_freqs), 0.0f);

for (int64_t b = 0; b < batch; ++b) {
Expand All @@ -232,23 +265,27 @@ namespace LTXV {

for (int64_t frame = 0; frame < frame_count; ++frame) {
const int64_t frame_offset = frame * hop_length;

for (int64_t k = 0; k < filter_len; ++k) {
float gate_weight = forward_basis.index(k, 0, 0);

float gated_sample = padded[static_cast<size_t>(frame_offset + k)] * gate_weight;

fft_buffer[static_cast<size_t>(k)] = {gated_sample, 0.0f};
}

fft(fft_buffer);

for (int64_t f = 0; f < n_freqs; ++f) {
double real = 0.0;
double imag = 0.0;
for (int64_t k = 0; k < filter_len; ++k) {
const float sample = padded[static_cast<size_t>(frame_offset + k)];
real += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f));
imag += static_cast<double>(sample) * static_cast<double>(forward_basis.index(k, 0, f + n_freqs));
}
magnitude[static_cast<size_t>(f)] = static_cast<float>(std::sqrt(real * real + imag * imag));
magnitude[static_cast<size_t>(f)] = std::abs(fft_buffer[static_cast<size_t>(f)]);
}

for (int64_t m = 0; m < n_mels; ++m) {
double mel_value = 0.0;
float mel_value = 0.0f;
for (int64_t f = 0; f < n_freqs; ++f) {
mel_value += static_cast<double>(mel_basis.index(f, m)) * static_cast<double>(magnitude[static_cast<size_t>(f)]);
mel_value += mel_basis.index(f, m) * magnitude[static_cast<size_t>(f)];
}
log_mel.index(m, frame, c, b) = static_cast<float>(std::log(std::max(mel_value, 1e-5)));
log_mel.index(m, frame, c, b) = std::log(std::max(mel_value, 1e-5f));
}
}
}
Expand Down
Loading