From 9faf26536eeb040088d97e1707c9df734a5a9d47 Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 11 May 2026 19:36:48 -0400 Subject: [PATCH 1/7] arena: add allocator + plan helper + tests --- src/TiledArray/tensor/arena.h | 159 ++++++++++++++++++++++++++++++++++ tests/CMakeLists.txt | 1 + tests/arena.cpp | 131 ++++++++++++++++++++++++++++ 3 files changed, 291 insertions(+) create mode 100644 src/TiledArray/tensor/arena.h create mode 100644 tests/arena.cpp diff --git a/src/TiledArray/tensor/arena.h b/src/TiledArray/tensor/arena.h new file mode 100644 index 0000000000..ede7590b38 --- /dev/null +++ b/src/TiledArray/tensor/arena.h @@ -0,0 +1,159 @@ +/// Arena implementation +#ifndef TILEDARRAY_TENSOR_ARENA_H__INCLUDED +#define TILEDARRAY_TENSOR_ARENA_H__INCLUDED + +#include "TiledArray/config.h" +#include "TiledArray/error.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace TiledArray { +namespace detail { + +/// Kill switch: when true, hooks fall back to the legacy heap path. +inline bool& arena_disabled() { + static bool flag = false; + return flag; +} + +/// One-shot bump allocator; slab is co-owned via aliasing shared_ptrs. +class Arena { + public: + explicit Arena(std::pmr::memory_resource* mr = + std::pmr::new_delete_resource()) noexcept + : resource_(mr) { + TA_ASSERT(resource_ != nullptr); + } + + Arena(const Arena&) = delete; + Arena& operator=(const Arena&) = delete; + Arena(Arena&&) noexcept = default; + Arena& operator=(Arena&&) noexcept = default; + ~Arena() = default; + + /// Allocate the slab once; zero_init clears it for accumulation kernels. + void reserve(std::size_t bytes, bool zero_init = false) { + TA_ASSERT(capacity_ == 0); + TA_ASSERT(bytes > 0); + void* raw = resource_->allocate(bytes, alignof(std::max_align_t)); + auto* mr = resource_; + auto deleter = [mr, bytes](std::byte* p) noexcept { + mr->deallocate(p, bytes, alignof(std::max_align_t)); + }; + slab_ = std::shared_ptr(static_cast(raw), + std::move(deleter)); + capacity_ = bytes; + cursor_ = 0; + if (zero_init) std::memset(slab_.get(), 0, bytes); + } + + /// Aliasing view at a caller-aligned offset. + template + std::shared_ptr slice(std::size_t offset, + std::size_t /*n_elem*/) const { + TA_ASSERT(slab_); + TA_ASSERT(offset % alignof(T) == 0); + TA_ASSERT(offset <= capacity_); + auto* p = reinterpret_cast(slab_.get() + offset); + return std::shared_ptr(slab_, p); + } + + /// Bump-allocate n elements of T; result is T-aligned. + template + std::shared_ptr claim(std::size_t n) { + TA_ASSERT(slab_); + auto base = reinterpret_cast(slab_.get() + cursor_); + auto aligned = (base + alignof(T) - 1) & ~(alignof(T) - 1); + std::size_t pad = static_cast(aligned - base); + std::size_t consumed = pad + n * sizeof(T); + TA_ASSERT(cursor_ + consumed <= capacity_); + cursor_ += consumed; + return std::shared_ptr(slab_, reinterpret_cast(aligned)); + } + + std::size_t capacity() const noexcept { return capacity_; } + std::size_t cursor() const noexcept { return cursor_; } + std::size_t remaining() const noexcept { return capacity_ - cursor_; } + bool empty() const noexcept { return cursor_ == 0; } + std::pmr::memory_resource* resource() const noexcept { return resource_; } + + private: + std::pmr::memory_resource* resource_; + std::shared_ptr slab_; + std::size_t capacity_ = 0; + std::size_t cursor_ = 0; +}; + +/// Per-cell offsets and total slab size produced by plan(). +struct ArenaPlan { + std::vector offsets; + std::size_t total_bytes = 0; +}; + +/// Cache-line-floor alignment used by production callers. +inline constexpr std::size_t kArenaCachelineAlign = 128; + +/// Round bytes up to a power-of-two alignment. +inline std::size_t arena_align_up(std::size_t bytes, + std::size_t alignment) noexcept { + return (bytes + alignment - 1) & ~(alignment - 1); +} + +/// Pre-walk cells once to compute offsets and total bytes. +template +ArenaPlan plan(std::size_t N_cells, + ShapeFn&& shape_fn, + std::size_t element_size, + std::size_t alignment) { + ArenaPlan out; + out.offsets.resize(N_cells); + std::size_t total = 0; + for (std::size_t ord = 0; ord < N_cells; ++ord) { + out.offsets[ord] = total; + auto&& r = shape_fn(ord); + std::size_t bytes = r.volume() * element_size; + total += arena_align_up(bytes, alignment); + } + out.total_bytes = total; + return out; +} + +/// PMR adapter over an Arena; deallocation is a no-op (slab-owned lifetime). +class ArenaResource final : public std::pmr::memory_resource { + public: + explicit ArenaResource(Arena* arena) noexcept : arena_(arena) { + TA_ASSERT(arena != nullptr); + } + + Arena* arena() const noexcept { return arena_; } + + protected: + void* do_allocate(std::size_t bytes, + std::size_t alignment) override { + auto h = arena_->claim(arena_align_up(bytes, alignment)); + return h.get(); + } + + void do_deallocate(void* /*p*/, + std::size_t /*bytes*/, + std::size_t /*alignment*/) override {} + + bool do_is_equal( + const std::pmr::memory_resource& other) const noexcept override { + return this == &other; + } + + private: + Arena* arena_; +}; + +} +} + +#endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a30770fb18..bbd3b16247 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -102,6 +102,7 @@ set(ta_test_src_files ta_test.cpp linalg.cpp cp.cpp btas.cpp + arena.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) diff --git a/tests/arena.cpp b/tests/arena.cpp new file mode 100644 index 0000000000..46273e8645 --- /dev/null +++ b/tests/arena.cpp @@ -0,0 +1,131 @@ +#include "TiledArray/tensor/arena.h" + +#include "tiledarray.h" +#include "unit_test_config.h" + +#include +#include +#include +#include + +using TiledArray::detail::Arena; +using TiledArray::detail::ArenaPlan; +using TiledArray::detail::ArenaResource; +using TiledArray::detail::plan; + +namespace { +// Minimal Range-like shim for plan() tests: supports only volume(). +struct FakeRange { + std::size_t v; + std::size_t volume() const noexcept { return v; } +}; +} + +BOOST_AUTO_TEST_SUITE(arena_suite, TA_UT_LABEL_SERIAL) + +BOOST_AUTO_TEST_CASE(default_arena_is_empty) { + Arena a; + BOOST_CHECK_EQUAL(a.capacity(), 0u); + BOOST_CHECK_EQUAL(a.cursor(), 0u); + BOOST_CHECK(a.empty()); + BOOST_CHECK(a.resource() != nullptr); +} + +BOOST_AUTO_TEST_CASE(reserve_initializes_capacity) { + Arena a; + a.reserve(1024); + BOOST_CHECK_EQUAL(a.capacity(), 1024u); + BOOST_CHECK_EQUAL(a.cursor(), 0u); + BOOST_CHECK_EQUAL(a.remaining(), 1024u); +} + +BOOST_AUTO_TEST_CASE(reserve_zero_init_clears_slab) { + Arena a; + a.reserve(64, /*zero_init=*/true); + auto h = a.slice(0, 64); + for (std::size_t i = 0; i < 64; ++i) BOOST_CHECK_EQUAL(h[i], 0u); +} + +BOOST_AUTO_TEST_CASE(slice_random_access_and_aliasing) { + Arena a; + a.reserve(1024); + std::shared_ptr p1 = a.slice(0, 4); + std::shared_ptr p2 = a.slice(64, 4); + for (int i = 0; i < 4; ++i) p1[i] = double(i); + for (int i = 0; i < 4; ++i) p2[i] = double(10 + i); + for (int i = 0; i < 4; ++i) BOOST_CHECK_EQUAL(p1[i], double(i)); + for (int i = 0; i < 4; ++i) BOOST_CHECK_EQUAL(p2[i], double(10 + i)); + BOOST_CHECK(static_cast(&p2[0]) >= static_cast(&p1[4])); +} + +BOOST_AUTO_TEST_CASE(claim_advances_cursor_and_aligns) { + Arena a; + a.reserve(1024); + std::shared_ptr h = a.claim(10); + BOOST_REQUIRE(h.get() != nullptr); + BOOST_CHECK_EQUAL(reinterpret_cast(h.get()) % alignof(double), + 0u); + BOOST_CHECK(a.cursor() >= 10u * sizeof(double)); +} + +BOOST_AUTO_TEST_CASE(slab_survives_arena_destruction) { + std::shared_ptr survivor; + { + Arena tmp; + tmp.reserve(256); + survivor = tmp.claim(10); + for (int i = 0; i < 10; ++i) survivor[i] = -i; + } + for (int i = 0; i < 10; ++i) BOOST_CHECK_EQUAL(survivor[i], -i); +} + +BOOST_AUTO_TEST_CASE(plan_uniform_cells) { + ArenaPlan p = plan( + /*N_cells=*/6, + /*shape_fn=*/[](std::size_t /*ord*/) { return FakeRange{10}; }, + /*element_size=*/sizeof(double), + /*alignment=*/alignof(double)); + BOOST_CHECK_EQUAL(p.total_bytes, 6u * 10u * sizeof(double)); + BOOST_CHECK_EQUAL(p.offsets.size(), 6u); + BOOST_CHECK_EQUAL(p.offsets[0], 0u); + BOOST_CHECK_EQUAL(p.offsets[5], 5u * 10u * sizeof(double)); +} + +BOOST_AUTO_TEST_CASE(plan_variable_cells_match_pivot_doc_example) { + ArenaPlan p = plan( + /*N_cells=*/12, + /*shape_fn=*/[](std::size_t /*ord*/) { return FakeRange{20}; }, + /*element_size=*/sizeof(double), + /*alignment=*/alignof(double)); + BOOST_CHECK_EQUAL(p.total_bytes, 12u * 20u * sizeof(double)); + BOOST_CHECK_EQUAL(p.offsets[1], 20u * sizeof(double)); +} + +BOOST_AUTO_TEST_CASE(plan_then_construct_then_read) { + const std::size_t N = 4; + std::vector volumes = {3, 5, 2, 7}; + auto shape_fn = [&volumes](std::size_t ord) { return FakeRange{volumes[ord]}; }; + ArenaPlan p = plan(N, shape_fn, sizeof(double), alignof(double)); + Arena a; + a.reserve(p.total_bytes); + std::vector> handles(N); + for (std::size_t ord = 0; ord < N; ++ord) { + handles[ord] = a.slice(p.offsets[ord], volumes[ord]); + for (std::size_t i = 0; i < volumes[ord]; ++i) + handles[ord][i] = double(100 * ord + i); + } + for (std::size_t ord = 0; ord < N; ++ord) + for (std::size_t i = 0; i < volumes[ord]; ++i) + BOOST_CHECK_EQUAL(handles[ord][i], double(100 * ord + i)); +} + +BOOST_AUTO_TEST_CASE(arena_resource_is_identity_equal) { + Arena a; + a.reserve(64); + ArenaResource r1(&a); + ArenaResource r2(&a); + BOOST_CHECK(r1.is_equal(r1)); + BOOST_CHECK(!r1.is_equal(r2)); +} + +BOOST_AUTO_TEST_SUITE_END() From 5756befbb916fbc44501c14165bf9b0963056e64 Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 11 May 2026 19:45:36 -0400 Subject: [PATCH 2/7] arena_kernels: add ToT kernels + tests --- src/TiledArray/tensor/arena_kernels.h | 183 ++++++++++++++++++++++++++ tests/CMakeLists.txt | 1 + tests/arena_kernels.cpp | 109 +++++++++++++++ 3 files changed, 293 insertions(+) create mode 100644 src/TiledArray/tensor/arena_kernels.h create mode 100644 tests/arena_kernels.cpp diff --git a/src/TiledArray/tensor/arena_kernels.h b/src/TiledArray/tensor/arena_kernels.h new file mode 100644 index 0000000000..7a86615ac3 --- /dev/null +++ b/src/TiledArray/tensor/arena_kernels.h @@ -0,0 +1,183 @@ +/// Arena kernels for ToT trivial ops and contraction result initialization. + +#ifndef TILEDARRAY_TENSOR_ARENA_KERNELS_H__INCLUDED +#define TILEDARRAY_TENSOR_ARENA_KERNELS_H__INCLUDED + +#include "TiledArray/config.h" +#include "TiledArray/error.h" +#include "TiledArray/tensor/arena.h" + +#include +#include +#include +#include +#include + +namespace TiledArray { +namespace detail { + +namespace { + +/// Build outer storage whose deleter owns arena and alias keep-alive state. +template +std::shared_ptr +make_outer_data(std::size_t n_cells, std::shared_ptr arena_handle, + KeepAlive keep_alive) { + using inner_t = typename OuterTensor::value_type; + std::allocator allocator; + inner_t* raw = allocator.allocate(n_cells); + auto deleter = [allocator = std::move(allocator), arena_handle = std::move(arena_handle), + keep_alive = std::move(keep_alive), n_cells](inner_t* p) mutable { + for (std::size_t i = 0; i < n_cells; ++i) (p + i)->~inner_t(); + allocator.deallocate(p, n_cells); + (void)arena_handle; + (void)keep_alive; + }; + return std::shared_ptr(raw, std::move(deleter)); +} + +} + +/// Apply a unary fill op while preserving each source inner shape. +template +OuterTensor arena_trivial_unary(const SrcOuterTensor& src, FillOp&& fill_op) { + using inner_t = typename OuterTensor::value_type; + using elem_t = typename inner_t::value_type; + const std::size_t N_cells = src.range().volume() * src.nbatch(); + auto shape_fn = [&src](std::size_t ord) -> decltype(auto) { + return src.data()[ord].range(); + }; + ArenaPlan p = plan(N_cells, shape_fn, sizeof(elem_t), alignof(elem_t)); + auto arena = std::make_shared(); + if (p.total_bytes > 0) arena->reserve(p.total_bytes, false); + auto data = + make_outer_data(N_cells, arena, std::shared_ptr{}); + OuterTensor result(src.range(), src.nbatch(), std::move(data)); + for (std::size_t ord = 0; ord < N_cells; ++ord) { + const auto& r = src.data()[ord].range(); + const std::size_t n = r.volume(); + if (n == 0) { + new (result.data() + ord) inner_t(r); + continue; + } + auto elem_data = arena->slice(p.offsets[ord], n); + new (result.data() + ord) inner_t(r, std::move(elem_data)); + fill_op(result.data()[ord].data(), src.data()[ord].data(), n); + } + return result; +} + +/// Apply a binary fill op using the left operand's inner shapes. +template +OuterTensor arena_trivial_binary(const LeftTensor& left, const RightTensor& right, + FillOp&& fill_op) { + using inner_t = typename OuterTensor::value_type; + using elem_t = typename inner_t::value_type; + TA_ASSERT(left.range().volume() == right.range().volume()); + TA_ASSERT(left.nbatch() == right.nbatch()); + const std::size_t N_cells = left.range().volume() * left.nbatch(); + auto shape_fn = [&left](std::size_t ord) -> decltype(auto) { + return left.data()[ord].range(); + }; + ArenaPlan p = plan(N_cells, shape_fn, sizeof(elem_t), alignof(elem_t)); + auto arena = std::make_shared(); + if (p.total_bytes > 0) arena->reserve(p.total_bytes, false); + auto data = + make_outer_data(N_cells, arena, std::shared_ptr{}); + OuterTensor result(left.range(), left.nbatch(), std::move(data)); + for (std::size_t ord = 0; ord < N_cells; ++ord) { + const auto& r = left.data()[ord].range(); + const std::size_t n = r.volume(); + TA_ASSERT(n == right.data()[ord].range().volume()); + if (n == 0) { + new (result.data() + ord) inner_t(r); + continue; + } + auto elem_data = arena->slice(p.offsets[ord], n); + new (result.data() + ord) inner_t(r, std::move(elem_data)); + fill_op(result.data()[ord].data(), left.data()[ord].data(), + right.data()[ord].data(), n); + } + return result; +} + +/// Shallow-permute outer cells while preserving inner data aliases. +template +OuterTensor arena_permute_shallow(const SrcOuterTensor& src, const Perm& perm) { + using inner_t = typename OuterTensor::value_type; + TA_ASSERT(perm); + TA_ASSERT(perm.size() == src.range().rank()); + auto perm_range = perm * src.range(); + const std::size_t N_cells = src.range().volume(); + const std::size_t total_cells = N_cells * src.nbatch(); + const auto src_data_ref = src.data_shared(); + auto data = + make_outer_data(total_cells, + std::make_shared(), + std::move(src_data_ref)); + OuterTensor result(perm_range, src.nbatch(), std::move(data)); + for (std::size_t s = 0; s < N_cells; ++s) { + auto src_idx = src.range().idx(s); + auto tgt_ord = perm_range.ordinal(perm * src_idx); + for (std::size_t b = 0; b < src.nbatch(); ++b) { + const std::size_t s_off = b * N_cells + s; + const std::size_t t_off = b * N_cells + tgt_ord; + const inner_t& src_inner = src.data()[s_off]; + auto src_inner_data = const_cast(src_inner).data_shared(); + new (result.data() + t_off) inner_t(src_inner.range(), src_inner.nbatch(), + std::move(src_inner_data)); + } + } + return result; +} + +/// Allocate a slab-backed outer tile using caller-provided inner shapes. +template +OuterTensor arena_outer_init(const Range& outer_range, std::size_t batch_sz, + ShapeFn&& shape_fn, bool zero_init = true) { + using inner_t = typename OuterTensor::value_type; + using elem_t = typename inner_t::value_type; + using inner_range_t = + std::decay_t()))>; + const std::size_t N_cells = outer_range.volume() * batch_sz; + std::vector ranges; + ranges.reserve(N_cells); + std::vector offsets(N_cells); + std::size_t total_bytes = 0; + for (std::size_t ord = 0; ord < N_cells; ++ord) { + offsets[ord] = total_bytes; + ranges.emplace_back(shape_fn(ord)); + const std::size_t bytes = ranges.back().volume() * sizeof(elem_t); + total_bytes += arena_align_up(bytes, alignof(elem_t)); + } + auto arena = std::make_shared(); + // Arena::reserve requires a non-empty slab. + if (total_bytes > 0) { + arena->reserve(total_bytes, zero_init); + } + auto data = + make_outer_data(N_cells, arena, std::shared_ptr{}); + OuterTensor result(outer_range, batch_sz, std::move(data)); + for (std::size_t ord = 0; ord < N_cells; ++ord) { + auto& r = ranges[ord]; + const std::size_t n = r.volume(); + if (n == 0) { + // Rank-0 empties must preserve Tensor's null-data/no-range invariant. + if (!r) { + new (result.data() + ord) inner_t(); + } else { + new (result.data() + ord) inner_t(r); + } + } else { + auto elem_data = arena->slice(offsets[ord], n); + new (result.data() + ord) inner_t(r, std::move(elem_data)); + } + } + return result; +} + +} +} + +#endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bbd3b16247..511f387426 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -103,6 +103,7 @@ set(ta_test_src_files ta_test.cpp cp.cpp btas.cpp arena.cpp + arena_kernels.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) diff --git a/tests/arena_kernels.cpp b/tests/arena_kernels.cpp new file mode 100644 index 0000000000..d587a64409 --- /dev/null +++ b/tests/arena_kernels.cpp @@ -0,0 +1,109 @@ +/// Unit tests for arena-backed ToT kernels. + +#include "TiledArray/tensor/arena_kernels.h" + +#include "TiledArray/tensor.h" +#include "TiledArray/tensor/arena.h" +#include "tiledarray.h" +#include "unit_test_config.h" + +#include +#include + +namespace TA = TiledArray; +using inner_t = TA::Tensor; +using outer_t = TA::Tensor; + +namespace { + +outer_t make_tot(std::size_t N_outer, std::size_t n_inner, double base = 1.0) { + outer_t outer(TA::Range{static_cast(N_outer)}, 1); + for (std::size_t ord = 0; ord < N_outer; ++ord) { + inner_t inner(TA::Range{static_cast(n_inner)}); + for (std::size_t i = 0; i < n_inner; ++i) + inner.at_ordinal(i) = base + ord * 100.0 + i; + *(outer.data() + ord) = std::move(inner); + } + return outer; +} + +bool tot_equal(const outer_t& a, const outer_t& b) { + if (a.range().volume() != b.range().volume()) return false; + for (std::size_t ord = 0; ord < a.range().volume(); ++ord) { + const inner_t& ai = *(a.data() + ord); + const inner_t& bi = *(b.data() + ord); + if (ai.range().volume() != bi.range().volume()) return false; + for (std::size_t i = 0; i < ai.range().volume(); ++i) + if (ai.at_ordinal(i) != bi.at_ordinal(i)) return false; + } + return true; +} + +} + +BOOST_AUTO_TEST_SUITE(arena_kernels_suite, TA_UT_LABEL_SERIAL) + +BOOST_AUTO_TEST_CASE(trivial_unary_clone_matches_heap_baseline) { + outer_t src = make_tot(4, 5, 1.0); + auto fill = [](double* dst, const double* src, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = src[i]; + }; + outer_t arena_result = + TA::detail::arena_trivial_unary(src, fill); + BOOST_CHECK(tot_equal(arena_result, src)); +} + +BOOST_AUTO_TEST_CASE(trivial_unary_scale_matches_heap_baseline) { + outer_t src = make_tot(4, 5, 1.0); + const double factor = 2.5; + auto fill = [factor](double* dst, const double* src, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = src[i] * factor; + }; + outer_t arena_result = + TA::detail::arena_trivial_unary(src, fill); + outer_t baseline(src.range(), 1); + for (std::size_t ord = 0; ord < src.range().volume(); ++ord) { + inner_t inner((src.data() + ord)->range()); + for (std::size_t i = 0; i < inner.range().volume(); ++i) + inner.at_ordinal(i) = (src.data() + ord)->at_ordinal(i) * factor; + *(baseline.data() + ord) = std::move(inner); + } + BOOST_CHECK(tot_equal(arena_result, baseline)); +} + +BOOST_AUTO_TEST_CASE(trivial_binary_add_matches_heap_baseline) { + outer_t L = make_tot(4, 5, 1.0); + outer_t R = make_tot(4, 5, 0.5); + auto fill = [](double* dst, const double* l, const double* r, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = l[i] + r[i]; + }; + outer_t arena_result = + TA::detail::arena_trivial_binary(L, R, fill); + outer_t baseline(L.range(), 1); + for (std::size_t ord = 0; ord < L.range().volume(); ++ord) { + inner_t inner((L.data() + ord)->range()); + for (std::size_t i = 0; i < inner.range().volume(); ++i) + inner.at_ordinal(i) = (L.data() + ord)->at_ordinal(i) + + (R.data() + ord)->at_ordinal(i); + *(baseline.data() + ord) = std::move(inner); + } + BOOST_CHECK(tot_equal(arena_result, baseline)); +} + +BOOST_AUTO_TEST_CASE(arena_outlives_kernel_call) { + // The result data deleter co-owns the arena. + outer_t arena_result; + { + outer_t src = make_tot(3, 4, 7.0); + auto fill = [](double* dst, const double* src, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = src[i]; + }; + arena_result = TA::detail::arena_trivial_unary(src, fill); + } + for (std::size_t ord = 0; ord < arena_result.range().volume(); ++ord) + for (std::size_t i = 0; i < (arena_result.data() + ord)->range().volume(); ++i) + BOOST_CHECK_EQUAL((arena_result.data() + ord)->at_ordinal(i), + 7.0 + ord * 100.0 + i); +} + +BOOST_AUTO_TEST_SUITE_END() From 5596879ceaa4222fb535c71c31e0449217ec3a05 Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 11 May 2026 20:01:25 -0400 Subject: [PATCH 3/7] arena_einsum: regime-A (outer-Hadamard) plans + dispatch + tests --- src/TiledArray/tensor/arena_einsum.h | 621 ++++++++++++++++++++++++++ src/TiledArray/tensor/arena_kernels.h | 8 +- tests/CMakeLists.txt | 1 + tests/arena_einsum_unit_suite.cpp | 253 +++++++++++ 4 files changed, 881 insertions(+), 2 deletions(-) create mode 100644 src/TiledArray/tensor/arena_einsum.h create mode 100644 tests/arena_einsum_unit_suite.cpp diff --git a/src/TiledArray/tensor/arena_einsum.h b/src/TiledArray/tensor/arena_einsum.h new file mode 100644 index 0000000000..c4b3ce409c --- /dev/null +++ b/src/TiledArray/tensor/arena_einsum.h @@ -0,0 +1,621 @@ +/// Arena-aware ToT einsum: plans, fused kernels, and dispatch. + +#ifndef TILEDARRAY_TENSOR_ARENA_EINSUM_H__INCLUDED +#define TILEDARRAY_TENSOR_ARENA_EINSUM_H__INCLUDED + +#include "TiledArray/error.h" +#include "TiledArray/math/gemm_helper.h" +#include "TiledArray/permutation.h" +#include "TiledArray/tensor/arena.h" +#include "TiledArray/tensor/arena_kernels.h" +#include "TiledArray/tensor/kernels.h" +#include "TiledArray/tensor/type_traits.h" + +#include +#include +#include +#include + +#if defined(_MSC_VER) && _MSC_VER < 1937 // VS 2022 < 17.7 +# define TA_NO_UNIQUE_ADDRESS [[msvc::no_unique_address]] +#else +# define TA_NO_UNIQUE_ADDRESS [[no_unique_address]] +#endif + +namespace TiledArray::detail { + +/// Specifies how an inner-cell range is derived from operand inner cells. +enum class ArenaInnerShapeKind { + left_range, // Hadamard inner; Scale tot_x_t + right_range, // Scale t_x_tot + gemm_result_range // inner Contraction (uses inner_gh) +}; + +/// Inner-shape derivation plan: kind + (optional) inner GemmHelper. +struct ArenaInnerShapePlan { + ArenaInnerShapeKind kind; + std::optional inner_gh; // only for gemm_result_range + + /// Derives one result inner range from operand inner cells. + template + ResultInnerRange make(const LInner& l, const RInner& r) const { + switch (kind) { + case ArenaInnerShapeKind::left_range: + return l.range(); + case ArenaInnerShapeKind::right_range: + return r.range(); + case ArenaInnerShapeKind::gemm_result_range: + TA_ASSERT(inner_gh.has_value()); + return inner_gh->template make_result_range( + l.range(), r.range()); + } + TA_ASSERT(false); + return ResultInnerRange{}; + } +}; + +/// Derives result ranges and constructs non-empty inner cells in one arena slab. +template +class ContractionArenaPlan { + public: + /// Stores the inner shape plan used to construct result cells. + explicit ContractionArenaPlan(ArenaInnerShapePlan p) + : inner_plan_(std::move(p)) {} + + /// Constructs a result tile whose non-empty inner cells alias arena storage. + Result reserve_and_construct(const Left& left, const Right& right, + const math::GemmHelper& outer_gh) const; + + private: + ArenaInnerShapePlan inner_plan_{}; +}; + +/// True when the result is a tensor-of-tensor with TA tensor inner cells. +template +inline constexpr bool is_contraction_arena_tot_v = + is_tensor_of_tensor_v && + is_ta_tensor_v; + +/// Stores an arena plan for ToT results and std::monostate otherwise. +template +using arena_plan_storage_t = std::conditional_t< + is_contraction_arena_tot_v, + std::optional>, + std::monostate>; + +/// Builds a contraction arena plan when the result and inner permutation allow it. +template +auto make_contraction_arena_plan( + ArenaInnerShapeKind inner_kind, + std::optional inner_gh, + const Permutation& inner_perm) + -> std::optional> { + if (arena_disabled()) return std::nullopt; + if constexpr (!is_contraction_arena_tot_v) { + return std::nullopt; + } else { + if (bool(inner_perm) && !inner_perm.is_identity()) return std::nullopt; + if (inner_kind != ArenaInnerShapeKind::gemm_result_range) inner_gh.reset(); + else if (!inner_gh.has_value()) return std::nullopt; + return std::optional>( + std::in_place, + ArenaInnerShapePlan{inner_kind, std::move(inner_gh)}); + } +} + +/// Reserves arena storage and constructs the result tensor-of-tensor tile. +template +Result ContractionArenaPlan::reserve_and_construct( + const Left& left, const Right& right, + const math::GemmHelper& outer_gh) const { + using inner_t = typename Result::value_type; + using inner_range_t = typename inner_t::range_type; + using integer = math::blas::integer; + + auto outer_range = + outer_gh.template make_result_range( + left.range(), right.range()); + + integer M, N, K; + outer_gh.compute_matrix_sizes(M, N, K, left.range(), right.range()); + const integer lda = + (outer_gh.left_op() == math::blas::NoTranspose) ? K : M; + const integer ldb = + (outer_gh.right_op() == math::blas::NoTranspose) ? N : K; + TA_ASSERT(left.nbatch() == right.nbatch()); + const std::size_t batch_sz = static_cast(left.nbatch()); + const std::size_t mn = + static_cast(M) * static_cast(N); + + auto range_for = [&](std::size_t ord) -> inner_range_t { + if (mn == 0) return inner_range_t{}; + const integer b = static_cast(ord / mn); + const integer rem = static_cast(ord % mn); + const integer m = rem / N; + const integer n = rem % N; + + if (inner_plan_.kind == ArenaInnerShapeKind::left_range) { + if constexpr (is_tensor_of_tensor_v) { + const auto* lbase = left.batch_data(static_cast(b)); + for (integer k = 0; k != K; ++k) { + const auto aoff = + (outer_gh.left_op() == math::blas::NoTranspose) + ? m * lda + k : k * lda + m; + const auto& lc = *(lbase + aoff); + if (!lc.empty()) return lc.range(); + } + } + return inner_range_t{}; + } + if (inner_plan_.kind == ArenaInnerShapeKind::right_range) { + if constexpr (is_tensor_of_tensor_v) { + const auto* rbase = right.batch_data(static_cast(b)); + for (integer k = 0; k != K; ++k) { + const auto boff = + (outer_gh.right_op() == math::blas::NoTranspose) + ? k * ldb + n : n * ldb + k; + const auto& rc = *(rbase + boff); + if (!rc.empty()) return rc.range(); + } + } + return inner_range_t{}; + } + // gemm_result_range needs both operands to be ToT. + if constexpr (is_tensor_of_tensor_v && is_tensor_of_tensor_v) { + const auto* lbase = left.batch_data(static_cast(b)); + const auto* rbase = right.batch_data(static_cast(b)); + for (integer k = 0; k != K; ++k) { + const auto aoff = + (outer_gh.left_op() == math::blas::NoTranspose) + ? m * lda + k : k * lda + m; + const auto boff = + (outer_gh.right_op() == math::blas::NoTranspose) + ? k * ldb + n : n * ldb + k; + const auto& lc = *(lbase + aoff); + const auto& rc = *(rbase + boff); + if (lc.empty() || rc.empty()) continue; + return inner_plan_.template make(lc, rc); + } + } + return inner_range_t{}; + }; + + return detail::arena_outer_init( + outer_range, batch_sz, range_for, kArenaCachelineAlign, + /*zero_init=*/true); +} + +/// Accumulates a contraction into an already-allocated result cell. +template +void fused_contraction_inplace(Result& result, const Left& left, + const Right& right, Scalar alpha, + const math::GemmHelper& gh) { + if (left.empty() || right.empty()) return; + TA_ASSERT(!result.empty()); + result.gemm(left, right, alpha, gh); +} + +/// Accumulates an elementwise product into an already-allocated result cell. +template +void fused_hadamard_inplace(Result& result, const Left& left, + const Right& right) { + if (left.empty() || right.empty()) return; + TA_ASSERT(!result.empty()); + inplace_tensor_op( + [](typename Result::value_type& MADNESS_RESTRICT r, + const typename Left::value_type& MADNESS_RESTRICT l, + const typename Right::value_type& MADNESS_RESTRICT rr) { + r += l * rr; + }, + result, left, right); +} + +/// Accumulates a scaled elementwise product into an allocated result cell. +template +void fused_hadamard_scaled_inplace(Result& result, const Left& left, + const Right& right, Scalar factor) { + if (left.empty() || right.empty()) return; + TA_ASSERT(!result.empty()); + // Preserve historical grouping: r += (l * rr) * factor. + inplace_tensor_op( + [factor](typename Result::value_type& MADNESS_RESTRICT r, + const typename Left::value_type& MADNESS_RESTRICT l, + const typename Right::value_type& MADNESS_RESTRICT rr) { + r += (l * rr) * factor; + }, + result, left, right); +} + +/// Accumulates a ToT cell scaled by a scalar right operand. +template +void fused_scale_tot_x_t_inplace(Result& result, const Left& left, + const Scalar& s) { + if (left.empty()) return; + TA_ASSERT(!result.empty()); + inplace_tensor_op( + [s](typename Result::value_type& MADNESS_RESTRICT r, + const typename Left::value_type& MADNESS_RESTRICT l) { + r += l * s; + }, + result, left); +} + +/// Accumulates a ToT right operand scaled by a scalar left operand. +template +void fused_scale_t_x_tot_inplace(Result& result, const Scalar& s, + const Right& right) { + if (right.empty()) return; + TA_ASSERT(!result.empty()); + inplace_tensor_op( + [s](typename Result::value_type& MADNESS_RESTRICT r, + const typename Right::value_type& MADNESS_RESTRICT rr) { + r += rr * s; + }, + result, right); +} + +/// Creates a fused contraction callback. +template +auto make_fused_contraction_lambda(Op contrreduce_op) { + return [contrreduce_op](Result& result, const Left& left, + const Right& right) { + TA_ASSERT(!contrreduce_op.perm()); + fused_contraction_inplace(result, left, right, + contrreduce_op.factor(), + contrreduce_op.gemm_helper()); + }; +} + +/// Creates a fused Hadamard callback. +template +auto make_fused_hadamard_lambda() { + return [](Result& result, const Left& left, const Right& right) { + fused_hadamard_inplace(result, left, right); + }; +} + +/// Creates a fused scaled-Hadamard callback. +template +auto make_fused_hadamard_scaled_lambda(Scalar factor) { + return [factor](Result& result, const Left& left, const Right& right) { + fused_hadamard_scaled_inplace(result, left, right, factor); + }; +} + +/// Creates a fused ToT-times-scalar callback. +template +auto make_fused_scale_tot_x_t_lambda() { + return [](Result& result, const Left& left, const Right& right) { + fused_scale_tot_x_t_inplace(result, left, right); + }; +} + +/// Creates a fused scalar-times-ToT callback. +template +auto make_fused_scale_t_x_tot_lambda() { + return [](Result& result, const Left& left, const Right& right) { + fused_scale_t_x_tot_inplace(result, left, right); + }; +} + +/// Discriminates the per-cell operation used by the arena regime-A path. +enum class RegimeAInnerKind { + hadamard, + contraction, + scale_left, // ToT × plain T → ToT (right operand contributes scalars) + scale_right // plain T × ToT → ToT (left operand contributes scalars) +}; + +/// Holds the inner operation plan for arena regime-A dispatch. +template +struct RegimeAArenaPlan { + using Annot = ::Einsum::Index; + + bool active = false; + RegimeAInnerKind kind = RegimeAInnerKind::hadamard; + + // Exactly one plan optional is engaged; optionals avoid default construction. + std::optional> h_plan{}; + std::optional> c_plan{}; + + /// Derives the result inner range from a non-empty input-cell pair. + template + InnerRange derive_inner_range(const LRange& l_range, + const RRange& r_range) const { + switch (kind) { + case RegimeAInnerKind::hadamard: + TA_ASSERT(h_plan.has_value()); + return h_plan->perm.AC * l_range; + case RegimeAInnerKind::contraction: { + TA_ASSERT(c_plan.has_value()); + const auto& p = *c_plan; + using PlanIndices = std::remove_cvref_t; + using PlanIndex = typename PlanIndices::value_type; + using Extent = std::remove_cv_t().extent())::value_type>; + using ExtentMap = ::Einsum::index::IndexMap; + ExtentMap extent = (ExtentMap{p.A, l_range.extent()} | + ExtentMap{p.B, r_range.extent()}); + container::vector rng; + rng.reserve(p.e.size()); + for (auto&& ix : p.e) rng.emplace_back(extent[ix]); + return InnerRange(TiledArray::Range(rng)); + } + case RegimeAInnerKind::scale_left: + // Scale-left preserves the ToT operand's inner range. + return InnerRange(l_range); + case RegimeAInnerKind::scale_right: + return InnerRange(r_range); + } + TA_ASSERT(false && "RegimeAInnerKind: unhandled kind"); + return InnerRange{}; + } + + /// Accumulates one input-cell pair into the result cell. + template + void accumulate(ResultCell& r, const LCell& l, const RCell& rr) const { + switch (kind) { + case RegimeAInnerKind::hadamard: { + if constexpr (is_ta_tensor_v && is_ta_tensor_v) { + if (l.empty() || rr.empty()) return; + TA_ASSERT(h_plan.has_value()); + const auto& hp = *h_plan; + TA_ASSERT((hp.no_perm || hp.perm_b) && + "regime-A arena plan should be inactive for unsupported " + "Hadamard perm branches (perm_to_c/perm_a/else)"); + fused_hadamard_inplace(r, l, rr); + } + return; + } + case RegimeAInnerKind::contraction: { + if constexpr (is_ta_tensor_v && is_ta_tensor_v) { + if (l.empty() || rr.empty()) return; + TA_ASSERT(c_plan.has_value()); + auto prod = tensor_contract(l, rr, *c_plan); + if (!prod.empty()) r.add_to(prod); + } + return; + } + case RegimeAInnerKind::scale_left: { + // Scale-left receives a ToT inner cell and a scalar. + if constexpr (is_ta_tensor_v && !is_ta_tensor_v) { + if (l.empty()) return; + fused_scale_tot_x_t_inplace(r, l, rr); + } + return; + } + case RegimeAInnerKind::scale_right: { + if constexpr (!is_ta_tensor_v && is_ta_tensor_v) { + if (rr.empty()) return; + fused_scale_t_x_tot_inplace(r, l, rr); + } + return; + } + } + } +}; + +/// Builds an arena regime-A plan when result and permutation constraints allow it. +template +auto make_regime_a_arena_plan(const A& a, const B& b, const Inner& inner, + const PermT& inner_perm) + -> RegimeAArenaPlan { + using Plan = RegimeAArenaPlan; + Plan plan; + if (arena_disabled()) return plan; + if constexpr (!is_tensor_of_tensor_v || + !is_ta_tensor_v) { + return plan; + } else { + if (bool(inner_perm) && !inner_perm.is_identity()) return plan; + + using ArrayA_t = std::remove_cvref_t; + using ArrayB_t = std::remove_cvref_t; + constexpr bool a_is_tot = + is_tensor_of_tensor_v; + constexpr bool b_is_tot = + is_tensor_of_tensor_v; + + if constexpr (a_is_tot && b_is_tot) { + if (static_cast(inner.h)) { + plan.kind = RegimeAInnerKind::hadamard; + plan.h_plan.emplace(inner.A, inner.B, inner.C); + const auto& hp = *plan.h_plan; + if (!(hp.no_perm || hp.perm_b)) return plan; + } else { + plan.kind = RegimeAInnerKind::contraction; + plan.c_plan.emplace(inner.A, inner.B, inner.C); + const auto& cp = *plan.c_plan; + if (cp.do_perm.C) return plan; + } + } else if constexpr (a_is_tot && !b_is_tot) { + plan.kind = RegimeAInnerKind::scale_left; + } else if constexpr (!a_is_tot && b_is_tot) { + plan.kind = RegimeAInnerKind::scale_right; + } else { + return plan; + } + plan.active = true; + (void)a; + (void)b; + return plan; + } +} + +/// Runs the arena regime-A path for one H-slice when the plan is active. +template +bool run_regime_a_arena(const Plan& plan, const HIndex& h, std::size_t batch, + const TermA& A, const TermB& B, const TermC& C, + LocalTiles& C_local_tiles, const Tiles& tiles, + const Trange& trange) { + if (!plan.active) return false; + + using ResultTensor = typename LocalTiles::value_type::second_type; + // Guard avoids naming inner-cell APIs for non-ToT instantiations. + using ArrayA_t = std::remove_cvref_t; + using ArrayB_t = std::remove_cvref_t; + constexpr bool a_is_tot = + is_tensor_of_tensor_v; + constexpr bool b_is_tot = + is_tensor_of_tensor_v; + if constexpr (!is_tensor_of_tensor_v || + !is_ta_tensor_v || + (!a_is_tot && !b_is_tot)) { + (void)h; (void)batch; (void)A; (void)B; (void)C; + (void)C_local_tiles; (void)tiles; (void)trange; + return false; + } else { + using InnerT = typename ResultTensor::value_type; + using InnerRange = typename InnerT::range_type; + + const auto& pa = A.permutation; + const auto& pb = B.permutation; + const auto& pc = C.permutation; + auto const c = apply(pc, h); + + if constexpr (a_is_tot && b_is_tot) { + using IIndex = ::Einsum::index::Index; + auto range_for = [&](std::size_t k) -> InnerRange { + if (k >= batch) return InnerRange{}; + for (IIndex i : tiles) { + const auto pahi_inv = apply_inverse(pa, h + i); + const auto pbhi_inv = apply_inverse(pb, h + i); + if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; + auto ai = A.array.find(pahi_inv).get(); + auto bi = B.array.find(pbhi_inv).get(); + if (pa) ai = ai.permute(pa); + if (pb) bi = bi.permute(pb); + auto shape = trange.tile(i); + ai = ai.reshape(shape, batch); + bi = bi.reshape(shape, batch); + auto aik = ai.batch(k); + auto bik = bi.batch(k); + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + for (decltype(vol) j = 0; j < vol; ++j) { + const auto& l_inner = aik.data()[j]; + const auto& r_inner = bik.data()[j]; + if (l_inner.empty() || r_inner.empty()) continue; + return plan.template derive_inner_range( + l_inner.range(), r_inner.range()); + } + } + return InnerRange{}; + }; + + ResultTensor tile = arena_outer_init( + TiledArray::Range{batch}, /*batch_sz=*/1, range_for, + kArenaCachelineAlign, /*zero_init=*/true); + + for (IIndex i : tiles) { + const auto pahi_inv = apply_inverse(pa, h + i); + const auto pbhi_inv = apply_inverse(pb, h + i); + if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; + auto ai = A.array.find(pahi_inv).get(); + auto bi = B.array.find(pbhi_inv).get(); + if (pa) ai = ai.permute(pa); + if (pb) bi = bi.permute(pb); + auto shape = trange.tile(i); + ai = ai.reshape(shape, batch); + bi = bi.reshape(shape, batch); + for (std::size_t k = 0; k < batch; ++k) { + auto& cell = tile({k}); + if (cell.empty()) continue; + auto aik = ai.batch(k); + auto bik = bi.batch(k); + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + for (decltype(vol) j = 0; j < vol; ++j) { + const auto& l_inner = aik.data()[j]; + const auto& r_inner = bik.data()[j]; + plan.accumulate(cell, l_inner, r_inner); + } + } + } + + auto shape = apply_inverse(pc, C.array.trange().tile(c)); + tile = tile.reshape(shape); + if (pc) tile = tile.permute(pc); + C_local_tiles.emplace_back(std::move(c), std::move(tile)); + return true; + } else { + // Scale path has exactly one ToT operand and one scalar-cell operand. + using IIndex = ::Einsum::index::Index; + auto range_for = [&](std::size_t k) -> InnerRange { + if (k >= batch) return InnerRange{}; + for (IIndex i : tiles) { + const auto pahi_inv = apply_inverse(pa, h + i); + const auto pbhi_inv = apply_inverse(pb, h + i); + if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; + auto ai = A.array.find(pahi_inv).get(); + auto bi = B.array.find(pbhi_inv).get(); + if (pa) ai = ai.permute(pa); + if (pb) bi = bi.permute(pb); + auto shape = trange.tile(i); + ai = ai.reshape(shape, batch); + bi = bi.reshape(shape, batch); + auto aik = ai.batch(k); + auto bik = bi.batch(k); + if constexpr (a_is_tot) { + auto vol = aik.total_size(); + for (decltype(vol) j = 0; j < vol; ++j) { + const auto& l_inner = aik.data()[j]; + if (l_inner.empty()) continue; + return InnerRange(l_inner.range()); + } + } else { + auto vol = bik.total_size(); + for (decltype(vol) j = 0; j < vol; ++j) { + const auto& r_inner = bik.data()[j]; + if (r_inner.empty()) continue; + return InnerRange(r_inner.range()); + } + } + } + return InnerRange{}; + }; + + ResultTensor tile = arena_outer_init( + TiledArray::Range{batch}, /*batch_sz=*/1, range_for, + kArenaCachelineAlign, /*zero_init=*/true); + + for (IIndex i : tiles) { + const auto pahi_inv = apply_inverse(pa, h + i); + const auto pbhi_inv = apply_inverse(pb, h + i); + if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue; + auto ai = A.array.find(pahi_inv).get(); + auto bi = B.array.find(pbhi_inv).get(); + if (pa) ai = ai.permute(pa); + if (pb) bi = bi.permute(pb); + auto shape = trange.tile(i); + ai = ai.reshape(shape, batch); + bi = bi.reshape(shape, batch); + for (std::size_t k = 0; k < batch; ++k) { + auto& cell = tile({k}); + if (cell.empty()) continue; + auto aik = ai.batch(k); + auto bik = bi.batch(k); + auto vol = aik.total_size(); + TA_ASSERT(vol == bik.total_size()); + for (decltype(vol) j = 0; j < vol; ++j) { + const auto& l_elem = aik.data()[j]; + const auto& r_elem = bik.data()[j]; + plan.accumulate(cell, l_elem, r_elem); + } + } + } + + auto shape = apply_inverse(pc, C.array.trange().tile(c)); + tile = tile.reshape(shape); + if (pc) tile = tile.permute(pc); + C_local_tiles.emplace_back(std::move(c), std::move(tile)); + return true; + } + } +} + +} // namespace TiledArray::detail + +#endif // TILEDARRAY_TENSOR_ARENA_EINSUM_H__INCLUDED diff --git a/src/TiledArray/tensor/arena_kernels.h b/src/TiledArray/tensor/arena_kernels.h index 7a86615ac3..a9f774a657 100644 --- a/src/TiledArray/tensor/arena_kernels.h +++ b/src/TiledArray/tensor/arena_kernels.h @@ -133,13 +133,17 @@ OuterTensor arena_permute_shallow(const SrcOuterTensor& src, const Perm& perm) { } /// Allocate a slab-backed outer tile using caller-provided inner shapes. +/// `alignment` is the per-cell stride alignment (e.g. kArenaCachelineAlign). template OuterTensor arena_outer_init(const Range& outer_range, std::size_t batch_sz, - ShapeFn&& shape_fn, bool zero_init = true) { + ShapeFn&& shape_fn, + std::size_t alignment = kArenaCachelineAlign, + bool zero_init = true) { using inner_t = typename OuterTensor::value_type; using elem_t = typename inner_t::value_type; using inner_range_t = std::decay_t()))>; + TA_ASSERT(alignment >= alignof(elem_t)); const std::size_t N_cells = outer_range.volume() * batch_sz; std::vector ranges; ranges.reserve(N_cells); @@ -149,7 +153,7 @@ OuterTensor arena_outer_init(const Range& outer_range, std::size_t batch_sz, offsets[ord] = total_bytes; ranges.emplace_back(shape_fn(ord)); const std::size_t bytes = ranges.back().volume() * sizeof(elem_t); - total_bytes += arena_align_up(bytes, alignof(elem_t)); + total_bytes += arena_align_up(bytes, alignment); } auto arena = std::make_shared(); // Arena::reserve requires a non-empty slab. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 511f387426..e928a2c504 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -104,6 +104,7 @@ set(ta_test_src_files ta_test.cpp btas.cpp arena.cpp arena_kernels.cpp + arena_einsum_unit_suite.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) diff --git a/tests/arena_einsum_unit_suite.cpp b/tests/arena_einsum_unit_suite.cpp new file mode 100644 index 0000000000..6de13f8c84 --- /dev/null +++ b/tests/arena_einsum_unit_suite.cpp @@ -0,0 +1,253 @@ +/// Unit tests for arena einsum plans and dispatch. + +#include "TiledArray/tensor/arena_einsum.h" + +#include "tiledarray.h" +#include "unit_test_config.h" + +BOOST_AUTO_TEST_SUITE(arena_einsum_unit_suite, TA_UT_LABEL_SERIAL) + +namespace TA = TiledArray; + +BOOST_AUTO_TEST_CASE(inner_shape_plan_left_range) { + TA::Tensor l(TA::Range{3, 4}); + TA::Tensor r(TA::Range{3, 4}); + TA::detail::ArenaInnerShapePlan p{ + TA::detail::ArenaInnerShapeKind::left_range, std::nullopt}; + auto out = p.make(l, r); + BOOST_CHECK(out == l.range()); +} + +BOOST_AUTO_TEST_CASE(inner_shape_plan_right_range) { + TA::Tensor l(TA::Range{3, 4}); + TA::Tensor r(TA::Range{5, 6}); + TA::detail::ArenaInnerShapePlan p{ + TA::detail::ArenaInnerShapeKind::right_range, std::nullopt}; + auto out = p.make(l, r); + BOOST_CHECK(out == r.range()); +} + +BOOST_AUTO_TEST_CASE(inner_shape_plan_gemm_result_range) { + TA::Tensor l(TA::Range{3, 5}); + TA::Tensor r(TA::Range{5, 4}); + TA::math::GemmHelper gh(TA::math::blas::NoTranspose, + TA::math::blas::NoTranspose, 2, 2, 2); + TA::detail::ArenaInnerShapePlan p{ + TA::detail::ArenaInnerShapeKind::gemm_result_range, + std::make_optional(gh)}; + auto out = p.make(l, r); + BOOST_CHECK_EQUAL(out.volume(), std::size_t{12}); +} + +BOOST_AUTO_TEST_CASE(is_contraction_arena_tot_v_predicate) { + using ToT = TA::Tensor>; + static_assert(TA::detail::is_contraction_arena_tot_v); + using Plain = TA::Tensor; + static_assert(!TA::detail::is_contraction_arena_tot_v); + BOOST_CHECK(true); +} + +BOOST_AUTO_TEST_CASE(arena_plan_storage_t_resolves) { + using ToT = TA::Tensor>; + using Plain = TA::Tensor; + using ToTStorage = TA::detail::arena_plan_storage_t; + using PlainStorage = TA::detail::arena_plan_storage_t; + static_assert(!std::is_same_v); + static_assert(std::is_same_v); + BOOST_CHECK(true); +} + +BOOST_AUTO_TEST_CASE(make_plan_returns_nullopt_when_disabled) { + using ToT = TA::Tensor>; + TA::detail::arena_disabled() = true; + auto plan = TA::detail::make_contraction_arena_plan( + TA::detail::ArenaInnerShapeKind::left_range, std::nullopt, + TA::Permutation{}); + BOOST_CHECK(!plan.has_value()); + TA::detail::arena_disabled() = false; +} + +BOOST_AUTO_TEST_CASE(make_plan_returns_nullopt_for_plain_tensor) { + using Plain = TA::Tensor; + // Non-ToT gating happens inside the function body, not in the return type. + auto plan = TA::detail::make_contraction_arena_plan( + TA::detail::ArenaInnerShapeKind::left_range, std::nullopt, + TA::Permutation{}); + BOOST_CHECK(!plan.has_value()); +} + +BOOST_AUTO_TEST_CASE(make_plan_rejects_nonidentity_inner_perm) { + using ToT = TA::Tensor>; + TA::Permutation perm({1, 0}); + auto plan = TA::detail::make_contraction_arena_plan( + TA::detail::ArenaInnerShapeKind::left_range, std::nullopt, perm); + BOOST_CHECK(!plan.has_value()); +} + +BOOST_AUTO_TEST_CASE(make_plan_returns_active_for_tot) { + using ToT = TA::Tensor>; + auto plan = TA::detail::make_contraction_arena_plan( + TA::detail::ArenaInnerShapeKind::left_range, std::nullopt, + TA::Permutation{}); + BOOST_CHECK(plan.has_value()); +} + +namespace { +using ToT = TA::Tensor>; + +// Placement-new initializes each ToT inner cell in existing tensor storage. +ToT make_uniform_tot(const TA::Range& outer, const TA::Range& inner, + double fill) { + ToT t(outer); + const std::size_t vol = outer.volume(); + for (std::size_t i = 0; i < vol; ++i) { + new (t.data() + i) TA::Tensor(inner, fill); + } + return t; +} +} // namespace + +BOOST_AUTO_TEST_CASE(reserve_and_construct_uniform_inner) { + TA::math::GemmHelper outer_gh(TA::math::blas::NoTranspose, + TA::math::blas::NoTranspose, 2, 2, 2); + TA::math::GemmHelper inner_gh(TA::math::blas::NoTranspose, + TA::math::blas::NoTranspose, 2, 2, 2); + auto left = make_uniform_tot(TA::Range{2, 3}, TA::Range{3, 5}, 1.0); + auto right = make_uniform_tot(TA::Range{3, 4}, TA::Range{5, 4}, 1.0); + TA::detail::ArenaInnerShapePlan inner_plan{ + TA::detail::ArenaInnerShapeKind::gemm_result_range, + std::make_optional(inner_gh)}; + TA::detail::ContractionArenaPlan plan(inner_plan); + ToT result = plan.reserve_and_construct(left, right, outer_gh); + BOOST_CHECK_EQUAL(result.range().volume(), std::size_t{8}); + BOOST_CHECK_EQUAL(result.data()[0].range().volume(), std::size_t{12}); +} + +BOOST_AUTO_TEST_CASE(reserve_and_construct_zero_volume_outer_skips_reserve) { + TA::math::GemmHelper outer_gh(TA::math::blas::NoTranspose, + TA::math::blas::NoTranspose, 2, 2, 2); + TA::math::GemmHelper inner_gh(TA::math::blas::NoTranspose, + TA::math::blas::NoTranspose, 2, 2, 2); + auto left = make_uniform_tot(TA::Range{0, 3}, TA::Range{3, 5}, 1.0); + auto right = make_uniform_tot(TA::Range{3, 2}, TA::Range{5, 4}, 1.0); + TA::detail::ArenaInnerShapePlan inner_plan{ + TA::detail::ArenaInnerShapeKind::gemm_result_range, + std::make_optional(inner_gh)}; + TA::detail::ContractionArenaPlan plan(inner_plan); + ToT result = plan.reserve_and_construct(left, right, outer_gh); + BOOST_CHECK_EQUAL(result.range().volume(), std::size_t{0}); +} + +BOOST_AUTO_TEST_CASE(reserve_and_construct_jagged_inner_per_cell) { + // Jagged left cells make first-non-empty K-strip range selection observable. + TA::math::GemmHelper outer_gh(TA::math::blas::NoTranspose, + TA::math::blas::NoTranspose, 2, 2, 2); + ToT left(TA::Range{2, 3}); + for (std::size_t m = 0; m < 2; ++m) + for (std::size_t k = 0; k < 3; ++k) { + TA::Range r{static_cast(m + 1), static_cast(k + 2)}; + new (left.data() + (m * 3 + k)) TA::Tensor(r, 1.0); + } + auto right = make_uniform_tot(TA::Range{3, 2}, TA::Range{2, 2}, 1.0); + TA::detail::ArenaInnerShapePlan inner_plan{ + TA::detail::ArenaInnerShapeKind::left_range, std::nullopt}; + TA::detail::ContractionArenaPlan plan(inner_plan); + ToT result = plan.reserve_and_construct(left, right, outer_gh); + BOOST_CHECK_EQUAL(result.range().volume(), std::size_t{4}); + BOOST_CHECK_EQUAL(result.data()[0].range().volume(), std::size_t{2}); + BOOST_CHECK_EQUAL(result.data()[1].range().volume(), std::size_t{2}); + BOOST_CHECK_EQUAL(result.data()[2].range().volume(), std::size_t{4}); + BOOST_CHECK_EQUAL(result.data()[3].range().volume(), std::size_t{4}); +} + +BOOST_AUTO_TEST_CASE(fused_hadamard_inplace_accumulates) { + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor l(TA::Range{4}, 1.0); + TA::Tensor rr(TA::Range{4}, 2.0); + TA::detail::fused_hadamard_inplace(r, l, rr); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 2.0, 1e-12); + TA::detail::fused_hadamard_inplace(r, l, rr); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 4.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_hadamard_scaled_inplace_accumulates) { + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor l(TA::Range{4}, 1.0); + TA::Tensor rr(TA::Range{4}, 2.0); + TA::detail::fused_hadamard_scaled_inplace(r, l, rr, 3.0); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 6.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_scale_tot_x_t_inplace_accumulates) { + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor l(TA::Range{4}, 1.5); + TA::detail::fused_scale_tot_x_t_inplace(r, l, 2.0); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 3.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_scale_t_x_tot_inplace_accumulates) { + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor rr(TA::Range{4}, 2.5); + TA::detail::fused_scale_t_x_tot_inplace(r, 4.0, rr); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 10.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_contraction_inplace_accumulates) { + TA::Tensor r(TA::Range{2, 2}, 0.0); + TA::Tensor l(TA::Range{2, 2}, 1.0); + TA::Tensor rr(TA::Range{2, 2}, 2.0); + TA::math::GemmHelper gh(TA::math::blas::NoTranspose, + TA::math::blas::NoTranspose, 2, 2, 2); + TA::detail::fused_contraction_inplace(r, l, rr, 1.0, gh); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 4.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_hadamard_lambda_round_trip) { + auto fn = TA::detail::make_fused_hadamard_lambda< + TA::Tensor, TA::Tensor, TA::Tensor>(); + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor l(TA::Range{4}, 1.0); + TA::Tensor rr(TA::Range{4}, 2.0); + fn(r, l, rr); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 2.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_hadamard_scaled_lambda_round_trip) { + auto fn = TA::detail::make_fused_hadamard_scaled_lambda< + TA::Tensor, TA::Tensor, TA::Tensor, double>(3.0); + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor l(TA::Range{4}, 1.0); + TA::Tensor rr(TA::Range{4}, 2.0); + fn(r, l, rr); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 6.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_scale_tot_x_t_lambda_round_trip) { + auto fn = TA::detail::make_fused_scale_tot_x_t_lambda< + TA::Tensor, TA::Tensor, double>(); + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor l(TA::Range{4}, 1.5); + fn(r, l, 2.0); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 3.0, 1e-12); +} + +BOOST_AUTO_TEST_CASE(fused_scale_t_x_tot_lambda_round_trip) { + auto fn = TA::detail::make_fused_scale_t_x_tot_lambda< + TA::Tensor, double, TA::Tensor>(); + TA::Tensor r(TA::Range{4}, 0.0); + TA::Tensor rr(TA::Range{4}, 2.5); + fn(r, 4.0, rr); + for (std::size_t i = 0; i < 4; ++i) + BOOST_CHECK_CLOSE(r.data()[i], 10.0, 1e-12); +} + +BOOST_AUTO_TEST_SUITE_END() From 0bb6cd06e6d8a55a9f11817d5eb0df5c058db67e Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 11 May 2026 20:09:02 -0400 Subject: [PATCH 4/7] tensor: route ToT trivial ops through arena kernels + tests --- src/TiledArray/tensor/tensor.h | 132 ++++++++++++++++++++---------- tests/CMakeLists.txt | 1 + tests/arena_tot_trivial.cpp | 144 +++++++++++++++++++++++++++++++++ 3 files changed, 236 insertions(+), 41 deletions(-) create mode 100644 tests/arena_tot_trivial.cpp diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index ba684cc768..68467d95cf 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -29,6 +29,7 @@ #include "TiledArray/math/gemm_helper.h" #include "TiledArray/tensor/complex.h" #include "TiledArray/tensor/kernels.h" +#include "TiledArray/tensor/arena_kernels.h" #include "TiledArray/tile_interface/clone.h" #include "TiledArray/tile_interface/permute.h" #include "TiledArray/tile_interface/trace.h" @@ -652,8 +653,14 @@ class Tensor { Tensor clone() const& { Tensor result; if (data_) { - if constexpr (detail::is_tensor_of_tensor_v) { - result = Tensor(*this, [](value_type const& el) { return el.clone(); }); + if constexpr (detail::is_tensor_of_tensor_v && + detail::is_ta_tensor_v) { + auto fill = [](typename value_type::value_type* dst, + const typename value_type::value_type* src, + std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = src[i]; + }; + result = detail::arena_trivial_unary(*this, fill); } else { result = detail::tensor_op( [](const numeric_type value) -> numeric_type { return value; }, @@ -1680,10 +1687,20 @@ class Tensor { // early exit for empty this if (empty()) return {}; - return unary([factor](const value_type& a) { - using namespace TiledArray::detail; - return a * factor; - }); + if constexpr (detail::is_tensor_of_tensor_v && + detail::is_ta_tensor_v) { + auto fill = [factor](typename value_type::value_type* dst, + const typename value_type::value_type* src, + std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = src[i] * factor; + }; + return detail::arena_trivial_unary(*this, fill); + } else { + return unary([factor](const value_type& a) { + using namespace TiledArray::detail; + return a * factor; + }); + } } /// Construct a scaled copy of this tensor @@ -1756,24 +1773,35 @@ class Tensor { // early exit for empty this if (empty()) detail::clone_or_cast(right); - return binary( - right, - [](const value_type& l, const value_t& r) -> decltype(l + r) { - if constexpr (detail::is_tensor_v) { - if (l.empty()) { - if (r.empty()) - return {}; - else - return r.clone(); - } else { - if (r.empty()) - return l.clone(); - else - return l + r; + if constexpr (detail::is_tensor_of_tensor_v && + detail::is_ta_tensor_v && + detail::is_ta_tensor_v) { + auto fill = [](typename value_type::value_type* dst, + const typename value_type::value_type* l, + const typename value_type::value_type* r, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = l[i] + r[i]; + }; + return detail::arena_trivial_binary(*this, right, fill); + } else { + return binary( + right, + [](const value_type& l, const value_t& r) -> decltype(l + r) { + if constexpr (detail::is_tensor_v) { + if (l.empty()) { + if (r.empty()) + return {}; + else + return r.clone(); + } else { + if (r.empty()) + return l.clone(); + else + return l + r; + } } - } - return l + r; - }); + return l + r; + }); + } } /// Add this and \c other to construct a new tensor @@ -1956,25 +1984,36 @@ class Tensor { typename = std::enable_if< detail::tensors_have_equal_nested_rank_v>> Tensor subt(const Right& right) const { - return binary( - right, - [](const value_type& l, const value_t& r) -> decltype(l - r) { - if constexpr (detail::is_tensor_v) { - if (l.empty()) { - if (r.empty()) - return {}; - else - return -r; + if constexpr (detail::is_tensor_of_tensor_v && + detail::is_ta_tensor_v && + detail::is_ta_tensor_v) { + auto fill = [](typename value_type::value_type* dst, + const typename value_type::value_type* l, + const typename value_type::value_type* r, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = l[i] - r[i]; + }; + return detail::arena_trivial_binary(*this, right, fill); + } else { + return binary( + right, + [](const value_type& l, const value_t& r) -> decltype(l - r) { + if constexpr (detail::is_tensor_v) { + if (l.empty()) { + if (r.empty()) + return {}; + else + return -r; + } else { + if (r.empty()) + return l.clone(); + else + return l - r; + } } else { - if (r.empty()) - return l.clone(); - else - return l - r; + return l - r; } - } else { - return l - r; - } - }); + }); + } } /// Subtract \c right from this and return the result permuted by \c perm @@ -2122,7 +2161,18 @@ class Tensor { return res_t{}; } - return binary(right, mult_op); + if constexpr (detail::is_tensor_of_tensor_v && + detail::is_ta_tensor_v && + detail::is_ta_tensor_v) { + auto fill = [](typename value_type::value_type* dst, + const typename value_type::value_type* l, + const typename value_type::value_type* r, std::size_t n) { + for (std::size_t i = 0; i < n; ++i) dst[i] = l[i] * r[i]; + }; + return detail::arena_trivial_binary(*this, right, fill); + } else { + return binary(right, mult_op); + } } /// Multiply this by \c right to create a new, permuted tensor diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e928a2c504..688ecaf52f 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -105,6 +105,7 @@ set(ta_test_src_files ta_test.cpp arena.cpp arena_kernels.cpp arena_einsum_unit_suite.cpp + arena_tot_trivial.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) diff --git a/tests/arena_tot_trivial.cpp b/tests/arena_tot_trivial.cpp new file mode 100644 index 0000000000..627bd5a7cc --- /dev/null +++ b/tests/arena_tot_trivial.cpp @@ -0,0 +1,144 @@ +/// Arena-aware ToT trivial-op end-to-end tests (add, subt, mult, scale, clone). + +#include "TiledArray/tensor.h" +#include "tiledarray.h" +#include "unit_test_config.h" + +#include +#include +#include + +namespace TA = TiledArray; +using inner_t = TA::Tensor; +using outer_t = TA::Tensor; + +namespace { + +outer_t make_tot(std::size_t N_outer, std::size_t n_inner, double base = 1.0) { + outer_t outer(TA::Range{static_cast(N_outer)}, 1); + for (std::size_t ord = 0; ord < N_outer; ++ord) { + inner_t inner(TA::Range{static_cast(n_inner)}); + for (std::size_t i = 0; i < n_inner; ++i) + inner.at_ordinal(i) = base + ord * 100.0 + i; + *(outer.data() + ord) = std::move(inner); + } + return outer; +} + +bool tot_equal(const outer_t& a, const outer_t& b) { + if (a.range().volume() != b.range().volume()) return false; + for (std::size_t ord = 0; ord < a.range().volume(); ++ord) { + const inner_t& ai = *(a.data() + ord); + const inner_t& bi = *(b.data() + ord); + if (ai.range().volume() != bi.range().volume()) return false; + for (std::size_t i = 0; i < ai.range().volume(); ++i) + if (ai.at_ordinal(i) != bi.at_ordinal(i)) return false; + } + return true; +} + +/// All inner cells point into one contiguous slab (monotonic with bounded gap). +bool inners_share_one_slab(const outer_t& tot) { + if (tot.range().volume() == 0) return true; + const double* prev_end = nullptr; + for (std::size_t ord = 0; ord < tot.range().volume(); ++ord) { + const inner_t& cell = *(tot.data() + ord); + if (cell.range().volume() == 0) continue; + const double* cell_begin = cell.data(); + const double* cell_end = cell_begin + cell.range().volume(); + if (prev_end != nullptr && cell_begin < prev_end) return false; + if (prev_end != nullptr && + static_cast(cell_begin - prev_end) > 1024) + return false; + prev_end = cell_end; + } + return true; +} + +} + +BOOST_AUTO_TEST_SUITE(arena_tot_trivial_suite, TA_UT_LABEL_SERIAL) + +BOOST_AUTO_TEST_CASE(scale_bit_equal_and_one_slab) { + outer_t src = make_tot(6, 8, 1.0); + outer_t arena_result = src.scale(2.5); + outer_t baseline(src.range(), 1); + for (std::size_t ord = 0; ord < src.range().volume(); ++ord) { + inner_t inner((src.data() + ord)->range()); + for (std::size_t i = 0; i < inner.range().volume(); ++i) + inner.at_ordinal(i) = (src.data() + ord)->at_ordinal(i) * 2.5; + *(baseline.data() + ord) = std::move(inner); + } + BOOST_CHECK(tot_equal(arena_result, baseline)); + BOOST_CHECK(inners_share_one_slab(arena_result)); +} + +BOOST_AUTO_TEST_CASE(clone_bit_equal_and_one_slab) { + outer_t src = make_tot(6, 8, 3.0); + outer_t arena_result = src.clone(); + BOOST_CHECK(tot_equal(arena_result, src)); + BOOST_CHECK(inners_share_one_slab(arena_result)); +} + +BOOST_AUTO_TEST_CASE(add_bit_equal_and_one_slab) { + outer_t L = make_tot(6, 8, 1.0); + outer_t R = make_tot(6, 8, 0.5); + outer_t arena_result = L.add(R); + outer_t baseline(L.range(), 1); + for (std::size_t ord = 0; ord < L.range().volume(); ++ord) { + inner_t inner((L.data() + ord)->range()); + for (std::size_t i = 0; i < inner.range().volume(); ++i) + inner.at_ordinal(i) = (L.data() + ord)->at_ordinal(i) + + (R.data() + ord)->at_ordinal(i); + *(baseline.data() + ord) = std::move(inner); + } + BOOST_CHECK(tot_equal(arena_result, baseline)); + BOOST_CHECK(inners_share_one_slab(arena_result)); +} + +BOOST_AUTO_TEST_CASE(subt_bit_equal_and_one_slab) { + outer_t L = make_tot(6, 8, 5.0); + outer_t R = make_tot(6, 8, 1.0); + outer_t arena_result = L.subt(R); + outer_t baseline(L.range(), 1); + for (std::size_t ord = 0; ord < L.range().volume(); ++ord) { + inner_t inner((L.data() + ord)->range()); + for (std::size_t i = 0; i < inner.range().volume(); ++i) + inner.at_ordinal(i) = (L.data() + ord)->at_ordinal(i) - + (R.data() + ord)->at_ordinal(i); + *(baseline.data() + ord) = std::move(inner); + } + BOOST_CHECK(tot_equal(arena_result, baseline)); + BOOST_CHECK(inners_share_one_slab(arena_result)); +} + +BOOST_AUTO_TEST_CASE(mult_elementwise_bit_equal_and_one_slab) { + outer_t L = make_tot(6, 8, 2.0); + outer_t R = make_tot(6, 8, 0.5); + outer_t arena_result = L.mult(R); + outer_t baseline(L.range(), 1); + for (std::size_t ord = 0; ord < L.range().volume(); ++ord) { + inner_t inner((L.data() + ord)->range()); + for (std::size_t i = 0; i < inner.range().volume(); ++i) + inner.at_ordinal(i) = (L.data() + ord)->at_ordinal(i) * + (R.data() + ord)->at_ordinal(i); + *(baseline.data() + ord) = std::move(inner); + } + BOOST_CHECK(tot_equal(arena_result, baseline)); + BOOST_CHECK(inners_share_one_slab(arena_result)); +} + +BOOST_AUTO_TEST_CASE(arena_outlives_source) { + outer_t arena_result; + { + outer_t src = make_tot(3, 4, 9.0); + arena_result = src.scale(2.0); + } + for (std::size_t ord = 0; ord < arena_result.range().volume(); ++ord) + for (std::size_t i = 0; i < (arena_result.data() + ord)->range().volume(); + ++i) + BOOST_CHECK_EQUAL((arena_result.data() + ord)->at_ordinal(i), + (9.0 + ord * 100.0 + i) * 2.0); +} + +BOOST_AUTO_TEST_SUITE_END() From 2c606b9dbd18a9035bddf24e6e2678da11fa338c Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 11 May 2026 20:14:17 -0400 Subject: [PATCH 5/7] cont_engine: thread arena plan + zero-overhead sizeof gate --- src/TiledArray/expressions/cont_engine.h | 300 ++++++++++++++++++----- src/TiledArray/tile_op/contract_reduce.h | 83 +++++-- tests/CMakeLists.txt | 1 + tests/arena_sizeof_invariant_suite.cpp | 85 +++++++ 4 files changed, 390 insertions(+), 79 deletions(-) create mode 100644 tests/arena_sizeof_invariant_suite.cpp diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 907a1632fd..83b97e5f6d 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -128,6 +129,9 @@ class ContEngine : public BinaryEngine { const right_tile_element_type&)> element_return_op_; ///< Same as element_nonreturn_op_ but returns ///< the result + using arena_plan_storage_t = TiledArray::detail::arena_plan_storage_t< + result_tile_type, left_tile_type, right_tile_type>; + TA_NO_UNIQUE_ADDRESS arena_plan_storage_t arena_plan_; TiledArray::detail::ProcGrid proc_grid_; ///< Process grid for the contraction size_type K_ = 1; ///< Inner dimension size @@ -300,7 +304,8 @@ class ContEngine : public BinaryEngine { // factor_ is absorbed into inner_tile_nonreturn_op_ op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_), outer_size(left_indices_), outer_size(right_indices_), - total_perm, this->element_nonreturn_op_); + total_perm, this->element_nonreturn_op_, + std::move(this->arena_plan_)); } trange_ = ContEngine_::make_trange(outer_perm); shape_ = ContEngine_::make_shape(outer_perm); @@ -330,7 +335,8 @@ class ContEngine : public BinaryEngine { // factor_ is absorbed into inner_tile_nonreturn_op_ op_ = op_type(left_op, right_op, scalar_type(1), outer_size(indices_), outer_size(left_indices_), outer_size(right_indices_), - total_perm, this->element_nonreturn_op_); + total_perm, this->element_nonreturn_op_, + std::move(this->arena_plan_)); } trange_ = ContEngine_::make_trange(); shape_ = ContEngine_::make_shape(); @@ -541,17 +547,51 @@ class ContEngine : public BinaryEngine { this->factor_, inner_size(this->indices_), inner_size(this->left_indices_), inner_size(this->right_indices_)); - this->element_nonreturn_op_ = - [contrreduce_op, permute_inner = this->product_type() != - TensorProduct::Contraction]( - result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - contrreduce_op(result, left, right); - // permutations of result are applied as "postprocessing" - if (permute_inner && !TA::empty(result)) - result = contrreduce_op(result); - }; + constexpr bool arena_eligible = + TiledArray::detail::is_contraction_arena_tot_v< + result_tile_type, left_tile_type, right_tile_type>; + if constexpr (arena_eligible) { + if (this->product_type() == TensorProduct::Contraction) { + this->arena_plan_ = + TiledArray::detail::make_contraction_arena_plan< + result_tile_type, left_tile_type, right_tile_type>( + TiledArray::detail::ArenaInnerShapeKind::gemm_result_range, + std::make_optional(contrreduce_op.gemm_helper()), + inner(this->perm_)); + } + } + if constexpr (arena_eligible) { + if (this->arena_plan_) { + this->element_nonreturn_op_ = + TiledArray::detail::make_fused_contraction_lambda< + result_tile_element_type, left_tile_element_type, + right_tile_element_type>(contrreduce_op); + } else { + this->element_nonreturn_op_ = + [contrreduce_op, permute_inner = this->product_type() != + TensorProduct::Contraction]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + contrreduce_op(result, left, right); + // permutations of result are applied as "postprocessing" + if (permute_inner && !TA::empty(result)) + result = contrreduce_op(result); + }; + } + } else { + this->element_nonreturn_op_ = + [contrreduce_op, permute_inner = this->product_type() != + TensorProduct::Contraction]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + contrreduce_op(result, left, right); + // permutations of result are applied as "postprocessing" + if (permute_inner && !TA::empty(result)) + result = contrreduce_op(result); + }; + } } // ToT x ToT } else if (inner_prod == TensorProduct::Hadamard) { TA_ASSERT(tot_x_tot); @@ -574,26 +614,68 @@ class ContEngine : public BinaryEngine { ? inner(this->perm_) : Permutation{}) : op_type(base_op_type()); - this->element_nonreturn_op_ = - [mult_op, outer_prod](result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - TA_ASSERT(outer_prod == TensorProduct::Hadamard || - outer_prod == TensorProduct::Contraction); - if (outer_prod == TensorProduct::Hadamard) - result = mult_op(left, right); - else { // outer_prod == TensorProduct::Contraction - // there is currently no fused MultAdd ternary Op, only Add - // and Mult thus implement this as 2 separate steps - // TODO optimize by implementing (ternary) MultAdd - if (empty(result)) + constexpr bool arena_eligible_h_unit = + TiledArray::detail::is_contraction_arena_tot_v< + result_tile_type, left_tile_type, right_tile_type>; + if constexpr (arena_eligible_h_unit) { + if (this->product_type() == TensorProduct::Contraction) { + this->arena_plan_ = + TiledArray::detail::make_contraction_arena_plan< + result_tile_type, left_tile_type, right_tile_type>( + TiledArray::detail::ArenaInnerShapeKind::left_range, + std::nullopt, inner(this->perm_)); + } + } + if constexpr (arena_eligible_h_unit) { + if (this->arena_plan_) { + this->element_nonreturn_op_ = + TiledArray::detail::make_fused_hadamard_lambda< + result_tile_element_type, left_tile_element_type, + right_tile_element_type>(); + } else { + this->element_nonreturn_op_ = + [mult_op, outer_prod](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + TA_ASSERT(outer_prod == TensorProduct::Hadamard || + outer_prod == TensorProduct::Contraction); + if (outer_prod == TensorProduct::Hadamard) + result = mult_op(left, right); + else { // outer_prod == TensorProduct::Contraction + // there is currently no fused MultAdd ternary Op, only Add + // and Mult thus implement this as 2 separate steps + // TODO optimize by implementing (ternary) MultAdd + if (empty(result)) + result = mult_op(left, right); + else { + auto result_increment = mult_op(left, right); + add_to(result, result_increment); + } + } + }; + } + } else { + this->element_nonreturn_op_ = + [mult_op, outer_prod](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + TA_ASSERT(outer_prod == TensorProduct::Hadamard || + outer_prod == TensorProduct::Contraction); + if (outer_prod == TensorProduct::Hadamard) result = mult_op(left, right); - else { - auto result_increment = mult_op(left, right); - add_to(result, result_increment); + else { // outer_prod == TensorProduct::Contraction + // there is currently no fused MultAdd ternary Op, only Add + // and Mult thus implement this as 2 separate steps + // TODO optimize by implementing (ternary) MultAdd + if (empty(result)) + result = mult_op(left, right); + else { + auto result_increment = mult_op(left, right); + add_to(result, result_increment); + } } - } - }; + }; + } } else { using base_op_type = TiledArray::detail::ScalMult< result_tile_element_type, left_tile_element_type, @@ -607,26 +689,68 @@ class ContEngine : public BinaryEngine { ? inner(this->perm_) : Permutation{}) : op_type(base_op_type(this->factor_)); - this->element_nonreturn_op_ = - [mult_op, outer_prod](result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - TA_ASSERT(outer_prod == TensorProduct::Hadamard || - outer_prod == TensorProduct::Contraction); - if (outer_prod == TensorProduct::Hadamard) - result = mult_op(left, right); - else { - // there is currently no fused MultAdd ternary Op, only Add - // and Mult thus implement this as 2 separate steps - // TODO optimize by implementing (ternary) MultAdd - if (empty(result)) + constexpr bool arena_eligible_h_scaled = + TiledArray::detail::is_contraction_arena_tot_v< + result_tile_type, left_tile_type, right_tile_type>; + if constexpr (arena_eligible_h_scaled) { + if (this->product_type() == TensorProduct::Contraction) { + this->arena_plan_ = + TiledArray::detail::make_contraction_arena_plan< + result_tile_type, left_tile_type, right_tile_type>( + TiledArray::detail::ArenaInnerShapeKind::left_range, + std::nullopt, inner(this->perm_)); + } + } + if constexpr (arena_eligible_h_scaled) { + if (this->arena_plan_) { + this->element_nonreturn_op_ = + TiledArray::detail::make_fused_hadamard_scaled_lambda< + result_tile_element_type, left_tile_element_type, + right_tile_element_type>(this->factor_); + } else { + this->element_nonreturn_op_ = + [mult_op, outer_prod](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + TA_ASSERT(outer_prod == TensorProduct::Hadamard || + outer_prod == TensorProduct::Contraction); + if (outer_prod == TensorProduct::Hadamard) + result = mult_op(left, right); + else { + // there is currently no fused MultAdd ternary Op, only Add + // and Mult thus implement this as 2 separate steps + // TODO optimize by implementing (ternary) MultAdd + if (empty(result)) + result = mult_op(left, right); + else { + auto result_increment = mult_op(left, right); + add_to(result, result_increment); + } + } + }; + } + } else { + this->element_nonreturn_op_ = + [mult_op, outer_prod](result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + TA_ASSERT(outer_prod == TensorProduct::Hadamard || + outer_prod == TensorProduct::Contraction); + if (outer_prod == TensorProduct::Hadamard) result = mult_op(left, right); else { - auto result_increment = mult_op(left, right); - add_to(result, result_increment); + // there is currently no fused MultAdd ternary Op, only Add + // and Mult thus implement this as 2 separate steps + // TODO optimize by implementing (ternary) MultAdd + if (empty(result)) + result = mult_op(left, right); + else { + auto result_increment = mult_op(left, right); + add_to(result, result_increment); + } } - } - }; + }; + } } } // ToT x T or T x ToT } else if (inner_prod == TensorProduct::Scale) { @@ -667,24 +791,72 @@ class ContEngine : public BinaryEngine { } else abort(); // unreachable }; - this->element_nonreturn_op_ = - [scal_op, outer_prod = (this->product_type())]( - result_tile_element_type& result, - const left_tile_element_type& left, - const right_tile_element_type& right) { - if (outer_prod == TensorProduct::Contraction) { - // TODO implement X-permuting AXPY - if (empty(result)) + constexpr auto kind = + tot_x_t ? TiledArray::detail::ArenaInnerShapeKind::left_range + : TiledArray::detail::ArenaInnerShapeKind::right_range; + constexpr bool arena_eligible_scale = + TiledArray::detail::is_contraction_arena_tot_v< + result_tile_type, left_tile_type, right_tile_type>; + if constexpr (arena_eligible_scale) { + if (this->product_type() == TensorProduct::Contraction) { + this->arena_plan_ = + TiledArray::detail::make_contraction_arena_plan< + result_tile_type, left_tile_type, right_tile_type>( + kind, std::nullopt, inner(this->perm_)); + } + } + if constexpr (arena_eligible_scale) { + if (this->arena_plan_) { + if constexpr (tot_x_t) + this->element_nonreturn_op_ = + TiledArray::detail::make_fused_scale_tot_x_t_lambda< + result_tile_element_type, left_tile_element_type, + right_tile_element_type>(); + else + this->element_nonreturn_op_ = + TiledArray::detail::make_fused_scale_t_x_tot_lambda< + result_tile_element_type, left_tile_element_type, + right_tile_element_type>(); + } else { + this->element_nonreturn_op_ = + [scal_op, outer_prod = (this->product_type())]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (outer_prod == TensorProduct::Contraction) { + // TODO implement X-permuting AXPY + if (empty(result)) + result = scal_op(left, right); + else { + auto result_increment = scal_op(left, right); + add_to(result, result_increment); + } + // result += scal_op(left, right); + } else { + result = scal_op(left, right); + } + }; + } + } else { + this->element_nonreturn_op_ = + [scal_op, outer_prod = (this->product_type())]( + result_tile_element_type& result, + const left_tile_element_type& left, + const right_tile_element_type& right) { + if (outer_prod == TensorProduct::Contraction) { + // TODO implement X-permuting AXPY + if (empty(result)) + result = scal_op(left, right); + else { + auto result_increment = scal_op(left, right); + add_to(result, result_increment); + } + // result += scal_op(left, right); + } else { result = scal_op(left, right); - else { - auto result_increment = scal_op(left, right); - add_to(result, result_increment); } - // result += scal_op(left, right); - } else { - result = scal_op(left, right); - } - }; + }; + } } } else abort(); // unsupported TensorProduct type diff --git a/src/TiledArray/tile_op/contract_reduce.h b/src/TiledArray/tile_op/contract_reduce.h index 2a5e90ea5d..843d35212d 100644 --- a/src/TiledArray/tile_op/contract_reduce.h +++ b/src/TiledArray/tile_op/contract_reduce.h @@ -26,8 +26,12 @@ #ifndef TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED #define TILEDARRAY_TILE_OP_CONTRACT_REDUCE_H__INCLUDED +#include +#include +#include #include #include +#include #include #include #include @@ -81,23 +85,32 @@ class ContractReduceBase { private: struct Impl { + using left_tile_type = std::remove_cv_t>; + using right_tile_type = std::remove_cv_t>; + using arena_plan_storage_t = + TiledArray::detail::arena_plan_storage_t; + template < typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref, + typename Plan = arena_plan_storage_t, typename = std::enable_if_t< TiledArray::detail::is_permutation_v< std::remove_reference_t> && std::is_invocable_r_v, result_value_type&, const left_value_type&, - const right_value_type&>>> + const right_value_type&> && + std::is_same_v, arena_plan_storage_t>>> Impl(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}, + Plan&& arena_plan_in = {}) : gemm_helper_(left_op, right_op, result_rank, left_rank, right_rank), alpha_(alpha), perm_(std::forward(perm)), - elem_muladd_op_(std::forward(elem_muladd_op)) { + elem_muladd_op_(std::forward(elem_muladd_op)), + arena_plan_(std::forward(arena_plan_in)) { // non-unit alpha must be absorbed into elem_muladd_op if (elem_muladd_op_) TA_ASSERT(alpha == scalar_type(1)); } @@ -111,6 +124,8 @@ class ContractReduceBase { /// type-erased reference to custom element multiply-add op /// \note the lifetime is managed by the callee! TiledArray::function_ref elem_muladd_op_; + + TA_NO_UNIQUE_ADDRESS arena_plan_storage_t arena_plan_; }; std::shared_ptr pimpl_; @@ -125,6 +140,8 @@ class ContractReduceBase { ContractReduceBase_& operator=(const ContractReduceBase_&) = default; ContractReduceBase_& operator=(ContractReduceBase_&&) = default; + using arena_plan_storage_t = typename Impl::arena_plan_storage_t; + /// Construct contract/reduce functor /// \tparam Perm a permutation type @@ -141,21 +158,25 @@ class ContractReduceBase { template < typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref, + typename Plan = typename Impl::arena_plan_storage_t, typename = std::enable_if_t< TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, - const right_value_type&>>> + const right_value_type&> && + std::is_same_v, typename Impl::arena_plan_storage_t>>> ContractReduceBase(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, Perm&& perm = {}, - ElemMultAddOp&& elem_muladd_op = {}) + ElemMultAddOp&& elem_muladd_op = {}, + Plan&& arena_plan_in = {}) : pimpl_(std::make_shared( left_op, right_op, alpha, result_rank, left_rank, right_rank, std::forward(perm), - std::forward(elem_muladd_op))) {} + std::forward(elem_muladd_op), + std::forward(arena_plan_in))) {} /// Gemm meta data accessor @@ -189,6 +210,14 @@ class ContractReduceBase { return pimpl_->elem_muladd_op_; } + /// Arena plan accessor + + /// \return A const reference to the arena plan storage + const auto& arena_plan() const { + TA_ASSERT(pimpl_); + return pimpl_->arena_plan_; + } + //-------------- these are only used for unit tests ----------------- /// Compute the number of contracted ranks @@ -277,18 +306,23 @@ class ContractReduce : public ContractReduceBase { template < typename Perm = BipartitePermutation, typename ElemMultAddOp = TiledArray::function_ref, + typename Plan = typename ContractReduceBase_::arena_plan_storage_t, typename = std::enable_if_t< TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, - const right_value_type&>>> + const right_value_type&> && + std::is_same_v, + typename ContractReduceBase_::arena_plan_storage_t>>> ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}, + Plan&& arena_plan_in = {}) : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank, right_rank, std::forward(perm), - std::forward(elem_muladd_op)) {} + std::forward(elem_muladd_op), + std::forward(arena_plan_in)) {} /// Create a result type object @@ -332,6 +366,15 @@ class ContractReduce : public ContractReduceBase { if constexpr (!ContractReduceBase_::plain_tensors) { TA_ASSERT(this->elem_muladd_op()); + if constexpr (detail::is_contraction_arena_tot_v< + result_type, + std::remove_cv_t>, + std::remove_cv_t>>) { + if (empty(result) && this->arena_plan().has_value()) { + result = this->arena_plan()->reserve_and_construct( + left, right, this->gemm_helper()); + } + } gemm(result, left, right, ContractReduceBase_::gemm_helper(), this->elem_muladd_op()); } else { // plain tensors @@ -404,18 +447,23 @@ class ContractReduce, + typename Plan = typename ContractReduceBase_::arena_plan_storage_t, typename = std::enable_if_t< TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, - const right_value_type&>>> + const right_value_type&> && + std::is_same_v, + typename ContractReduceBase_::arena_plan_storage_t>>> ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}, + Plan&& arena_plan_in = {}) : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank, right_rank, std::forward(perm), - std::forward(elem_muladd_op)) {} + std::forward(elem_muladd_op), + std::forward(arena_plan_in)) {} /// Create a result type object @@ -530,18 +578,23 @@ class ContractReduce, + typename Plan = typename ContractReduceBase_::arena_plan_storage_t, typename = std::enable_if_t< TiledArray::detail::is_permutation_v> && std::is_invocable_r_v, result_value_type&, const left_value_type&, - const right_value_type&>>> + const right_value_type&> && + std::is_same_v, + typename ContractReduceBase_::arena_plan_storage_t>>> ContractReduce(const math::blas::Op left_op, const math::blas::Op right_op, const scalar_type alpha, const unsigned int result_rank, const unsigned int left_rank, const unsigned int right_rank, - Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}) + Perm&& perm = {}, ElemMultAddOp&& elem_muladd_op = {}, + Plan&& arena_plan_in = {}) : ContractReduceBase_(left_op, right_op, alpha, result_rank, left_rank, right_rank, std::forward(perm), - std::forward(elem_muladd_op)) {} + std::forward(elem_muladd_op), + std::forward(arena_plan_in)) {} /// Create a result type object diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 688ecaf52f..5635bb23d6 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -106,6 +106,7 @@ set(ta_test_src_files ta_test.cpp arena_kernels.cpp arena_einsum_unit_suite.cpp arena_tot_trivial.cpp + arena_sizeof_invariant_suite.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) diff --git a/tests/arena_sizeof_invariant_suite.cpp b/tests/arena_sizeof_invariant_suite.cpp new file mode 100644 index 0000000000..047f62f1ff --- /dev/null +++ b/tests/arena_sizeof_invariant_suite.cpp @@ -0,0 +1,85 @@ +/// Locks plain-tensor zero-overhead from the arena plan storage field. + +#include "TiledArray/tensor.h" +#include "TiledArray/tensor/arena_einsum.h" +#include "TiledArray/tile_op/contract_reduce.h" +#include "TiledArray/util/function.h" +#include "tiledarray.h" +#include "unit_test_config.h" + +#include +#include +#include + +namespace TA = TiledArray; + +namespace { + +using PlainResult = TA::Tensor; +using PlainLeft = TA::Tensor; +using PlainRight = TA::Tensor; +using PlainScalar = double; + +using PlainContractReduceBase = + TA::detail::ContractReduceBase; + +using PlainArenaPlanStorage = + TA::detail::arena_plan_storage_t; + +using PlainElemMulAddOp = + TA::function_ref; + +/// Shadows the public field order of ContractReduceBase::Impl on master. +struct ImplLayoutMaster { + TA::math::GemmHelper gemm_helper_; + PlainScalar alpha_; + TA::BipartitePermutation perm_; + PlainElemMulAddOp elem_muladd_op_; +}; + +/// Same as ImplLayoutMaster + trailing TA_NO_UNIQUE_ADDRESS arena_plan_. +struct ImplLayoutAllocator { + TA::math::GemmHelper gemm_helper_; + PlainScalar alpha_; + TA::BipartitePermutation perm_; + PlainElemMulAddOp elem_muladd_op_; + TA_NO_UNIQUE_ADDRESS PlainArenaPlanStorage arena_plan_; +}; + +/// Sizes captured against master; re-baseline if the toolchain changes. +constexpr std::size_t kTensorDoubleSizeMaster = 328; +constexpr std::size_t kContractReduceBaseSizeMaster = 16; +constexpr std::size_t kImplLayoutSizeMaster = 248; + +static_assert(std::is_same_v, + "plain-tensor arena_plan_storage_t must be std::monostate"); + +static_assert(sizeof(ImplLayoutAllocator) == sizeof(ImplLayoutMaster), + "TA_NO_UNIQUE_ADDRESS failed to fold arena_plan_ into padding"); + +} + +BOOST_AUTO_TEST_SUITE(arena_sizeof_invariant_suite, TA_UT_LABEL_SERIAL) + +BOOST_AUTO_TEST_CASE(tensor_double_sizeof_matches_master) { + BOOST_CHECK_EQUAL(sizeof(TA::Tensor), kTensorDoubleSizeMaster); +} + +BOOST_AUTO_TEST_CASE(contract_reduce_base_sizeof_matches_master) { + BOOST_CHECK_EQUAL(sizeof(PlainContractReduceBase), + kContractReduceBaseSizeMaster); +} + +BOOST_AUTO_TEST_CASE(impl_layout_no_unique_address_invariant) { + BOOST_CHECK_EQUAL(sizeof(ImplLayoutMaster), kImplLayoutSizeMaster); + BOOST_CHECK_EQUAL(sizeof(ImplLayoutAllocator), kImplLayoutSizeMaster); + BOOST_CHECK_EQUAL(sizeof(ImplLayoutAllocator), sizeof(ImplLayoutMaster)); +} + +BOOST_AUTO_TEST_CASE(plain_arena_plan_storage_is_monostate) { + BOOST_CHECK((std::is_same_v)); + BOOST_CHECK_EQUAL(sizeof(PlainArenaPlanStorage), sizeof(std::monostate)); +} + +BOOST_AUTO_TEST_SUITE_END() From a034afd2df62cb1e3324aafb40055cb902d35797 Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 11 May 2026 20:34:31 -0400 Subject: [PATCH 6/7] einsum + tests/cases: hook regime-A arena into einsum + add hec_* case binaries --- src/TiledArray/einsum/tiledarray.h | 6 + tests/CMakeLists.txt | 5 + tests/cases/CMakeLists.txt | 15 + tests/cases/case_4d_e.cpp | 61 ++++ tests/cases/case_common.h | 528 +++++++++++++++++++++++++++++ tests/cases/case_hec_e.cpp | 40 +++ tests/cases/case_hec_ec.cpp | 38 +++ tests/cases/case_hec_h.cpp | 35 ++ tests/cases/case_hec_scale.cpp | 35 ++ 9 files changed, 763 insertions(+) create mode 100644 tests/cases/CMakeLists.txt create mode 100644 tests/cases/case_4d_e.cpp create mode 100644 tests/cases/case_common.h create mode 100644 tests/cases/case_hec_e.cpp create mode 100644 tests/cases/case_hec_ec.cpp create mode 100644 tests/cases/case_hec_h.cpp create mode 100644 tests/cases/case_hec_scale.cpp diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index e0b93886ee..daa9ec96a1 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -7,6 +7,7 @@ #include "TiledArray/einsum/range.h" #include "TiledArray/expressions/fwd.h" #include "TiledArray/fwd.h" +#include "TiledArray/tensor/arena_einsum.h" #include "TiledArray/tiled_range.h" #include "TiledArray/tiled_range1.h" @@ -687,6 +688,8 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, auto pa = A.permutation; auto pb = B.permutation; + auto arena_plan = detail::make_regime_a_arena_plan( + A, B, inner, /*inner_perm=*/C.permutation); for (Index h : H.tiles) { auto const pc = C.permutation; auto const c = apply(pc, h); @@ -695,6 +698,9 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, for (size_t i = 0; i < h.size(); ++i) { batch *= H.batch[i].at(h[i]); } + if (detail::run_regime_a_arena(arena_plan, h, batch, A, B, C, + C_local_tiles, tiles, trange)) + continue; ResultTensor tile(TiledArray::Range{batch}, typename ResultTensor::value_type{}); for (Index i : tiles) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 5635bb23d6..a7328a538a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -190,3 +190,8 @@ else() ENVIRONMENT "${TA_UNIT_TESTS_ENVIRONMENT}" ) endif() + +if (NOT TARGET test-cases-tiledarray) + add_custom_target_subproject(tiledarray test-cases) +endif() +add_subdirectory(cases) diff --git a/tests/cases/CMakeLists.txt b/tests/cases/CMakeLists.txt new file mode 100644 index 0000000000..8cc4721163 --- /dev/null +++ b/tests/cases/CMakeLists.txt @@ -0,0 +1,15 @@ +# hec_* + 4d_e per-cell case binaries (arena vs heap). + +set(_cases + case_hec_h + case_hec_e + case_hec_ec + case_hec_scale + case_4d_e +) + +foreach(_case ${_cases}) + add_ta_executable(${_case} "${_case}.cpp" "tiledarray") + target_include_directories(${_case} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) + add_dependencies(test-cases-tiledarray ${_case}) +endforeach() diff --git a/tests/cases/case_4d_e.cpp b/tests/cases/case_4d_e.cpp new file mode 100644 index 0000000000..df1f47b6ea --- /dev/null +++ b/tests/cases/case_4d_e.cpp @@ -0,0 +1,61 @@ +/// 4d_e: outer-4D x outer-3D with one Hadamard, one contracted, three free. + +#include "case_common.h" + +#include + +namespace c = cases; + +namespace { + +/// Deterministic truncated-exponential inner-size, mean ~10, cap 50. +inline long a_size(long p, long q) { + unsigned long h = + (static_cast(p) * 73ULL + + static_cast(q) * 113ULL + 17ULL) * 2654435761ULL; + double u = static_cast(h & 0x7FFFFFFFUL) / + static_cast(0x80000000UL); + double x = -10.0 * std::log(1.0 - u); + if (x > 50.0) x = 50.0; + return static_cast(x); +} + +} // namespace + +struct Ops { + c::ToT lhs; + c::ToT rhs; +}; + +int main(int argc, char** argv) { + constexpr int I = 20; + constexpr int M = 50; + constexpr int K = 100; + + auto sl = [](long q, long p, long /*m*/, long /*k*/) { + return TiledArray::Range{a_size(p, q)}; + }; + auto sr = [](long r, long q, long /*m*/) { + return TiledArray::Range{a_size(q, r)}; + }; + + return c::run_case_main_split( + argc, argv, "4d_e", + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_4d_jagged(w, I, I, M, K, 1.0, sl); + ops.rhs = c::make_tot_3d_jagged(w, I, I, M, 100.0, sr); + return ops; + }, + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_4d_jagged_slab(w, I, I, M, K, 1.0, sl); + ops.rhs = c::make_tot_3d_jagged_slab(w, I, I, M, 100.0, sr); + return ops; + }, + [&](Ops& ops) { + return TiledArray::einsum(ops.lhs("q,p,m,k;s"), + ops.rhs("r,q,m;t"), + "p,r,q,k;s,t"); + }); +} diff --git a/tests/cases/case_common.h b/tests/cases/case_common.h new file mode 100644 index 0000000000..b097eb56cd --- /dev/null +++ b/tests/cases/case_common.h @@ -0,0 +1,528 @@ +/// Shared bench helpers for arena-vs-heap case binaries. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cases { + +namespace TA = ::TiledArray; + +using inner_t = TA::Tensor; +using tile_t = TA::Tensor; +using ToT = TA::DistArray; +using Plain = TA::DistArray; + +inline int& g_tile_grid() { + static int v = 7; + return v; +} + +/// Stores the h-dimension scale set by --h-scale. +inline int& g_h_scale() { + static int v = 1; + return v; +} + +inline std::vector tile_breaks(int n, int ntiles) { + if (ntiles <= 1 || n <= 0) + return {0, static_cast(std::max(n, 0))}; + std::vector b; + b.reserve(ntiles + 1); + const int chunk = n / ntiles; + for (int t = 0; t < ntiles; ++t) { + b.push_back(static_cast(t * chunk)); + } + b.push_back(static_cast(n)); + std::vector uniq; + for (auto x : b) + if (uniq.empty() || uniq.back() != x) uniq.push_back(x); + return uniq; +} + +inline TA::TiledRange1 tr1_dim(int n) { + auto b = tile_breaks(n, g_tile_grid()); + return TA::TiledRange1(b.begin(), b.end()); +} + +inline TA::TiledRange tr3(int a, int b, int c) { + return TA::TiledRange{tr1_dim(a), tr1_dim(b), tr1_dim(c)}; +} + +inline TA::TiledRange tr4(int a, int b, int c, int d) { + return TA::TiledRange{tr1_dim(a), tr1_dim(b), tr1_dim(c), tr1_dim(d)}; +} + +/// Builds a 3-D slab-backed jagged ToT. +template +ToT make_tot_3d_jagged_slab(TA::World& world, int A, int B, int C, + double offset, Fn inner_fn) { + ToT out(world, tr3(A, B, C)); + out.init_tiles([offset, inner_fn](const TA::Range& tile_range) { + const std::size_t n_cells = tile_range.volume(); + std::vector ranges; + ranges.reserve(n_cells); + std::vector cell_offsets(n_cells); + std::size_t total_elems = 0; + { + std::size_t ord = 0; + for (auto outer_idx : tile_range) { + const long o0 = static_cast(outer_idx[0]); + const long o1 = static_cast(outer_idx[1]); + const long o2 = static_cast(outer_idx[2]); + TA::Range ir = inner_fn(o0, o1, o2); + cell_offsets[ord] = total_elems; + const std::size_t vol = ir.volume(); + const std::size_t padded = (vol + 7) & ~std::size_t{7}; + total_elems += padded; + ranges.push_back(std::move(ir)); + ++ord; + } + } + + std::shared_ptr slab; + if (total_elems > 0) { + void* raw = nullptr; + if (posix_memalign(&raw, 64, total_elems * sizeof(double)) != 0) { + std::abort(); + } + slab = std::shared_ptr(static_cast(raw), + [](double* p) { std::free(p); }); + } + + tile_t tile(tile_range); + std::size_t ord = 0; + for (auto outer_idx : tile_range) { + const long o0 = static_cast(outer_idx[0]); + const long o1 = static_cast(outer_idx[1]); + const long o2 = static_cast(outer_idx[2]); + auto& ir = ranges[ord]; + const std::size_t vol = ir.volume(); + if (vol == 0) { + *(tile.data() + ord) = inner_t{}; + } else { + std::shared_ptr alias(slab, + slab.get() + cell_offsets[ord]); + for (std::size_t k = 0; k < vol; ++k) + alias[k] = offset + 1e-4 * static_cast( + o0 * 100000 + o1 * 1000 + o2 * 100 + k); + *(tile.data() + ord) = inner_t(ir, std::move(alias)); + } + ++ord; + } + return tile; + }); + world.gop.fence(); + return out; +} + +/// Builds a 3-D heap-scattered jagged ToT. +template +ToT make_tot_3d_jagged(TA::World& world, int A, int B, int C, double offset, + Fn inner_fn) { + ToT out(world, tr3(A, B, C)); + out.init_tiles([offset, inner_fn](const TA::Range& tile_range) { + tile_t tile(tile_range); + std::size_t ord = 0; + for (auto outer_idx : tile_range) { + const long o0 = static_cast(outer_idx[0]); + const long o1 = static_cast(outer_idx[1]); + const long o2 = static_cast(outer_idx[2]); + TA::Range ir = inner_fn(o0, o1, o2); + const std::size_t vol = ir.volume(); + if (vol == 0) { + *(tile.data() + ord) = inner_t{}; + } else { + inner_t inner(ir); + for (std::size_t k = 0; k < vol; ++k) + inner.at_ordinal(k) = + offset + 1e-4 * static_cast( + o0 * 100000 + o1 * 1000 + o2 * 100 + k); + *(tile.data() + ord) = std::move(inner); + } + ++ord; + } + return tile; + }); + world.gop.fence(); + return out; +} + +/// Builds a 4-D slab-backed jagged ToT. +template +ToT make_tot_4d_jagged_slab(TA::World& world, int A, int B, int C, int D, + double offset, Fn inner_fn) { + ToT out(world, tr4(A, B, C, D)); + out.init_tiles([offset, inner_fn](const TA::Range& tile_range) { + const std::size_t n_cells = tile_range.volume(); + std::vector ranges; + ranges.reserve(n_cells); + std::vector cell_offsets(n_cells); + std::size_t total_elems = 0; + { + std::size_t ord = 0; + for (auto outer_idx : tile_range) { + const long o0 = static_cast(outer_idx[0]); + const long o1 = static_cast(outer_idx[1]); + const long o2 = static_cast(outer_idx[2]); + const long o3 = static_cast(outer_idx[3]); + TA::Range ir = inner_fn(o0, o1, o2, o3); + cell_offsets[ord] = total_elems; + const std::size_t vol = ir.volume(); + const std::size_t padded = (vol + 7) & ~std::size_t{7}; + total_elems += padded; + ranges.push_back(std::move(ir)); + ++ord; + } + } + std::shared_ptr slab; + if (total_elems > 0) { + void* raw = nullptr; + if (posix_memalign(&raw, 64, total_elems * sizeof(double)) != 0) { + std::abort(); + } + slab = std::shared_ptr(static_cast(raw), + [](double* p) { std::free(p); }); + } + tile_t tile(tile_range); + std::size_t ord = 0; + for (auto outer_idx : tile_range) { + const long o0 = static_cast(outer_idx[0]); + const long o1 = static_cast(outer_idx[1]); + const long o2 = static_cast(outer_idx[2]); + const long o3 = static_cast(outer_idx[3]); + auto& ir = ranges[ord]; + const std::size_t vol = ir.volume(); + if (vol == 0) { + *(tile.data() + ord) = inner_t{}; + } else { + std::shared_ptr alias(slab, + slab.get() + cell_offsets[ord]); + for (std::size_t k = 0; k < vol; ++k) + alias[k] = offset + 1e-4 * static_cast( + o0 * 1000000 + o1 * 10000 + + o2 * 100 + o3 * 10 + k); + *(tile.data() + ord) = inner_t(ir, std::move(alias)); + } + ++ord; + } + return tile; + }); + world.gop.fence(); + return out; +} + +/// Builds a 4-D heap-scattered jagged ToT. +template +ToT make_tot_4d_jagged(TA::World& world, int A, int B, int C, int D, + double offset, Fn inner_fn) { + ToT out(world, tr4(A, B, C, D)); + out.init_tiles([offset, inner_fn](const TA::Range& tile_range) { + tile_t tile(tile_range); + std::size_t ord = 0; + for (auto outer_idx : tile_range) { + const long o0 = static_cast(outer_idx[0]); + const long o1 = static_cast(outer_idx[1]); + const long o2 = static_cast(outer_idx[2]); + const long o3 = static_cast(outer_idx[3]); + TA::Range ir = inner_fn(o0, o1, o2, o3); + const std::size_t vol = ir.volume(); + if (vol == 0) { + *(tile.data() + ord) = inner_t{}; + } else { + inner_t inner(ir); + for (std::size_t k = 0; k < vol; ++k) + inner.at_ordinal(k) = + offset + 1e-4 * static_cast( + o0 * 1000000 + o1 * 10000 + o2 * 100 + + o3 * 10 + k); + *(tile.data() + ord) = std::move(inner); + } + ++ord; + } + return tile; + }); + world.gop.fence(); + return out; +} + +inline Plain make_plain_3d(TA::World& world, int A, int B, int C, + double offset) { + Plain out(world, tr3(A, B, C)); + out.init_tiles([offset](const TA::Range& r) { + inner_t tile(r); + for (std::size_t k = 0; k < r.volume(); ++k) + tile.at_ordinal(k) = offset + 1e-3 * static_cast(k); + return tile; + }); + world.gop.fence(); + return out; +} + +inline double max_abs_diff(const ToT& a, const ToT& b) { + if (a.trange() != b.trange()) return 1e30; + double mx = 0.0; + const auto& tr = a.trange(); + for (auto t = tr.tiles_range().begin(); t != tr.tiles_range().end(); ++t) { + if (!a.is_local(*t)) continue; + auto ta = a.find(*t).get(); + auto tb = b.find(*t).get(); + if (ta.range().volume() != tb.range().volume()) return 1e30; + for (std::size_t ord = 0; ord < ta.range().volume(); ++ord) { + const auto& ia = *(ta.data() + ord); + const auto& ib = *(tb.data() + ord); + if (ia.range().volume() != ib.range().volume()) { + if (ia.range().volume() == 0 || ib.range().volume() == 0) { + mx = std::max(mx, 1.0); + continue; + } + return 1e30; + } + for (std::size_t k = 0; k < ia.range().volume(); ++k) { + double d = std::abs(ia.at_ordinal(k) - ib.at_ordinal(k)); + if (d > mx) mx = d; + } + } + } + return mx; +} + +struct RunResult { + double wall_ns_min = 0.0; + double wall_ns_med = 0.0; + ToT result; + bool ok = true; + std::string err; +}; + +template +RunResult time_run(TA::World& world, Runner&& run, bool disable_arena, + int repeats) { + RunResult R; + std::vector ns; + ns.reserve(repeats); + for (int r = 0; r < repeats; ++r) { + TA::detail::arena_disabled() = disable_arena; + world.gop.fence(); + auto t0 = std::chrono::steady_clock::now(); + try { + R.result = run(); + world.gop.fence(); + } catch (std::exception& e) { + R.ok = false; + R.err = e.what(); + return R; + } catch (...) { + R.ok = false; + R.err = "unknown"; + return R; + } + auto t1 = std::chrono::steady_clock::now(); + ns.push_back( + std::chrono::duration_cast(t1 - t0).count()); + } + std::sort(ns.begin(), ns.end()); + R.wall_ns_min = ns.front(); + R.wall_ns_med = ns[ns.size() / 2]; + return R; +} + +/// Runs a case binary by building operands once and timing one mode. +template +int run_case_main(int argc, char** argv, const char* case_name, Build build, + Run run) { + // Heap and arena timings must run in separate processes to avoid allocator/cache bias. + std::string mode; + int repeats = 3; + bool quiet = false; + int tile_grid = 7; + for (int i = 1; i < argc; ++i) { + std::string a = argv[i]; + if (a == "--mode" && i + 1 < argc) { + mode = argv[++i]; + } else if (a == "--repeat" && i + 1 < argc) { + repeats = std::atoi(argv[++i]); + } else if (a == "--tile-grid" && i + 1 < argc) { + tile_grid = std::max(1, std::atoi(argv[++i])); + } else if (a == "--h-scale" && i + 1 < argc) { + g_h_scale() = std::max(1, std::atoi(argv[++i])); + } else if (a == "--quiet") { + quiet = true; + } else if (a == "-h" || a == "--help") { + std::cout + << "Usage: " << argv[0] + << " --mode {heap|arena} [--tile-grid G] [--h-scale S] " + "[--repeat R] [--quiet]\n" + "MAD_NUM_THREADS env var controls thread count.\n" + "Note: --mode is required. heap and arena MUST be benchmarked\n" + "in separate processes — running both in one process biases the\n" + "second run via allocator fragmentation and cache residue.\n"; + return 0; + } + } + if (mode != "heap" && mode != "arena") { + std::cerr << "error: --mode must be 'heap' or 'arena' (got '" + << mode << "')\n"; + return 2; + } + g_tile_grid() = tile_grid; + + TA::World& world = TA::initialize(argc, argv); + + const char* threads_env = std::getenv("MAD_NUM_THREADS"); + std::string threads_label = threads_env ? threads_env : "default"; + + std::cout << "case,mode,tile_grid,threads,wall_ns_min,wall_ns_med,verified\n"; + + if (!quiet) { + std::cerr << "# " << case_name << " tile_grid=" << tile_grid + << " h_scale=" << g_h_scale() + << " threads=" << threads_label << "\n"; + } + + auto operands = build(world); + + auto emit = [&](const char* m, const RunResult& R, const std::string& v) { + if (!R.ok) { + std::cout << case_name << "," << m << "," << tile_grid << "," + << threads_label << ",NA,NA,err:" << R.err << "\n"; + return; + } + std::cout << case_name << "," << m << "," << tile_grid << "," + << threads_label << "," << static_cast(R.wall_ns_min) + << "," << static_cast(R.wall_ns_med) << "," << v + << "\n"; + }; + + if (mode == "heap") { + auto Rh = time_run( + world, [&]() { return run(operands); }, true, + repeats); + emit("heap", Rh, "single"); + if (!quiet) { + std::cerr << " heap=" << Rh.wall_ns_med / 1e6 << "ms\n"; + } + } else { + auto Ra = time_run( + world, [&]() { return run(operands); }, false, + repeats); + emit("arena", Ra, "single"); + if (!quiet) { + std::cerr << " arena=" << Ra.wall_ns_med / 1e6 << "ms\n"; + } + } + + std::cout.flush(); + TA::detail::arena_disabled() = false; + TA::finalize(); + return 0; +} + +/// Runs a case binary with separate heap-scatter and arena-slab input builders. +template +int run_case_main_split(int argc, char** argv, const char* case_name, + BuildHeap build_heap, BuildArena build_arena, + Run run) { + // Heap and arena timings must run in separate processes to avoid allocator/cache bias. + std::string mode; + int repeats = 3; + bool quiet = false; + int tile_grid = 7; + for (int i = 1; i < argc; ++i) { + std::string a = argv[i]; + if (a == "--mode" && i + 1 < argc) { + mode = argv[++i]; + } else if (a == "--repeat" && i + 1 < argc) { + repeats = std::atoi(argv[++i]); + } else if (a == "--tile-grid" && i + 1 < argc) { + tile_grid = std::max(1, std::atoi(argv[++i])); + } else if (a == "--h-scale" && i + 1 < argc) { + g_h_scale() = std::max(1, std::atoi(argv[++i])); + } else if (a == "--quiet") { + quiet = true; + } else if (a == "-h" || a == "--help") { + std::cout << "Usage: " << argv[0] + << " --mode {heap|arena} [--tile-grid G] [--h-scale S] " + "[--repeat R] [--quiet]\n" + "Heap mode uses scattered (legacy) inputs; arena mode " + "uses slab-backed inputs.\n" + "Note: --mode is required. heap and arena MUST be " + "benchmarked in separate\n" + "processes — running both in one process biases the " + "second run via allocator\n" + "fragmentation and cache residue.\n"; + return 0; + } + } + if (mode != "heap" && mode != "arena") { + std::cerr << "error: --mode must be 'heap' or 'arena' (got '" + << mode << "')\n"; + return 2; + } + g_tile_grid() = tile_grid; + + TA::World& world = TA::initialize(argc, argv); + + const char* threads_env = std::getenv("MAD_NUM_THREADS"); + std::string threads_label = threads_env ? threads_env : "default"; + + std::cout << "case,mode,tile_grid,threads,wall_ns_min,wall_ns_med,verified\n"; + + if (!quiet) { + std::cerr << "# " << case_name << " tile_grid=" << tile_grid + << " h_scale=" << g_h_scale() + << " threads=" << threads_label + << " (split inputs: heap=scatter, arena=slab)\n"; + } + + auto emit = [&](const char* m, const RunResult& R, const std::string& v) { + if (!R.ok) { + std::cout << case_name << "," << m << "," << tile_grid << "," + << threads_label << ",NA,NA,err:" << R.err << "\n"; + return; + } + std::cout << case_name << "," << m << "," << tile_grid << "," + << threads_label << "," << static_cast(R.wall_ns_min) + << "," << static_cast(R.wall_ns_med) << "," << v + << "\n"; + }; + + if (mode == "heap") { + auto operands = build_heap(world); + auto Rh = time_run( + world, [&]() { return run(operands); }, true, + repeats); + emit("heap", Rh, "single"); + if (!quiet) std::cerr << " heap=" << Rh.wall_ns_med / 1e6 << "ms\n"; + } else { + auto operands = build_arena(world); + auto Ra = time_run( + world, [&]() { return run(operands); }, false, + repeats); + emit("arena", Ra, "single"); + if (!quiet) std::cerr << " arena=" << Ra.wall_ns_med / 1e6 << "ms\n"; + } + + std::cout.flush(); + TA::detail::arena_disabled() = false; + TA::finalize(); + return 0; +} + +} diff --git a/tests/cases/case_hec_e.cpp b/tests/cases/case_hec_e.cpp new file mode 100644 index 0000000000..83e767e4f2 --- /dev/null +++ b/tests/cases/case_hec_e.cpp @@ -0,0 +1,40 @@ +/// hec_e: A(h,i,j;m) * B(h,j,k;n) -> C(h,i,k;m,n); inner outer-product (i, k). + +#include "case_common.h" + +namespace c = cases; + +struct Ops { + c::ToT lhs; + c::ToT rhs; +}; + +int main(int argc, char** argv) { + constexpr int N = 30; + auto sl = [](long /*h*/, long i, long /*j*/) { + return TiledArray::Range{i}; + }; + auto sr = [](long /*h*/, long /*j*/, long k) { + return TiledArray::Range{k}; + }; + return c::run_case_main_split( + argc, argv, "hec_e", + [&](TiledArray::World& w) { + const int H = N * c::g_h_scale(); + Ops ops; + ops.lhs = c::make_tot_3d_jagged(w, H, N, N, 1.0, sl); + ops.rhs = c::make_tot_3d_jagged(w, H, N, N, 100.0, sr); + return ops; + }, + [&](TiledArray::World& w) { + const int H = N * c::g_h_scale(); + Ops ops; + ops.lhs = c::make_tot_3d_jagged_slab(w, H, N, N, 1.0, sl); + ops.rhs = c::make_tot_3d_jagged_slab(w, H, N, N, 100.0, sr); + return ops; + }, + [&](Ops& ops) { + return TiledArray::einsum(ops.lhs("h,i,j;m"), ops.rhs("h,j,k;n"), + "h,i,k;m,n"); + }); +} diff --git a/tests/cases/case_hec_ec.cpp b/tests/cases/case_hec_ec.cpp new file mode 100644 index 0000000000..857c8fcba8 --- /dev/null +++ b/tests/cases/case_hec_ec.cpp @@ -0,0 +1,38 @@ +/// hec_ec: A(h,i,j;m,p) * B(h,j,k;p,n) -> C(h,i,k;m,n); inner contracts p. + +#include "case_common.h" + +namespace c = cases; + +struct Ops { + c::ToT lhs; + c::ToT rhs; +}; + +int main(int argc, char** argv) { + constexpr int N = 60; + auto sl = [](long /*h*/, long i, long j) { + return TiledArray::Range{i, j}; + }; + auto sr = [](long /*h*/, long j, long k) { + return TiledArray::Range{j, k}; + }; + return c::run_case_main_split( + argc, argv, "hec_ec", + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_3d_jagged(w, N, N, N, 1.0, sl); + ops.rhs = c::make_tot_3d_jagged(w, N, N, N, 100.0, sr); + return ops; + }, + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_3d_jagged_slab(w, N, N, N, 1.0, sl); + ops.rhs = c::make_tot_3d_jagged_slab(w, N, N, N, 100.0, sr); + return ops; + }, + [&](Ops& ops) { + return TiledArray::einsum(ops.lhs("h,i,j;m,p"), ops.rhs("h,j,k;p,n"), + "h,i,k;m,n"); + }); +} diff --git a/tests/cases/case_hec_h.cpp b/tests/cases/case_hec_h.cpp new file mode 100644 index 0000000000..8b4da12071 --- /dev/null +++ b/tests/cases/case_hec_h.cpp @@ -0,0 +1,35 @@ +/// hec_h: A(h,i,j;m,n) * B(h,j,k;m,n) -> C(h,i,k;m,n); inner = (h, h). + +#include "case_common.h" + +namespace c = cases; + +struct Ops { + c::ToT lhs; + c::ToT rhs; +}; + +int main(int argc, char** argv) { + constexpr int N = 56; + auto sf = [](long h, long /*o1*/, long /*o2*/) { + return TiledArray::Range{h, h}; + }; + return c::run_case_main_split( + argc, argv, "hec_h", + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_3d_jagged(w, N, N, N, /*offset=*/1.0, sf); + ops.rhs = c::make_tot_3d_jagged(w, N, N, N, /*offset=*/100.0, sf); + return ops; + }, + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_3d_jagged_slab(w, N, N, N, /*offset=*/1.0, sf); + ops.rhs = c::make_tot_3d_jagged_slab(w, N, N, N, /*offset=*/100.0, sf); + return ops; + }, + [&](Ops& ops) { + return TiledArray::einsum(ops.lhs("h,i,j;m,n"), ops.rhs("h,j,k;m,n"), + "h,i,k;m,n"); + }); +} diff --git a/tests/cases/case_hec_scale.cpp b/tests/cases/case_hec_scale.cpp new file mode 100644 index 0000000000..34d399e323 --- /dev/null +++ b/tests/cases/case_hec_scale.cpp @@ -0,0 +1,35 @@ +/// hec_scale: A(h,i,j;m,n) * B_plain(h,j,k) -> C(h,i,k;m,n); inner scale. + +#include "case_common.h" + +namespace c = cases; + +struct Ops { + c::ToT lhs; + c::Plain rhs; +}; + +int main(int argc, char** argv) { + constexpr int N = 56; + auto sl = [](long /*h*/, long i, long /*j*/) { + return TiledArray::Range{i, i}; + }; + return c::run_case_main_split( + argc, argv, "hec_scale", + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_3d_jagged(w, N, N, N, 1.0, sl); + ops.rhs = c::make_plain_3d(w, N, N, N, 0.5); + return ops; + }, + [&](TiledArray::World& w) { + Ops ops; + ops.lhs = c::make_tot_3d_jagged_slab(w, N, N, N, 1.0, sl); + ops.rhs = c::make_plain_3d(w, N, N, N, 0.5); + return ops; + }, + [&](Ops& ops) { + return TiledArray::einsum(ops.lhs("h,i,j;m,n"), ops.rhs("h,j,k"), + "h,i,k;m,n"); + }); +} From d0d26f189edb652aa2e16ad3ed333a12b08a54ed Mon Sep 17 00:00:00 2001 From: Zhihao Deng Date: Mon, 11 May 2026 20:52:13 -0400 Subject: [PATCH 7/7] review fixes: portable sizeof gate, explicit plan-move, alignment intent - arena_sizeof_invariant_suite: drop platform-specific absolute baselines (328/16/248 were Apple-arm64/libc++ only); keep relative ImplLayoutAllocator == ImplLayoutMaster invariant + monostate static_asserts. - cont_engine: reset arena_plan_ after std::move into op_ so later reads see "no plan" rather than a moved-from optional. - arena_kernels: one-line intent note on trivial kernels' tight packing. --- src/TiledArray/expressions/cont_engine.h | 10 ++++++++++ src/TiledArray/tensor/arena_kernels.h | 2 ++ tests/arena_sizeof_invariant_suite.cpp | 21 --------------------- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/TiledArray/expressions/cont_engine.h b/src/TiledArray/expressions/cont_engine.h index 83b97e5f6d..2fb2763474 100644 --- a/src/TiledArray/expressions/cont_engine.h +++ b/src/TiledArray/expressions/cont_engine.h @@ -306,6 +306,11 @@ class ContEngine : public BinaryEngine { outer_size(left_indices_), outer_size(right_indices_), total_perm, this->element_nonreturn_op_, std::move(this->arena_plan_)); + // Plan ownership transferred to op_; mark carrier slot empty so any + // later use of arena_plan_ reads as "no plan" rather than moved-from. + if constexpr (!std::is_same_v) { + this->arena_plan_.reset(); + } } trange_ = ContEngine_::make_trange(outer_perm); shape_ = ContEngine_::make_shape(outer_perm); @@ -337,6 +342,11 @@ class ContEngine : public BinaryEngine { outer_size(left_indices_), outer_size(right_indices_), total_perm, this->element_nonreturn_op_, std::move(this->arena_plan_)); + // Plan ownership transferred to op_; mark carrier slot empty so any + // later use of arena_plan_ reads as "no plan" rather than moved-from. + if constexpr (!std::is_same_v) { + this->arena_plan_.reset(); + } } trange_ = ContEngine_::make_trange(); shape_ = ContEngine_::make_shape(); diff --git a/src/TiledArray/tensor/arena_kernels.h b/src/TiledArray/tensor/arena_kernels.h index a9f774a657..d3d53b964e 100644 --- a/src/TiledArray/tensor/arena_kernels.h +++ b/src/TiledArray/tensor/arena_kernels.h @@ -47,6 +47,7 @@ OuterTensor arena_trivial_unary(const SrcOuterTensor& src, FillOp&& fill_op) { auto shape_fn = [&src](std::size_t ord) -> decltype(auto) { return src.data()[ord].range(); }; + // Elementwise kernels: pack tight (no cross-cell GEMM to amortize 128B pad). ArenaPlan p = plan(N_cells, shape_fn, sizeof(elem_t), alignof(elem_t)); auto arena = std::make_shared(); if (p.total_bytes > 0) arena->reserve(p.total_bytes, false); @@ -80,6 +81,7 @@ OuterTensor arena_trivial_binary(const LeftTensor& left, const RightTensor& righ auto shape_fn = [&left](std::size_t ord) -> decltype(auto) { return left.data()[ord].range(); }; + // Elementwise kernels: pack tight (no cross-cell GEMM to amortize 128B pad). ArenaPlan p = plan(N_cells, shape_fn, sizeof(elem_t), alignof(elem_t)); auto arena = std::make_shared(); if (p.total_bytes > 0) arena->reserve(p.total_bytes, false); diff --git a/tests/arena_sizeof_invariant_suite.cpp b/tests/arena_sizeof_invariant_suite.cpp index 047f62f1ff..649e3a50c5 100644 --- a/tests/arena_sizeof_invariant_suite.cpp +++ b/tests/arena_sizeof_invariant_suite.cpp @@ -2,7 +2,6 @@ #include "TiledArray/tensor.h" #include "TiledArray/tensor/arena_einsum.h" -#include "TiledArray/tile_op/contract_reduce.h" #include "TiledArray/util/function.h" #include "tiledarray.h" #include "unit_test_config.h" @@ -20,10 +19,6 @@ using PlainLeft = TA::Tensor; using PlainRight = TA::Tensor; using PlainScalar = double; -using PlainContractReduceBase = - TA::detail::ContractReduceBase; - using PlainArenaPlanStorage = TA::detail::arena_plan_storage_t; @@ -47,11 +42,6 @@ struct ImplLayoutAllocator { TA_NO_UNIQUE_ADDRESS PlainArenaPlanStorage arena_plan_; }; -/// Sizes captured against master; re-baseline if the toolchain changes. -constexpr std::size_t kTensorDoubleSizeMaster = 328; -constexpr std::size_t kContractReduceBaseSizeMaster = 16; -constexpr std::size_t kImplLayoutSizeMaster = 248; - static_assert(std::is_same_v, "plain-tensor arena_plan_storage_t must be std::monostate"); @@ -62,18 +52,7 @@ static_assert(sizeof(ImplLayoutAllocator) == sizeof(ImplLayoutMaster), BOOST_AUTO_TEST_SUITE(arena_sizeof_invariant_suite, TA_UT_LABEL_SERIAL) -BOOST_AUTO_TEST_CASE(tensor_double_sizeof_matches_master) { - BOOST_CHECK_EQUAL(sizeof(TA::Tensor), kTensorDoubleSizeMaster); -} - -BOOST_AUTO_TEST_CASE(contract_reduce_base_sizeof_matches_master) { - BOOST_CHECK_EQUAL(sizeof(PlainContractReduceBase), - kContractReduceBaseSizeMaster); -} - BOOST_AUTO_TEST_CASE(impl_layout_no_unique_address_invariant) { - BOOST_CHECK_EQUAL(sizeof(ImplLayoutMaster), kImplLayoutSizeMaster); - BOOST_CHECK_EQUAL(sizeof(ImplLayoutAllocator), kImplLayoutSizeMaster); BOOST_CHECK_EQUAL(sizeof(ImplLayoutAllocator), sizeof(ImplLayoutMaster)); }