Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions benchmarks/bench_nfft_direct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ static void nfft_adjoint_direct_3d(benchmark::State& state) {

// Register benchmarks for direct transforms
BENCH(nfft_forward_direct_1d, SUFFIX)
->Args({8, 25})
->Args({16, 50})
->Args({32, 100})
->Args({64, 200})
->Args({128, 400})
Expand All @@ -176,6 +178,8 @@ BENCH(nfft_forward_direct_1d, SUFFIX)
->Complexity();

BENCH(nfft_adjoint_direct_1d, SUFFIX)
->Args({8, 25})
->Args({16, 50})
->Args({32, 100})
->Args({64, 200})
->Args({128, 400})
Expand Down
86 changes: 78 additions & 8 deletions kernel/nfft/nfft.c
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,47 @@ static inline void sort(const X(plan) *ths)
* for k in I_N^d
* f_hat[k] = sum_{j=0}^{M_total-1} f[j] * exp(-2(pi) k x[j])
*/
/* Block size for the phase recurrence in the direct transforms */
#define NFFT_DIRECT_RECURRENCE_BLOCK 32

/* Accurate phase for exp(+-i 2pi k x): reduce k*x modulo 1 into ~[-1/2,1/2) so COS/SIN see a
* small argument, error does not grow with N. Requires FMA single-rounding semantics. */
static inline R X(reduced_omega)(const R k, const R x)
{
const R n = RINT(k * x); // Nearest integer to k * x.
return K2PI * FFMA(k, x, -n); // Calculate k * x - n witha single rounding, then multiply 2 * pi.
}

void X(trafo_direct)(const X(plan) *ths)
{
C *f_hat = (C*)ths->f_hat, *f = (C*)ths->f;

if (ths->d == 1)
{
/* specialize for univariate case, rationale: faster */
const INT B = NFFT_DIRECT_RECURRENCE_BLOCK;
INT j;
#ifdef _OPENMP
#pragma omp parallel for default(shared) private(j)
#endif
for (j = 0; j < ths->M_total; j++)
{
C v = K(0.0);
INT k_L;
for (k_L = 0; k_L < ths->N_total; k_L++)
const R x = ths->x[j];
const R dphi = K2PI * x; /* |dphi| <= pi: accurate without reduction */
const C dw = COS(dphi) - II * SIN(dphi); /* per-step phase factor exp(-i 2pi x) */
INT k_L = 0;
while (k_L < ths->N_total)
{
R omega = K2PI * ((R)(k_L - ths->N_total/2)) * ths->x[j];
v += f_hat[k_L] * (COS(omega) - II * SIN(omega));
/* Accurate seed exp(-i 2pi (k_L - N/2) x), then recur within the block. */
const R omega = X(reduced_omega)((R)(k_L - ths->N_total/2), x);
C w = COS(omega) - II * SIN(omega);
INT kend = k_L + B; if (kend > ths->N_total) kend = ths->N_total;
for (; k_L < kend; k_L++)
{
v += f_hat[k_L] * w;
w *= dw;
}
}

f[j] = v;
Expand Down Expand Up @@ -217,7 +239,45 @@ void X(adjoint_direct)(const X(plan) *ths)
if (ths->d == 1)
{
/* specialize for univariate case, rationale: faster */
const INT B = NFFT_DIRECT_RECURRENCE_BLOCK;
#ifdef _OPENMP
if (ths->N_total > B)
{
/* Give each thread a disjoint, contiguous range of frequencies [klo,khi) (so the
* f_hat[k] writes are race-free) and run the phase recurrence within it, re-seeded
* every B steps. */
#pragma omp parallel default(shared)
{
const int nt = omp_get_num_threads();
const int tid = omp_get_thread_num();
const INT klo = (INT)(((long long)ths->N_total * tid) / nt);
const INT khi = (INT)(((long long)ths->N_total * (tid + 1)) / nt);
INT j;
for (j = 0; j < ths->M_total; j++)
{
const R x = ths->x[j];
const R dphi = K2PI * x;
const C dw = COS(dphi) + II * SIN(dphi);
INT k_L = klo;
while (k_L < khi)
{
const R omega = X(reduced_omega)((R)(k_L - ths->N_total/2), x);
C w = COS(omega) + II * SIN(omega);
INT kend = k_L + B; if (kend > khi) kend = khi;
for (; k_L < kend; k_L++)
{
f_hat[k_L] += f[j] * w;
w *= dw;
}
}
}
}
}
else
{
/* N <= B: the recurrence spans at most one block per thread-range, so its per-block
* seed/setup costs more than it saves once threaded. Use the plain per-k
* parallelisation. At these tiny N the per-entry phase error is in check. */
INT k_L;
#pragma omp parallel for default(shared) private(k_L)
for (k_L = 0; k_L < ths->N_total; k_L++)
Expand All @@ -229,15 +289,25 @@ void X(adjoint_direct)(const X(plan) *ths)
f_hat[k_L] += f[j] * (COS(omega) + II * SIN(omega));
}
}
}
#else
INT j;
for (j = 0; j < ths->M_total; j++)
{
INT k_L;
for (k_L = 0; k_L < ths->N_total; k_L++)
const R x = ths->x[j];
const R dphi = K2PI * x;
const C dw = COS(dphi) + II * SIN(dphi);
INT k_L = 0;
while (k_L < ths->N_total)
{
R omega = K2PI * ((R)(k_L - ths->N_total / 2)) * ths->x[j];
f_hat[k_L] += f[j] * (COS(omega) + II * SIN(omega));
const R omega = X(reduced_omega)((R)(k_L - ths->N_total/2), x);
C w = COS(omega) + II * SIN(omega);
INT kend = k_L + B; if (kend > ths->N_total) kend = ths->N_total;
for (; k_L < kend; k_L++)
{
f_hat[k_L] += f[j] * w;
w *= dw;
}
}
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions tests/data/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ nfct_adjoint_2d_25_10_50.txt \
nfct_adjoint_2d_25_25_25.txt \
nfct_adjoint_2d_25_25_50.txt \
nfct_adjoint_3d_10_10_10_10.txt \
nfft_1d_1024_50.txt \
nfft_1d_10_1.txt \
nfft_1d_10_10.txt \
nfft_1d_10_20.txt \
Expand All @@ -88,6 +89,7 @@ nfft_1d_50_1.txt \
nfft_1d_50_10.txt \
nfft_1d_50_20.txt \
nfft_1d_50_50.txt \
nfft_1d_512_50.txt \
nfft_2d_10_10_20.txt \
nfft_2d_10_10_50.txt \
nfft_2d_10_20_20.txt \
Expand All @@ -97,6 +99,7 @@ nfft_2d_20_10_50.txt \
nfft_2d_20_20_20.txt \
nfft_2d_20_20_50.txt \
nfft_3d_10_10_10_10.txt \
nfft_adjoint_1d_1024_50.txt \
nfft_adjoint_1d_10_1.txt \
nfft_adjoint_1d_10_10.txt \
nfft_adjoint_1d_10_20.txt \
Expand All @@ -121,6 +124,7 @@ nfft_adjoint_1d_50_1.txt \
nfft_adjoint_1d_50_10.txt \
nfft_adjoint_1d_50_20.txt \
nfft_adjoint_1d_50_50.txt \
nfft_adjoint_1d_512_50.txt \
nfft_adjoint_2d_10_10_20.txt \
nfft_adjoint_2d_10_10_50.txt \
nfft_adjoint_2d_10_20_20.txt \
Expand Down
8 changes: 8 additions & 0 deletions tests/data/generated/nfft_testcases.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ static const testcase_delegate_file_t nfft_1d_50_1 = {setup_file, destroy_file,
static const testcase_delegate_file_t nfft_1d_50_10 = {setup_file, destroy_file, ABSPATH("data/nfft_1d_50_10.txt")};
static const testcase_delegate_file_t nfft_1d_50_20 = {setup_file, destroy_file, ABSPATH("data/nfft_1d_50_20.txt")};
static const testcase_delegate_file_t nfft_1d_50_50 = {setup_file, destroy_file, ABSPATH("data/nfft_1d_50_50.txt")};
static const testcase_delegate_file_t nfft_1d_512_50 = {setup_file, destroy_file, ABSPATH("data/nfft_1d_512_50.txt")};
static const testcase_delegate_file_t nfft_1d_1024_50 = {setup_file, destroy_file, ABSPATH("data/nfft_1d_1024_50.txt")};

static const testcase_delegate_file_t *testcases_1d_file[] =
{
Expand Down Expand Up @@ -55,6 +57,8 @@ static const testcase_delegate_file_t *testcases_1d_file[] =
&nfft_1d_50_10,
&nfft_1d_50_20,
&nfft_1d_50_50,
&nfft_1d_512_50,
&nfft_1d_1024_50,
};
static const testcase_delegate_t **testcases_1d_file_ = (const testcase_delegate_t**)testcases_1d_file;

Expand Down Expand Up @@ -112,6 +116,8 @@ static const testcase_delegate_file_t nfft_adjoint_1d_50_1 = {setup_file, destro
static const testcase_delegate_file_t nfft_adjoint_1d_50_10 = {setup_file, destroy_file, ABSPATH("data/nfft_adjoint_1d_50_10.txt")};
static const testcase_delegate_file_t nfft_adjoint_1d_50_20 = {setup_file, destroy_file, ABSPATH("data/nfft_adjoint_1d_50_20.txt")};
static const testcase_delegate_file_t nfft_adjoint_1d_50_50 = {setup_file, destroy_file, ABSPATH("data/nfft_adjoint_1d_50_50.txt")};
static const testcase_delegate_file_t nfft_adjoint_1d_512_50 = {setup_file, destroy_file, ABSPATH("data/nfft_adjoint_1d_512_50.txt")};
static const testcase_delegate_file_t nfft_adjoint_1d_1024_50 = {setup_file, destroy_file, ABSPATH("data/nfft_adjoint_1d_1024_50.txt")};

static const testcase_delegate_file_t *testcases_adjoint_1d_file[] =
{
Expand Down Expand Up @@ -139,6 +145,8 @@ static const testcase_delegate_file_t *testcases_adjoint_1d_file[] =
&nfft_adjoint_1d_50_10,
&nfft_adjoint_1d_50_20,
&nfft_adjoint_1d_50_50,
&nfft_adjoint_1d_512_50,
&nfft_adjoint_1d_1024_50,
};
static const testcase_delegate_t **testcases_adjoint_1d_file_ = (const testcase_delegate_t**)testcases_adjoint_1d_file;

Expand Down
Loading
Loading