From 604e7a8eaff7f00278e473c7fd7eebf010587c5f Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Tue, 12 May 2026 16:36:22 -0400 Subject: [PATCH 1/5] build: bump BTAS pin to pick up Tensor (Range, generator) ctor BTAS 7e64fbad adds a (Range, F) ctor on btas::Tensor that mirrors TA::Tensor's range+lambda ctor. Needed for tile-type-agnostic per-index inner-tile construction (e.g. MPQC's jacobi_update for btas-inner ToT amplitudes). --- external/versions.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/versions.cmake b/external/versions.cmake index 2545dff39d..288fa3bd3d 100644 --- a/external/versions.cmake +++ b/external/versions.cmake @@ -17,8 +17,8 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG 7d8aaf9d51981e4accf4d84742270d1473f8ca2e) set(TA_TRACKED_MADNESS_VERSION 0.10.1) set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1) -set(TA_TRACKED_BTAS_TAG 287b145ead818a0332f2b7ce0b7375a83d328bae) -set(TA_TRACKED_BTAS_PREVIOUS_TAG 62d57d9b1e0c733b4b547bc9cfdd07047159dbca) +set(TA_TRACKED_BTAS_TAG 7e64fbad97c76f316f313f4c8ed3fca5445da15f) +set(TA_TRACKED_BTAS_PREVIOUS_TAG 287b145ead818a0332f2b7ce0b7375a83d328bae) set(TA_TRACKED_LIBRETT_TAG 6eed30d4dd2a5aa58840fe895dcffd80be7fbece) set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 354e0ccee54aeb2f191c3ce2c617ebf437e49d83) From 53dee008a4d04bf4a49ed6219cad5bdbd84e9d60 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Tue, 12 May 2026 16:45:28 -0400 Subject: [PATCH 2/5] enable btas::Tensor as inner tile of TA::Tensor-of-Tensor Closes the remaining gaps so a TA::DistArray>, Policy> (btas-inner ToT) is usable end-to-end alongside the existing TA-inner ToT, including through einsum's ToT * T and ToT * ToT paths. The outer tile is always TA::Tensor; btas::Tensor is *inner-only* and remains free of outer-tile operations (permute/reshape/batch/...). Notable changes: - external/btas.h: - nested_rank> partial spec so einsum's MaxNestedArray correctly classifies btas-inner ToT. - 6-arg gemm(alpha, A, B, beta, C, helper) overload matching the TA::Tensor signature (the existing 5-arg form is accumulate-only). - size_of(btas::Tensor) in namespace btas so ADL finds it from TA::Tensor's recursive size_of when the inner tile is btas. - tensor/operators.h + new operators_body.ipp: - Lift T+T/T-T/T*T/-T/T*N/N*T/T+=T/T-=T/T*=T/T*=N/Perm*T and the contiguous-tensor operator<< into a shared .ipp body, included once into namespace TiledArray and once into namespace btas, gated by disjoint per-namespace ta_ops_match_tensor_v predicates. Fixes ADL of these operators inside TA::Tensor's lambdas for btas inner. - tensor/print.h + .ipp: - Split NDArrayPrinter::print's Index template param into ExtentIndex / StrideIndex so btas::Tensor (whose range exposes extent_data as unsigned-long* but stride_data as long-long*) drives the printer. - tensor/tensor.h: - Fix broken SFINAE in Tensor::subt template (typename = std::enable_if<...> missing the _t). - tensor/type_traits.h: - Rewrite result_tensor_helper to derive the result via TensorA::rebind_t (the "rebind allocator on numeric" operation both TA::Tensor and btas::Tensor expose). Drops the requirement that the input expose allocator_type. - tensor/kernels.h: - tensor_contract / tensor_hadamard: member-style .permute(), .mult(), .mult_to() -> free CPO calls so ADL dispatches via namespace btas for btas tiles and tile_op/tile_interface.h for TA. - einsum/tiledarray.h: - DeNestedArray for ToT inputs now wraps the inner numeric in TA::Tensor rather than re-using the inner tile type as a new outer tile (preserves the "btas is inner-only" rule across DeNest). - sum_tot_2_tos lambda likewise produces TA::Tensor, and replaces the tot(ix).sum() member with an unqualified sum() so ADL finds the right overload (free fn in TA / namespace btas). - dist_array.h: - volume(): reduce_op uses arg->total_size() when the inner tile exposes it (TA::Tensor), else falls back to arg->size() (btas). - tests/btas_zb_inner_tile.cpp (+ tests/CMakeLists.txt): - Sniff tests instantiating TA::Tensor, ...>> as the ToT inner tile, exercising subt / add / scale through the cross-namespace operator path. --- src/TiledArray/dist_array.h | 14 +- src/TiledArray/einsum/tiledarray.h | 46 +++- src/TiledArray/external/btas.h | 135 +++++++++++ src/TiledArray/tensor/kernels.h | 25 +- src/TiledArray/tensor/operators.h | 282 +++-------------------- src/TiledArray/tensor/operators_body.ipp | 172 ++++++++++++++ src/TiledArray/tensor/print.h | 29 ++- src/TiledArray/tensor/print.ipp | 19 +- src/TiledArray/tensor/tensor.h | 2 +- src/TiledArray/tensor/type_traits.h | 31 ++- tests/CMakeLists.txt | 1 + tests/btas_zb_inner_tile.cpp | 136 +++++++++++ 12 files changed, 600 insertions(+), 292 deletions(-) create mode 100644 src/TiledArray/tensor/operators_body.ipp create mode 100644 tests/btas_zb_inner_tile.cpp diff --git a/src/TiledArray/dist_array.h b/src/TiledArray/dist_array.h index 08805bee52..c3cafb5605 100644 --- a/src/TiledArray/dist_array.h +++ b/src/TiledArray/dist_array.h @@ -1852,8 +1852,20 @@ size_t volume(const DistArray& array) { auto local_vol = [&vol](Tile const& in_tile) { if constexpr (detail::is_tensor_of_tensor_v) { + // Inner tile pointer is passed (see is_reduce_op_v in + // tensor/type_traits.h selecting the pointer-passing tensor_reduce + // overload). Prefer `total_size()` (TA::Tensor exposes it, batches + // included); fall back to `size()` for inner tile types that don't + // (e.g. btas::Tensor). auto reduce_op = [](size_t& MADNESS_RESTRICT result, auto&& arg) { - result += arg->total_size(); + using InnerTile = + std::remove_cv_t>; + if constexpr (detail::has_member_function_total_size_anyreturn_v< + InnerTile>) { + result += arg->total_size(); + } else { + result += arg->size(); + } }; auto join_op = [](auto& MADNESS_RESTRICT result, size_t count) { result += count; diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h index e0b93886ee..8196803152 100644 --- a/src/TiledArray/einsum/tiledarray.h +++ b/src/TiledArray/einsum/tiledarray.h @@ -185,9 +185,31 @@ template constexpr bool AreArraySame = AreArrayT || AreArrayToT; +// "Denested" companion of a ToT array: drops the inner-tile nesting, leaving +// a regular (non-nested) DistArray. For ToT inputs, the outer tile of the +// denested array is always TA::Tensor — nested inner-tile types (e.g. +// btas::Tensor) are only valid as the *innermost* tile and don't support the +// outer-tile operations einsum needs (permute/reshape/batch/range+lambda +// ctor). So for ToT we drop the inner tile and re-wrap its numeric type in +// TA::Tensor. For non-ToT inputs, the original "drop one level" behavior is +// preserved. +namespace detail_denested { +template +struct denested { + using type = DistArray; +}; +template +struct denested>> { + using type = DistArray< + TA::Tensor, + typename Array::policy_type>; +}; +} // namespace detail_denested template -using DeNestedArray = DistArray; +using DeNestedArray = typename detail_denested::denested::type; template using MaxNestedArray = std::conditional_t<(detail::nested_rank > @@ -496,13 +518,21 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, // Step III: C1(ijpqab) -> C2(ijpq) // Step IV: C2(ijpq) -> C(ipjq) + // Build a "denested" tile: one scalar per outer index, summed over the + // inner tile. The result tile's outer type is TA::Tensor (inner tile + // types like btas::Tensor are only valid as the innermost tile and don't + // expose the range+lambda ctor used here). auto sum_tot_2_tos = [](auto const &tot) { using tot_t = std::remove_reference_t; - typename tot_t::value_type result(tot.range(), [tot](auto &&ix) { + using numeric_type = typename tot_t::numeric_type; + TA::Tensor result(tot.range(), [tot](auto &&ix) { + // unqualified `sum` so ADL finds the right overload for both + // TA::Tensor inner (free fn in namespace TiledArray, calls .sum()) + // and btas::Tensor inner (free fn in namespace btas). if (!tot(ix).empty()) - return tot(ix).sum(); + return sum(tot(ix)); else - return typename tot_t::numeric_type{}; + return numeric_type{}; }); return result; }; @@ -722,7 +752,7 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, using TensorT = std::remove_reference_t; for (auto i = 0; i < vol; ++i) - el.add_to(element_product_op(aik.data()[i], bik.data()[i])); + add_to(el, element_product_op(aik.data()[i], bik.data()[i])); } else if constexpr (!AreArraySame) { auto aik = ai.batch(k); @@ -734,9 +764,9 @@ auto einsum(expressions::TsrExpr A, expressions::TsrExpr B, for (auto i = 0; i < vol; ++i) if constexpr (IsArrayToT) { - el.add_to(aik.data()[i].scale(bik.data()[i])); + add_to(el, scale(aik.data()[i], bik.data()[i])); } else { - el.add_to(bik.data()[i].scale(aik.data()[i])); + add_to(el, scale(bik.data()[i], aik.data()[i])); } } else { diff --git a/src/TiledArray/external/btas.h b/src/TiledArray/external/btas.h index 9a49362916..37c9ff70af 100644 --- a/src/TiledArray/external/btas.h +++ b/src/TiledArray/external/btas.h @@ -24,6 +24,7 @@ #define TILEDARRAY_EXTERNAL_BTAS_H__INCLUDED #include +#include #include #include #include @@ -38,6 +39,7 @@ #include #include #include +#include #include @@ -86,6 +88,15 @@ inline TiledArray::Range make_ta_range( return TiledArray::Range(range.lobound(), range.upbound()); } +/// makes TiledArray::Range from a btas::zb::RangeNd (zero-based, row-major) + +/// \param[in] range a btas::zb::RangeNd object +template +inline TiledArray::Range make_ta_range( + const btas::zb::RangeNd& range) { + return TiledArray::Range(range.lobound(), range.upbound()); +} + } // namespace detail /// Test if the two ranges are congruent @@ -130,6 +141,41 @@ inline bool is_congruent(const btas::RangeNd& r1, r2.extent_data()); } +namespace zb { + +// ADL on btas::zb::RangeNd reaches namespace btas::zb (the innermost +// enclosing namespace of the class), NOT namespace btas — so is_congruent +// overloads taking a btas::zb::RangeNd must live here. + +/// Test if a btas::zb::RangeNd is congruent with another btas::zb::RangeNd +template +inline bool is_congruent(const btas::zb::RangeNd& r1, + const btas::zb::RangeNd& r2) { + return (r1.rank() == r2.rank()) && + std::equal(r1.extent_data(), r1.extent_data() + r1.rank(), + r2.extent_data()); +} + +/// Test if a btas::zb::RangeNd and a TA range are congruent +template +inline bool is_congruent(const btas::zb::RangeNd& r1, + const TiledArray::Range& r2) { + return (r1.rank() == r2.rank()) && + std::equal(r1.extent_data(), r1.extent_data() + r1.rank(), + r2.extent_data()); +} + +/// Test if a TA range and a btas::zb::RangeNd are congruent +template +inline bool is_congruent(const TiledArray::Range& r1, + const btas::zb::RangeNd& r2) { + return (r1.rank() == r2.rank()) && + std::equal(r1.extent_data(), r1.extent_data() + r1.rank(), + r2.extent_data()); +} + +} // namespace zb + /// Test if a TA range and a BTAS range are congruent /// This function tests that the rank and extent of @@ -837,6 +883,38 @@ inline void gemm(btas::Tensor& result, ldb, T(1), result.data(), ldc); } +// gemm overload matching TA::Tensor's signature +// (alpha, A, B, beta, C, helper): C = alpha*A*B + beta*C +template +inline void gemm(Alpha alpha, const btas::Tensor& left, + const btas::Tensor& right, Beta beta, + btas::Tensor& result, + const TiledArray::math::GemmHelper& gemm_helper) { + TA_ASSERT(!result.empty()); + TA_ASSERT(result.range().rank() == gemm_helper.result_rank()); + TA_ASSERT(!left.empty()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(!right.empty()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + + using integer = TiledArray::math::blas::integer; + integer m, n, k; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + (gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m)); + const integer ldb = std::max( + integer{1}, + (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); + const integer ldc = std::max(integer{1}, n); + + TiledArray::math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, + n, k, T(alpha), left.data(), lda, right.data(), + ldb, T(beta), result.data(), ldc); +} + // sum of the hyperdiagonal elements template inline T trace(const btas::Tensor& arg) { @@ -941,6 +1019,13 @@ template struct is_contiguous_tensor_helper> : public std::true_type {}; +template +constexpr size_t nested_rank<::btas::Tensor> = 1 + nested_rank; + +template +constexpr size_t nested_rank> = + nested_rank<::btas::Tensor>; + template struct is_btas_tensor : public std::false_type {}; @@ -989,6 +1074,56 @@ struct Cast, return result; } }; + +} // namespace TiledArray + +// Memory-footprint accessor for btas::Tensor. Lives in @c namespace btas so +// that ADL finds it from TA::Tensor's recursive @c size_of when the inner- +// tile type is a btas::Tensor (e.g., btas-inner ToT). The trailing template +// parameter @c S (the @c TiledArray::MemorySpace) is selected by the caller. +namespace btas { +template <::TiledArray::MemorySpace S, typename T, typename R, typename Storage> +std::size_t size_of(const btas::Tensor& t) { + std::size_t result = sizeof(t); + if constexpr (S == ::TiledArray::MemorySpace::Host) { + if constexpr (::TiledArray::is_constexpr_size_of_v) { + result += t.size() * sizeof(T); + } else { + for (auto const& el : t) result += size_of(el); + } + } + return result; +} +} // namespace btas + +// Subtract btas::Tensor out of TA's operator predicate so the TA-side +// operators (in namespace TiledArray) don't match btas tensors, leaving the +// btas-side operators below as the only viable candidates via ADL. This is +// needed because TA's @c is_tensor_helper is explicitly specialized for +// @c btas::Tensor above (line 938), which would otherwise pull btas types +// into TA's @c is_nested_tensor_v predicate too. +namespace TiledArray { +namespace detail { +template +struct ta_ops_match_tensor<::btas::Tensor> : std::false_type {}; +} // namespace detail } // namespace TiledArray +// Element-wise arithmetic operators (T+T, T-T, T*T, -T) for btas::Tensor, so +// ADL finds them when the inner tile of a ToT is a btas::Tensor (e.g. +// from inside TiledArray::Tensor::subt's lambda). Same shared body as the +// TiledArray-side operators in ; the +// per-namespace ta_ops_match_tensor_v predicate keeps the two copies +// non-overlapping under overload resolution. +namespace btas { +namespace detail { +template +inline constexpr bool ta_ops_match_tensor_v = + ::TiledArray::detail::is_btas_tensor_v< + ::TiledArray::detail::remove_cvr_t>; +} // namespace detail + +#include +} // namespace btas + #endif /* TILEDARRAY_EXTERNAL_BTAS_H__INCLUDED */ diff --git a/src/TiledArray/tensor/kernels.h b/src/TiledArray/tensor/kernels.h index 64cabbb9d4..c68bad8f7a 100644 --- a/src/TiledArray/tensor/kernels.h +++ b/src/TiledArray/tensor/kernels.h @@ -1258,12 +1258,12 @@ auto tensor_contract(TensorA const& A, TensorB const& B, using Numeric = typename Result::numeric_type; // call gemm - gemm(Numeric{1}, // - plan.do_perm.A ? A.permute(plan.perm.A) : A, // - plan.do_perm.B ? B.permute(plan.perm.B) : B, // + gemm(Numeric{1}, // + plan.do_perm.A ? permute(A, plan.perm.A) : A, // + plan.do_perm.B ? permute(B, plan.perm.B) : B, // Numeric{0}, result, plan.gemm_helper); - return plan.do_perm.C ? result.permute(plan.perm.C.inv()) : result; + return plan.do_perm.C ? permute(result, plan.perm.C.inv()) : result; } /// contracts 2 tensors, with 1 plan construction per call. @@ -1334,20 +1334,21 @@ auto tensor_hadamard(TensorA const& A, TensorB const& B, TA_ASSERT(B.range().rank() == plan.B.size()); if (plan.no_perm) { - return A.mult(B); + return mult(A, B); } else if (plan.perm_to_c) { - return A.mult(B, plan.perm.AC); + return mult(A, B, plan.perm.AC); } else if (plan.perm_a) { - auto pA = A.permute(plan.perm.AC); - pA.mult_to(B); + auto pA = permute(A, plan.perm.AC); + mult_to(pA, B); return pA; } else if (plan.perm_b) { - auto pB = B.permute(plan.perm.BC); - pB.mult_to(A); + auto pB = permute(B, plan.perm.BC); + mult_to(pB, A); return pB; } else { - auto pA = A.permute(plan.perm.AC); - return pA.mult_to(B.permute(plan.perm.BC)); + auto pA = permute(A, plan.perm.AC); + mult_to(pA, permute(B, plan.perm.BC)); + return pA; } } diff --git a/src/TiledArray/tensor/operators.h b/src/TiledArray/tensor/operators.h index 97f9e5cdd8..05636c3d7d 100644 --- a/src/TiledArray/tensor/operators.h +++ b/src/TiledArray/tensor/operators.h @@ -26,26 +26,20 @@ #ifndef TILEDARRAY_TENSOR_OPERATORS_H__INCLUDED #define TILEDARRAY_TENSOR_OPERATORS_H__INCLUDED +#include #include namespace TiledArray { // Tensor arithmetic operators - -/// Tensor plus Tensor operator - -/// Add two tensors -/// \tparam T1 The left-hand tensor type -/// \tparam T2 The right-hand tensor type -/// \param left The left-hand tensor argument -/// \param right The right-hand tensor argument -/// \return A tensor where element \c i is equal to left[i] + right[i] -template , detail::remove_cvr_t>>> -inline decltype(auto) operator+(T1&& left, T2&& right) { - return add(std::forward(left), std::forward(right)); -} +// +// The element-wise tensor+tensor, tensor-tensor, tensor*tensor, and unary +// negation operators live in @c operators_body.ipp so they can be re-injected +// into @c namespace btas via the same file (see @c external/btas.h). Operators +// that involve TiledArray-specific types (Permutation, ostream, scalar mixing, +// in-place compounds) remain below. The @c detail::ta_ops_match_tensor +// predicate they use is declared in . +#include /// Tensor plus number operator @@ -54,10 +48,11 @@ inline decltype(auto) operator+(T1&& left, T2&& right) { /// \param tensor The tensor argument /// \param number The number argument /// \return A tensor where element \c i is equal to tensor[i] + number -template >>> +template >>> inline decltype(auto) operator+( - T1&& tensor, detail::numeric_t> number) { + T1&& tensor, TA::detail::numeric_t> number) { return std::forward(tensor).add(number); } @@ -68,28 +63,14 @@ inline decltype(auto) operator+( /// \param number The number argument /// \param tensor The tensor argument /// \return A tensor where element \c i is equal to tensor[i] + number -template >>> +template >>> inline decltype(auto) operator+( - detail::numeric_t> number, T1&& tensor) { + TA::detail::numeric_t> number, T1&& tensor) { return std::forward(tensor).add(number); } -/// Tensor minus Tensor operator - -/// Subtracts two tensors -/// \tparam T1 The left-hand tensor type -/// \tparam T2 The right-hand tensor type -/// \param left The left-hand tensor argument -/// \param right The right-hand tensor argument -/// \return A tensor where element \c i is equal to left[i] - right[i] -template , detail::remove_cvr_t>>> -inline decltype(auto) operator-(T1&& left, T2&& right) { - return subt(std::forward(left), std::forward(right)); -} - /// Tensor minus number operator /// Subtracts a number from a tensor @@ -97,138 +78,14 @@ inline decltype(auto) operator-(T1&& left, T2&& right) { /// \param tensor The tensor argument /// \param number The number argument /// \return A tensor where element \c i is equal to tensor[i] - number -template >>> +template >>> inline decltype(auto) operator-( - T1&& tensor, detail::numeric_t> number) { + T1&& tensor, TA::detail::numeric_t> number) { return std::forward(tensor).subt(number); } -/// Element-wise multiplication operator for Tensors - -/// Element-wise multiplication of two tensors -/// \tparam T1 The left-hand tensor type -/// \tparam T2 The right-hand tensor type -/// \param left The left-hand tensor argument -/// \param right The right-hand tensor argument -/// \return A tensor where element \c i is equal to left[i] * right[i] -template < - typename T1, typename T2, - typename std::enable_if, detail::remove_cvr_t>>::type* = nullptr> -inline decltype(auto) operator*(T1&& left, T2&& right) { - return mult(std::forward(left), std::forward(right)); -} - -/// Create a copy of \c left that is scaled by \c right - -/// Scale a tensor -/// \tparam T The left-hand tensor type -/// \tparam N Numeric type -/// \param left The left-hand tensor argument -/// \param right The right-hand scalar argument -/// \return A tensor where element \c i is equal to left[i] * right -template > && - detail::is_numeric_v>::type* = nullptr> -inline decltype(auto) operator*(T&& left, N right) { - return scale(std::forward(left), right); -} - -/// Create a copy of \c right that is scaled by \c left - -/// \tparam N A numeric type -/// \tparam T The right-hand tensor type -/// \param left The left-hand scalar argument -/// \param right The right-hand tensor argument -/// \return A tensor where element \c i is equal to left * right[i] -template < - typename N, typename T, - typename std::enable_if< - detail::is_numeric_v && - detail::is_nested_tensor_v>>::type* = nullptr> -inline decltype(auto) operator*(N left, T&& right) { - return scale(std::forward(right), left); -} - -/// Create a negated copy of \c arg - -/// \tparam T The element type of \c arg -/// \param arg The argument tensor -/// \return A tensor where element \c i is equal to \c -arg[i] -template >::value || - detail::is_tensor_of_tensor< - detail::remove_cvr_t>::value>::type* = nullptr> -inline decltype(auto) operator-(T&& arg) { - return neg(std::forward(arg)); -} - -/// Create a permuted copy of \c arg - -/// \tparam T The argument tensor type -/// \param perm The permutation to be applied to \c arg -/// \param arg The argument tensor to be permuted -template >::value || - detail::is_tensor_of_tensor< - detail::remove_cvr_t>::value>::type* = nullptr> -inline decltype(auto) operator*(const Permutation& perm, T&& arg) { - return permute(std::forward(arg), perm); -} - -/// Tensor plus operator - -/// Add the elements of \c right to that of \c left -/// \tparam T1 The left-hand tensor type -/// \tparam T2 The right-hand tensor type -/// \param left The left-hand tensor argument -/// \param right The right-hand tensor argument -/// \return A tensor where element \c i is equal to left[i] + right[i] -template , T2>::value || - detail::is_tensor_of_tensor, - T2>::value>::type* = nullptr> -inline decltype(auto) operator+=(T1&& left, const T2& right) { - return add_to(std::forward(left), right); -} - -/// Tensor minus operator - -/// Subtract the elements of \c right from that of \c left -/// \tparam T1 The left-hand tensor type -/// \tparam T2 The right-hand tensor type -/// \param left The left-hand tensor argument -/// \param right The right-hand tensor argument -/// \return A reference to \c left -template , T2>::value || - detail::is_tensor_of_tensor, - T2>::value>::type* = nullptr> -inline decltype(auto) operator-=(T1&& left, const T2& right) { - return subt_to(std::forward(left), right); -} - -/// In place tensor multiplication - -/// Multiply the elements of left by that of right -/// \tparam T1 The left-hand tensor type -/// \tparam T2 The right-hand tensor type -/// \param left The left-hand tensor argument -/// \param right The right-hand tensor argument -/// \return A reference to \c left -template , T2>::value || - detail::is_tensor_of_tensor, - T2>::value>::type* = nullptr> -inline decltype(auto) operator*=(T1&& left, const T2& right) { - return mult_to(std::forward(left), right); -} - /// In place tensor add constant /// Scale the elements of \c left by \c right @@ -237,11 +94,12 @@ inline decltype(auto) operator*=(T1&& left, const T2& right) { /// \param left The left-hand tensor argument /// \param right The right-hand scalar argument /// \return A reference to \c left -template >::value || - detail::is_tensor_of_tensor>::value) && - detail::is_numeric_v>::type* = nullptr> +template < + typename T, typename N, + typename std::enable_if< + (TA::detail::is_tensor>::value || + TA::detail::is_tensor_of_tensor>::value) && + TA::detail::is_numeric_v>::type* = nullptr> inline decltype(auto) operator+=(T&& left, N right) { return add_to(std::forward(left), right); } @@ -254,83 +112,19 @@ inline decltype(auto) operator+=(T&& left, N right) { /// \param left The left-hand tensor argument /// \param right The right-hand scalar argument /// \return A reference to \c left -template >::value || - detail::is_tensor_of_tensor>::value) && - detail::is_numeric_v>::type* = nullptr> +template < + typename T, typename N, + typename std::enable_if< + (TA::detail::is_tensor>::value || + TA::detail::is_tensor_of_tensor>::value) && + TA::detail::is_numeric_v>::type* = nullptr> inline decltype(auto) operator-=(T&& left, N right) { return subt_to(std::forward(left), right); } -/// In place tensor scale - -/// Scale the elements of \c left by \c right -/// \tparam T The left-hand tensor type -/// \tparam N Numeric type -/// \param left The left-hand tensor argument -/// \param right The right-hand scalar argument -/// \return A reference to \c left -template >::value || - detail::is_tensor_of_tensor>::value) && - detail::is_numeric_v>::type* = nullptr> -inline decltype(auto) operator*=(T&& left, N right) { - return scale_to(std::forward(left), right); -} - -/// Tensor output operator - -/// Output tensor \c t to the output stream, \c os . -/// \tparam T The tensor type -/// \param os The output stream -/// \param t The tensor to be output -/// \return A reference to the output stream -template < - typename Char, typename CharTraits, typename T, - typename std::enable_if && - detail::is_contiguous_tensor_v>::type* = nullptr> -inline std::basic_ostream& operator<<( - std::basic_ostream& os, const T& t) { - os << t.range() << " {\n"; - const auto n = t.range().volume(); - std::size_t offset = 0ul; - std::size_t nbatch = 1; - if constexpr (detail::has_member_function_nbatch_anyreturn_v) - nbatch = t.nbatch(); - const auto more_than_1_batch = nbatch > 1; - for (auto b = 0ul; b != nbatch; ++b) { - if (more_than_1_batch) { - os << " [batch " << b << "]{\n"; - } - if constexpr (detail::is_tensor_v) { // tensor of scalars - detail::NDArrayPrinter{}.print( - t.data() + offset, t.range().rank(), t.range().extent_data(), - t.range().stride_data(), os, more_than_1_batch ? 4 : 2); - } else { // tensor of tensors, need to annotate each element by its index - for (auto&& idx : t.range()) { // Loop over inner tensors - const auto& inner_t = *(t.data() + offset + t.range().ordinal(idx)); - os << " " << idx << ":"; - detail::NDArrayPrinter{}.print(inner_t.data(), inner_t.range().rank(), - inner_t.range().extent_data(), - inner_t.range().stride_data(), os, - more_than_1_batch ? 6 : 4); - os << "\n"; - } - } - if (more_than_1_batch) { - os << "\n }"; - if (b + 1 != nbatch) os << "\n"; // not last batch - } - offset += n; - } - os << "\n}\n"; - - return os; -} - -/// Tensor output operator +/// Tensor output operator (non-contiguous TensorInterface-style views; the +/// contiguous-tensor case is provided by @c operators_body.ipp above for both +/// TiledArray::Tensor and btas::Tensor) /// Output tensor \c t to the output stream, \c os . /// \tparam T The tensor type @@ -339,8 +133,8 @@ inline std::basic_ostream& operator<<( /// \return A reference to the output stream template ::value && - !detail::is_contiguous_tensor::value>::type* = nullptr> + TA::detail::is_tensor::value && + !TA::detail::is_contiguous_tensor::value>::type* = nullptr> inline std::basic_ostream& operator<<( std::basic_ostream& os, const T& t) { const auto stride = inner_size(t); diff --git a/src/TiledArray/tensor/operators_body.ipp b/src/TiledArray/tensor/operators_body.ipp new file mode 100644 index 0000000000..9758af5d24 --- /dev/null +++ b/src/TiledArray/tensor/operators_body.ipp @@ -0,0 +1,172 @@ +// Shared body of tensor arithmetic operators. +// +// NO header guard, NO namespace — this file is *intentionally* #included from +// inside two different namespaces (TiledArray and btas) to make ADL find the +// right overload for each tile type. Before including, the enclosing namespace +// must define: +// +// namespace detail { +// template +// inline constexpr bool ta_ops_match_tensor_v = …; +// } +// +// returning true for the tensor types this namespace's operators should accept +// and false otherwise — that's how we keep the two copies non-overlapping in +// overload resolution (TA's copy accepts TiledArray::Tensor &c.; btas's copy +// accepts btas::Tensor). The two predicates must be *disjoint*; in particular +// TA's copy must not accept btas::Tensor and vice versa. References to +// `::TA::detail::remove_cvr_t` etc. work from either namespace via the +// `namespace TA = TiledArray` alias at global scope. +// +// Only operators whose body delegates to a free CPO available for both tile +// types belong here. Operators that delegate to member functions (e.g. +// `tensor + number` → `tensor.add(number)`) or to TA-only free CPOs (e.g. +// `add_to(Tensor, scalar)`) live directly in +// inside namespace TiledArray. + +/// element-wise tensor + tensor +template > && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator+(T1&& left, T2&& right) { + return add(std::forward(left), std::forward(right)); +} + +/// element-wise tensor - tensor +template > && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator-(T1&& left, T2&& right) { + return subt(std::forward(left), std::forward(right)); +} + +/// element-wise tensor * tensor +template > && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator*(T1&& left, T2&& right) { + return mult(std::forward(left), std::forward(right)); +} + +/// unary negation +template >>> +inline decltype(auto) operator-(T&& arg) { + return neg(std::forward(arg)); +} + +/// tensor * scalar (right-hand scalar) +template > && + TA::detail::is_numeric_v>> +inline decltype(auto) operator*(T&& tensor, N number) { + return scale(std::forward(tensor), number); +} + +/// scalar * tensor (left-hand scalar) +template && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator*(N number, T&& tensor) { + return scale(std::forward(tensor), number); +} + +/// tensor += tensor +template > && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator+=(T1&& left, const T2& right) { + return add_to(std::forward(left), right); +} + +/// tensor -= tensor +template > && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator-=(T1&& left, const T2& right) { + return subt_to(std::forward(left), right); +} + +/// tensor *= tensor (element-wise) +template > && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator*=(T1&& left, const T2& right) { + return mult_to(std::forward(left), right); +} + +/// tensor *= scalar +template > && + TA::detail::is_numeric_v>> +inline decltype(auto) operator*=(T&& left, N right) { + return scale_to(std::forward(left), right); +} + +/// permutation * tensor (templated on the permutation type so this overload +/// is usable with any class that satisfies @c TA::detail::is_permutation_v , +/// not only @c TiledArray::Permutation ) +template > && + detail::ta_ops_match_tensor_v>>> +inline decltype(auto) operator*(const P& perm, T&& arg) { + return permute(std::forward(arg), perm); +} + +/// Tensor output operator — NumPy-style printing for any contiguous tensor +/// type whose range exposes the standard accessors (@c rank , @c extent_data , +/// @c stride_data , @c ordinal ). Element-of-tensor (ToT) decoration is +/// emitted when the element type is itself a tensor; the optional @c nbatch +/// member is queried via @c if constexpr so non-batched tensors compile too. +template > && + TA::detail::is_contiguous_tensor_v>>> +inline std::basic_ostream& operator<<( + std::basic_ostream& os, const T& t) { + os << t.range() << " {\n"; + const auto n = t.range().volume(); + std::size_t offset = 0ul; + std::size_t nbatch = 1; + if constexpr (TA::detail::has_member_function_nbatch_anyreturn_v) + nbatch = t.nbatch(); + const auto more_than_1_batch = nbatch > 1; + for (auto b = 0ul; b != nbatch; ++b) { + if (more_than_1_batch) { + os << " [batch " << b << "]{\n"; + } + if constexpr (TA::detail::is_tensor_v) { // tensor of scalars + TA::detail::NDArrayPrinter{}.print( + t.data() + offset, t.range().rank(), t.range().extent_data(), + t.range().stride_data(), os, more_than_1_batch ? 4 : 2); + } else { // tensor of tensors — annotate each element by its index + for (auto&& idx : t.range()) { + const auto& inner_t = *(t.data() + offset + t.range().ordinal(idx)); + os << " " << idx << ":"; + TA::detail::NDArrayPrinter{}.print( + inner_t.data(), inner_t.range().rank(), + inner_t.range().extent_data(), inner_t.range().stride_data(), os, + more_than_1_batch ? 6 : 4); + os << "\n"; + } + } + if (more_than_1_batch) { + os << "\n }"; + if (b + 1 != nbatch) os << "\n"; + } + offset += n; + } + os << "\n}\n"; + return os; +} diff --git a/src/TiledArray/tensor/print.h b/src/TiledArray/tensor/print.h index 18d09be7a7..abb1ab315e 100644 --- a/src/TiledArray/tensor/print.h +++ b/src/TiledArray/tensor/print.h @@ -80,30 +80,39 @@ class NDArrayPrinter { FloatTruncate truncate_; // Helper function to recursively print the array - template > - void printArray(const T* data, const std::size_t order, const Index* extents, - const Index* strides, + template > + void printArray(const T* data, const std::size_t order, + const ExtentIndex* extents, const StrideIndex* strides, std::basic_ostream& os, size_t level = 0, size_t offset = 0, size_t extra_indentation = 0); public: // Print a row-major array to a stream + // + // @c ExtentIndex and @c StrideIndex are independently deducible so callers + // can pass arrays of different integer types — needed for @c btas::Tensor , + // whose range exposes @c extent_data() as @c unsigned-long* but + // @c stride_data() as @c long-long* . template > - void print(const T* data, const std::size_t order, const Index* extents, - const Index* strides, std::basic_ostream& os, + void print(const T* data, const std::size_t order, const ExtentIndex* extents, + const StrideIndex* strides, + std::basic_ostream& os, std::size_t extra_indentation = 0); // Helper function to create a string representation template > std::basic_string toString(const T* data, const std::size_t order, - const Index* extents, - const Index* strides); + const ExtentIndex* extents, + const StrideIndex* strides); }; // Explicit template instantiations diff --git a/src/TiledArray/tensor/print.ipp b/src/TiledArray/tensor/print.ipp index c3535409b3..dae8f64fbd 100644 --- a/src/TiledArray/tensor/print.ipp +++ b/src/TiledArray/tensor/print.ipp @@ -34,9 +34,11 @@ namespace TiledArray { namespace detail { // Class to print n-dimensional arrays in NumPy style but with curly braces -template +template void NDArrayPrinter::printArray(const T* data, const std::size_t order, - const Index* extents, const Index* strides, + const ExtentIndex* extents, + const StrideIndex* strides, std::basic_ostream& os, size_t level, size_t offset, size_t extra_indentation) { @@ -69,9 +71,11 @@ void NDArrayPrinter::printArray(const T* data, const std::size_t order, } // Print a row-major array to a stream -template +template void NDArrayPrinter::print(const T* data, const std::size_t order, - const Index* extents, const Index* strides, + const ExtentIndex* extents, + const StrideIndex* strides, std::basic_ostream& os, std::size_t extra_indentation) { // Note: Can't validate data size with raw pointers, caller must ensure data has sufficient size @@ -80,10 +84,11 @@ void NDArrayPrinter::print(const T* data, const std::size_t order, } // Helper function to create a string representation -template +template std::basic_string NDArrayPrinter::toString( - const T* data, const std::size_t order, const Index* extents, - const Index* strides) { + const T* data, const std::size_t order, const ExtentIndex* extents, + const StrideIndex* strides) { std::basic_stringstream oss; print(data, order, extents, strides, oss, /* extra_indentation = */ 0); return oss.str(); diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index ba684cc768..fa04ff7eda 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -1953,7 +1953,7 @@ class Tensor { /// \return A new tensor where the elements are the different between the /// elements of \c this and \c right template >> Tensor subt(const Right& right) const { return binary( diff --git a/src/TiledArray/tensor/type_traits.h b/src/TiledArray/tensor/type_traits.h index 2bf3ce0072..bf4836e968 100644 --- a/src/TiledArray/tensor/type_traits.h +++ b/src/TiledArray/tensor/type_traits.h @@ -117,6 +117,18 @@ struct is_nested_tensor { template inline constexpr const bool is_nested_tensor_v = is_nested_tensor::value; +/// Predicate used by the shared operator body in +/// @c TiledArray/tensor/operators_body.inl to gate the element-wise tensor +/// operators that are injected into @c namespace TiledArray . The btas-side +/// copy of the same operators (in @c external/btas.h) partial-specializes +/// this predicate to @c std::false_type for @c btas::Tensor so the two +/// namespaces' operators stay non-overlapping under ADL. +template +struct ta_ops_match_tensor : is_nested_tensor {}; + +template +inline constexpr bool ta_ops_match_tensor_v = ta_ops_match_tensor::value; + //////////////////////////////////////////////////////////////////////////////// template @@ -487,19 +499,20 @@ struct result_tensor_helper { using TensorB_ = std::remove_reference_t; using value_type_A = typename TensorA_::value_type; using value_type_B = typename TensorB_::value_type; - using allocator_type_A = typename TensorA_::allocator_type; - using allocator_type_B = typename TensorB_::allocator_type; public: using numeric_type = binop_result_t; - using allocator_type = - std::conditional_t && - std::is_same_v, - allocator_type_A, Allocator>; + + // Result tensor type stays in TensorA's family with the allocator rebound to + // hold `numeric_type`. Both TA::Tensor and btas::Tensor expose this as + // `rebind_t` (TA::Tensor via std::allocator_traits::rebind_alloc; btas + // via storage_traits::rebind_t). An explicit @tparam Allocator override only + // applies when TensorA is a TA::Tensor. using result_type = - std::conditional_t, - TA::Tensor, - TA::Tensor>; + std::conditional_t || + !is_ta_tensor_v, + typename TensorA_::template rebind_t, + TA::Tensor>; }; } // namespace diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a30770fb18..7a3840ea12 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -38,6 +38,7 @@ set(ta_test_src_files ta_test.cpp type_traits.cpp tensor.cpp tensor_of_tensor.cpp + btas_zb_inner_tile.cpp tensor_tensor_view.cpp tensor_shift_wrapper.cpp tiled_range1.cpp diff --git a/tests/btas_zb_inner_tile.cpp b/tests/btas_zb_inner_tile.cpp new file mode 100644 index 0000000000..a9293ba1c9 --- /dev/null +++ b/tests/btas_zb_inner_tile.cpp @@ -0,0 +1,136 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2026 Virginia Tech + * + * btas_zb_inner_tile.cpp + * + * Sniff tests for using btas::Tensor, ...> as the + * inner tile of a TiledArray Tensor-of-Tensor. This is the entry point for + * validating Phase 2 of the inner-tile shrink work (see PR + * ValeevGroup/BTAS#187); ops not yet validated for the new range type will + * surface as compile errors here. + */ + +#include + +#ifdef TILEDARRAY_HAS_BTAS + +#include +#include +#include + +#include "unit_test_config.h" + +using namespace TiledArray; + +// Inner tile under test: btas::Tensor with zero-based packed range and the +// BTAS default storage (boost::container::small_vector wrapper). Uses extent +// type int16_t / ordinal type int32_t by default. +using bTensorIzb_storage = btas::DEFAULT::storage; +using bTensorIzb = btas::Tensor, bTensorIzb_storage>; + +// Sanity-check the size claim — fail loudly here if range layout drifts. +static_assert(sizeof(btas::zb::RangeNd<>) == 14, + "zb::RangeNd default layout must remain 14 bytes"); + +BOOST_AUTO_TEST_SUITE(btas_zb_inner_tile_suite, + *boost::unit_test::label("@serial")) + +// 1. Bare btas::Tensor parametrized on zb::RangeNd compiles, constructs, and +// permits element access via raw storage. Anchor for the whole stack. +BOOST_AUTO_TEST_CASE(btas_tensor_basic) { + bTensorIzb t(3, 4); + BOOST_REQUIRE_EQUAL(t.range().rank(), 2u); + BOOST_REQUIRE_EQUAL(t.range().area(), 12u); + BOOST_REQUIRE_EQUAL(t.range().extent(0), 3); + BOOST_REQUIRE_EQUAL(t.range().extent(1), 4); + + // Fill via raw pointer (smallest possible surface). + auto* p = t.data(); + for (std::size_t i = 0; i < t.range().area(); ++i) p[i] = static_cast(i); + BOOST_CHECK_EQUAL(p[0], 0); + BOOST_CHECK_EQUAL(p[5], 5); + BOOST_CHECK_EQUAL(p[11], 11); +} + +// 2. Wrap as the element type of a TA::Tensor (the ToT shape) and verify +// construction + per-element placement of inner tiles works. +BOOST_AUTO_TEST_CASE(tensor_of_btas_tensor_construct) { + Tensor outer(Range({2ul, 3ul})); + BOOST_REQUIRE_EQUAL(outer.range().rank(), 2u); + BOOST_REQUIRE_EQUAL(outer.range().area(), 6u); + + // Each inner tile has non-uniform extents — the load-bearing property of + // ToT — but the inner range is zero-based. + for (std::size_t i = 0; i < 2; ++i) { + for (std::size_t j = 0; j < 3; ++j) { + bTensorIzb inner(static_cast(2 + i), static_cast(3 + j)); + for (std::size_t k = 0; k < inner.range().area(); ++k) + inner.data()[k] = static_cast((i + 1) * 100 + (j + 1) * 10 + k); + outer(i, j) = std::move(inner); + } + } + + BOOST_CHECK_EQUAL(outer(0, 0).range().rank(), 2u); + BOOST_CHECK_EQUAL(outer(0, 0).range().area(), 2 * 3); + BOOST_CHECK_EQUAL(outer(1, 2).range().area(), 3 * 5); + BOOST_CHECK_EQUAL(outer(0, 0).data()[0], 110); + BOOST_CHECK_EQUAL(outer(1, 2).data()[3 * 5 - 1], 230 + 14); +} + +// 3. Helper: build a Tensor with uniform inner-tile extents (so +// both operands are congruent, which the binary ops require). +static Tensor make_uniform_TobT_zb(const Range& r, int fill) { + Tensor tensor(r); + for (decltype(r.extent(0)) i = 0; i < r.extent(0); ++i) { + for (decltype(r.extent(1)) j = 0; j < r.extent(1); ++j) { + bTensorIzb inner(static_cast(4), static_cast(5)); + for (std::size_t k = 0; k < inner.range().area(); ++k) + inner.data()[k] = fill + static_cast(k); + tensor(i, j) = std::move(inner); + } + } + return tensor; +} + +// 4. Exercise the actual ToT ops we care about — these instantiate the +// cross-namespace operator infrastructure (btas::operator- via shared +// body) inside TA::Tensor::{subt,add,scale,neg} lambdas. +// *This* is the test that proves btas::Tensor with btas::zb::RangeNd is +// usable as a ToT inner tile end-to-end. +BOOST_AUTO_TEST_CASE(tot_subt_with_zb_inner) { + auto a = make_uniform_TobT_zb(Range({2ul, 3ul}), 10); + auto b = make_uniform_TobT_zb(Range({2ul, 3ul}), 1); + + Tensor c; + BOOST_REQUIRE_NO_THROW(c = a.subt(b)); + + BOOST_REQUIRE_EQUAL(c.range().area(), 6u); + // Inner tile (0,0): each element should be a_inner[k] - b_inner[k] = 9. + for (std::size_t k = 0; k < c(0, 0).range().area(); ++k) + BOOST_CHECK_EQUAL(c(0, 0).data()[k], 9); +} + +BOOST_AUTO_TEST_CASE(tot_add_with_zb_inner) { + auto a = make_uniform_TobT_zb(Range({2ul, 3ul}), 10); + auto b = make_uniform_TobT_zb(Range({2ul, 3ul}), 5); + + Tensor c; + BOOST_REQUIRE_NO_THROW(c = a.add(b)); + + // First inner-tile element: (10 + 0) + (5 + 0) = 15. + BOOST_CHECK_EQUAL(c(0, 0).data()[0], 15); +} + +BOOST_AUTO_TEST_CASE(tot_scale_with_zb_inner) { + auto a = make_uniform_TobT_zb(Range({2ul, 3ul}), 1); + + Tensor c; + BOOST_REQUIRE_NO_THROW(c = a.scale(3)); + + BOOST_CHECK_EQUAL(c(0, 0).data()[0], 3); +} + +BOOST_AUTO_TEST_SUITE_END() + +#endif // TILEDARRAY_HAS_BTAS From f06e9fca6dbd5fe07ae1ce6efc74e6b0a9d9e4eb Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Tue, 12 May 2026 17:47:55 -0400 Subject: [PATCH 3/5] btas inner-tile path: empty-arg handling + rank-1 TI ordinal lookup Two fixes needed by MPQC's CSV CCk validation tests (which exercise TA::Tensor> inner tiles end-to-end): 1. external/btas.h: mirror TA::Tensor's empty-/null-argument early-exits in the btas free-function suite (add/add_to/subt/subt_to/mult/mult_to, their factored and permuted variants, scale/scale_to/neg/neg_to). TA's ToT lambdas default-construct result inner tiles and accumulate into them, so consumers regularly hit the empty-result-then-add_to(arg) pattern; without these guards, TA's congruent-range assertion fires inside the underlying TensorInterface machinery. 2. tensor/tensor_interface.h: replace range_.includes(index_ordinal) with range_.includes_ordinal(index_ordinal) in TensorInterface's operator[] / at_ordinal. The integral includes(Ordinal) overload on TA::Range asserts rank!=1 to disambiguate from the includes(Index) coordinate-tuple form; TI's ordinal-lookup path was inadvertently tripping that for rank-1 inner tiles (e.g. PNO-basis 1-D energies). --- src/TiledArray/external/btas.h | 107 +++++++++++++++++++++++ src/TiledArray/tensor/tensor_interface.h | 14 ++- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/src/TiledArray/external/btas.h b/src/TiledArray/external/btas.h index 37c9ff70af..988e51c3a7 100644 --- a/src/TiledArray/external/btas.h +++ b/src/TiledArray/external/btas.h @@ -308,11 +308,21 @@ inline btas::Tensor&& shift_to( return std::move(arg); } +// NOTE on empty-handling: these btas free functions mirror the empty-/null- +// argument behavior of TA::Tensor's same-named member ops, so generic ToT +// code (e.g. TA::Tensor>'s lambdas in +// TA::Tensor::{add,subt,mult,add_to,subt_to,mult_to,neg,scale}) can produce +// or accumulate into default-constructed (empty) btas inner tiles without +// hitting TA's congruent-range assertion in the underlying TensorInterface +// machinery. + /// result[i] = arg1[i] + arg2[i] template inline btas::Tensor add( const btas::Tensor& arg1, const btas::Tensor& arg2) { + if (arg1.empty()) return arg2; + if (arg2.empty()) return arg1; auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.add(arg2_view); @@ -325,6 +335,9 @@ template add( const btas::Tensor& arg1, const btas::Tensor& arg2, const Scalar factor) { + if (arg1.empty() && arg2.empty()) return {}; + if (arg1.empty()) return scale(arg2, factor); + if (arg2.empty()) return scale(arg1, factor); auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.add(arg2_view, factor); @@ -337,6 +350,9 @@ template < inline btas::Tensor add( const btas::Tensor& arg1, const btas::Tensor& arg2, const Perm& perm) { + if (arg1.empty() && arg2.empty()) return {}; + if (arg1.empty()) return permute(arg2, perm); + if (arg2.empty()) return permute(arg1, perm); auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.add(arg2_view, perm); @@ -352,6 +368,9 @@ inline btas::Tensor add( const btas::Tensor& arg1, const btas::Tensor& arg2, const Scalar factor, const Perm& perm) { + if (arg1.empty() && arg2.empty()) return {}; + if (arg1.empty()) return scale(permute(arg2, perm), factor); + if (arg2.empty()) return scale(permute(arg1, perm), factor); auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.add(arg2_view, factor, perm); @@ -362,6 +381,15 @@ template inline btas::Tensor& add_to( btas::Tensor& result, const btas::Tensor& arg) { + // Mirror TA::Tensor::add_to's empty-handling so consumers (e.g. einsum's + // inner-contraction loop, which default-constructs result inner tiles) + // can accumulate into an empty result. Without these the make_ti's view + // ranges differ and TA's congruent-range assertion fires. + if (arg.empty()) return result; + if (result.empty()) { + result = arg; + return result; + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.add_to(arg_view); @@ -375,6 +403,15 @@ template & add_to( btas::Tensor& result, const btas::Tensor& arg, const Scalar factor) { + if (arg.empty()) return result; + if (result.empty()) { + result = arg; + { + auto result_view = make_ti(result); + result_view.scale_to(factor); + } + return result; + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.add_to(arg_view, factor); @@ -386,6 +423,8 @@ template inline btas::Tensor subt( const btas::Tensor& arg1, const btas::Tensor& arg2) { + if (arg2.empty()) return arg1; + if (arg1.empty()) return neg(arg2); auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.subt(arg2_view); @@ -398,6 +437,9 @@ template subt( const btas::Tensor& arg1, const btas::Tensor& arg2, const Scalar factor) { + if (arg1.empty() && arg2.empty()) return {}; + if (arg2.empty()) return scale(arg1, factor); + if (arg1.empty()) return scale(arg2, -factor); auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.subt(arg2_view, factor); @@ -410,6 +452,9 @@ template < inline btas::Tensor subt( const btas::Tensor& arg1, const btas::Tensor& arg2, const Perm& perm) { + if (arg1.empty() && arg2.empty()) return {}; + if (arg2.empty()) return permute(arg1, perm); + if (arg1.empty()) return neg(permute(arg2, perm)); auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.subt(arg2_view, perm); @@ -425,6 +470,9 @@ inline btas::Tensor subt( const btas::Tensor& arg1, const btas::Tensor& arg2, const Scalar factor, const Perm& perm) { + if (arg1.empty() && arg2.empty()) return {}; + if (arg2.empty()) return scale(permute(arg1, perm), factor); + if (arg1.empty()) return scale(permute(arg2, perm), -factor); auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.subt(arg2_view, factor, perm); @@ -435,6 +483,11 @@ template inline btas::Tensor& subt_to( btas::Tensor& result, const btas::Tensor& arg) { + if (arg.empty()) return result; + if (result.empty()) { + result = neg(arg); + return result; + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.subt_to(arg_view); @@ -445,6 +498,11 @@ template inline btas::Tensor&& subt_to( btas::Tensor&& result, const btas::Tensor& arg) { + if (arg.empty()) return std::move(result); + if (result.empty()) { + result = neg(arg); + return std::move(result); + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.subt_to(arg_view); @@ -457,6 +515,15 @@ template & subt_to( btas::Tensor& result, const btas::Tensor& arg, const Scalar factor) { + // TA::Tensor::subt_to(right, factor): right empty -> scale_to(factor). + if (arg.empty()) { + if (!result.empty()) scale_to(result, factor); + return result; + } + if (result.empty()) { + result = scale(arg, -factor); + return result; + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.subt_to(arg_view, factor); @@ -469,6 +536,14 @@ template && subt_to( btas::Tensor&& result, const btas::Tensor& arg, const Scalar factor) { + if (arg.empty()) { + if (!result.empty()) scale_to(result, factor); + return std::move(result); + } + if (result.empty()) { + result = scale(arg, -factor); + return std::move(result); + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.subt_to(arg_view, factor); @@ -480,6 +555,7 @@ template inline btas::Tensor mult( const btas::Tensor& arg1, const btas::Tensor& arg2) { + if (arg1.empty() || arg2.empty()) return {}; auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.mult(arg2_view); @@ -492,6 +568,7 @@ template mult( const btas::Tensor& arg1, const btas::Tensor& arg2, const Scalar factor) { + if (arg1.empty() || arg2.empty()) return {}; auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.mult(arg2_view, factor); @@ -504,6 +581,7 @@ template < inline btas::Tensor mult( const btas::Tensor& arg1, const btas::Tensor& arg2, const Perm& perm) { + if (arg1.empty() || arg2.empty()) return {}; auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.mult(arg2_view, perm); @@ -519,16 +597,23 @@ inline btas::Tensor mult( const btas::Tensor& arg1, const btas::Tensor& arg2, const Scalar factor, const Perm& perm) { + if (arg1.empty() || arg2.empty()) return {}; auto arg1_view = make_ti(arg1); auto arg2_view = make_ti(arg2); return arg1_view.mult(arg2_view, factor, perm); } /// result[i] *= arg[i] +/// (TA::Tensor::mult_to: right empty -> result = {}; left empty -> result.) template inline btas::Tensor& mult_to( btas::Tensor& result, const btas::Tensor& arg) { + if (arg.empty()) { + result = btas::Tensor{}; + return result; + } + if (result.empty()) return result; auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.mult_to(arg_view); @@ -539,6 +624,11 @@ template inline btas::Tensor&& mult_to( btas::Tensor&& result, const btas::Tensor& arg) { + if (arg.empty()) { + result = btas::Tensor{}; + return std::move(result); + } + if (result.empty()) return std::move(result); auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.mult_to(arg_view); @@ -552,6 +642,11 @@ template & mult_to( btas::Tensor& result, const btas::Tensor& arg, const Scalar factor) { + if (result.empty()) return result; + if (arg.empty()) { + result = btas::Tensor{}; + return result; + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.mult_to(arg_view, factor); @@ -564,6 +659,11 @@ template && mult_to( btas::Tensor&& result, const btas::Tensor& arg, const Scalar factor) { + if (result.empty()) return std::move(result); + if (arg.empty()) { + result = btas::Tensor{}; + return std::move(result); + } auto result_view = make_ti(result); auto arg_view = make_ti(arg); result_view.mult_to(arg_view, factor); @@ -605,6 +705,7 @@ template >* = nullptr> inline btas::Tensor& scale_to( btas::Tensor& result, const Scalar factor) { + if (result.empty()) return result; auto result_view = make_ti(result); result_view.scale_to(factor); return result; @@ -614,6 +715,7 @@ template >* = nullptr> inline decltype(auto) scale(const btas::Tensor& result, const Scalar factor) { + if (result.empty()) return btas::Tensor{}; auto result_view = make_ti(result); return result_view.scale(factor); } @@ -625,6 +727,7 @@ template < TiledArray::detail::is_permutation_v>* = nullptr> inline decltype(auto) scale(const btas::Tensor& result, const Scalar factor, const Perm& perm) { + if (result.empty()) return btas::Tensor{}; auto result_view = make_ti(result); return result_view.scale(factor, perm); } @@ -632,6 +735,7 @@ inline decltype(auto) scale(const btas::Tensor& result, template inline btas::Tensor& neg_to( btas::Tensor& result) { + if (result.empty()) return result; auto result_view = make_ti(result); result_view.neg_to(); return result; @@ -640,6 +744,7 @@ inline btas::Tensor& neg_to( template inline btas::Tensor&& neg_to( btas::Tensor&& result) { + if (result.empty()) return std::move(result); auto result_view = make_ti(result); result_view.neg_to(); return std::move(result); @@ -648,6 +753,7 @@ inline btas::Tensor&& neg_to( template inline btas::Tensor neg( const btas::Tensor& arg) { + if (arg.empty()) return {}; auto arg_view = make_ti(arg); return arg_view.neg(); } @@ -657,6 +763,7 @@ template < typename = std::enable_if_t>> inline btas::Tensor neg( const btas::Tensor& arg, const Perm& perm) { + if (arg.empty()) return {}; auto arg_view = make_ti(arg); return arg_view.neg(perm); } diff --git a/src/TiledArray/tensor/tensor_interface.h b/src/TiledArray/tensor/tensor_interface.h index ad4ba7dc34..930f9ef6b4 100644 --- a/src/TiledArray/tensor/tensor_interface.h +++ b/src/TiledArray/tensor/tensor_interface.h @@ -178,16 +178,14 @@ class TensorInterface { /// \param range The range of this tensor /// \param data The data pointer for this tensor TensorInterface(const range_type& range, pointer data) - : range_(range), data_(data) { - } + : range_(range), data_(data) {} /// Construct a new view of \c tensor /// \param range The range of this tensor /// \param data The data pointer for this tensor TensorInterface(range_type&& range, pointer data) - : range_(std::move(range)), data_(data) { - } + : range_(std::move(range)), data_(data) {} template ::value>::type* = nullptr> @@ -223,7 +221,7 @@ class TensorInterface { /// \param index_ordinal The ordinal element index /// \return A const reference to the element at \c index_ordinal. const_reference operator[](const ordinal_type index_ordinal) const { - TA_ASSERT(range_.includes(index_ordinal)); + TA_ASSERT(range_.includes_ordinal(index_ordinal)); return data_[range_.ordinal(index_ordinal)]; } @@ -232,7 +230,7 @@ class TensorInterface { /// \param index The ordinal element index /// \return A const reference to the element at \c index_ordinal. reference operator[](const ordinal_type index_ordinal) { - TA_ASSERT(range_.includes(index_ordinal)); + TA_ASSERT(range_.includes_ordinal(index_ordinal)); return data_[range_.ordinal(index_ordinal)]; } @@ -241,7 +239,7 @@ class TensorInterface { /// \param index_ordinal The ordinal element index /// \return A const reference to the element at \c index_ordinal. const_reference at_ordinal(const ordinal_type index_ordinal) const { - TA_ASSERT(range_.includes(index_ordinal)); + TA_ASSERT(range_.includes_ordinal(index_ordinal)); return data_[range_.ordinal(index_ordinal)]; } @@ -250,7 +248,7 @@ class TensorInterface { /// \param index_ordinal The ordinal element index /// \return A const reference to the element at \c index_ordinal. reference at_ordinal(const ordinal_type index_ordinal) { - TA_ASSERT(range_.includes(index_ordinal)); + TA_ASSERT(range_.includes_ordinal(index_ordinal)); return data_[range_.ordinal(index_ordinal)]; } From 6d1b16bcede3846f967a652248bdd61ae13ba4c5 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Tue, 12 May 2026 23:11:33 -0400 Subject: [PATCH 4/5] btas inner-tile path: stride/permute plumbing, follow zb::RangeNd refactor Bump BTAS pin to pick up the zb::RangeNd refactor (template params now match btas::RangeNd: ; row- or column-major; stride() exposed; permute uses if constexpr). Adapt TA accordingly: - external/btas.h: reorder includes so is visible to two-phase lookup before BTAS's generic permute is parsed; update the make_ta_range / is_congruent overloads to the new zb::RangeNd template parameter order (and refuse col-major in make_ta_range, matching the default RangeNd overload). - type_traits.h: add has_member_function_stride_data_anyreturn trait. - tensor/operators_body.ipp: in the tensor-of-tensors operator<< branch, call inner_t.range().stride_data() only when the inner range exposes it (zb::RangeNd intentionally doesn't, since synthesizing strides into a pointer-stable buffer would defeat its packed footprint). --- external/versions.cmake | 4 +-- src/TiledArray/external/btas.h | 34 +++++++++++++++--------- src/TiledArray/tensor/operators_body.ipp | 18 ++++++++++--- src/TiledArray/type_traits.h | 3 +++ 4 files changed, 41 insertions(+), 18 deletions(-) diff --git a/external/versions.cmake b/external/versions.cmake index 288fa3bd3d..78c015ec79 100644 --- a/external/versions.cmake +++ b/external/versions.cmake @@ -17,8 +17,8 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG 7d8aaf9d51981e4accf4d84742270d1473f8ca2e) set(TA_TRACKED_MADNESS_VERSION 0.10.1) set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1) -set(TA_TRACKED_BTAS_TAG 7e64fbad97c76f316f313f4c8ed3fca5445da15f) -set(TA_TRACKED_BTAS_PREVIOUS_TAG 287b145ead818a0332f2b7ce0b7375a83d328bae) +set(TA_TRACKED_BTAS_TAG 245e49f117981d6124e0f1aa0d1ae72f1c16318b) +set(TA_TRACKED_BTAS_PREVIOUS_TAG 7e64fbad97c76f316f313f4c8ed3fca5445da15f) set(TA_TRACKED_LIBRETT_TAG 6eed30d4dd2a5aa58840fe895dcffd80be7fbece) set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 354e0ccee54aeb2f191c3ce2c617ebf437e49d83) diff --git a/src/TiledArray/external/btas.h b/src/TiledArray/external/btas.h index 988e51c3a7..d8841a8596 100644 --- a/src/TiledArray/external/btas.h +++ b/src/TiledArray/external/btas.h @@ -37,9 +37,14 @@ #include #include + +// zb/range.h before generic/permute.h so the zb-specific permute(zb::RangeNd, +// p) overload is visible to two-phase lookup inside BTAS's generic permute +// template. +#include + #include #include -#include #include @@ -88,12 +93,15 @@ inline TiledArray::Range make_ta_range( return TiledArray::Range(range.lobound(), range.upbound()); } -/// makes TiledArray::Range from a btas::zb::RangeNd (zero-based, row-major) +/// makes TiledArray::Range from a btas::zb::RangeNd (zero-based) /// \param[in] range a btas::zb::RangeNd object -template +template <::blas::Layout Order, typename Ext, typename Ord, std::size_t MaxRank> inline TiledArray::Range make_ta_range( - const btas::zb::RangeNd& range) { + const btas::zb::RangeNd& range) { + TA_ASSERT(Order == ::blas::Layout::RowMajor && + "TiledArray::detail::make_ta_range(btas::zb::RangeNd): " + "not supported for col-major Order"); return TiledArray::Range(range.lobound(), range.upbound()); } @@ -148,17 +156,18 @@ namespace zb { // overloads taking a btas::zb::RangeNd must live here. /// Test if a btas::zb::RangeNd is congruent with another btas::zb::RangeNd -template -inline bool is_congruent(const btas::zb::RangeNd& r1, - const btas::zb::RangeNd& r2) { +template <::blas::Layout Order, typename Ext, typename Ord, std::size_t MaxRank> +inline bool is_congruent( + const btas::zb::RangeNd& r1, + const btas::zb::RangeNd& r2) { return (r1.rank() == r2.rank()) && std::equal(r1.extent_data(), r1.extent_data() + r1.rank(), r2.extent_data()); } /// Test if a btas::zb::RangeNd and a TA range are congruent -template -inline bool is_congruent(const btas::zb::RangeNd& r1, +template <::blas::Layout Order, typename Ext, typename Ord, std::size_t MaxRank> +inline bool is_congruent(const btas::zb::RangeNd& r1, const TiledArray::Range& r2) { return (r1.rank() == r2.rank()) && std::equal(r1.extent_data(), r1.extent_data() + r1.rank(), @@ -166,9 +175,10 @@ inline bool is_congruent(const btas::zb::RangeNd& r1, } /// Test if a TA range and a btas::zb::RangeNd are congruent -template -inline bool is_congruent(const TiledArray::Range& r1, - const btas::zb::RangeNd& r2) { +template <::blas::Layout Order, typename Ext, typename Ord, std::size_t MaxRank> +inline bool is_congruent( + const TiledArray::Range& r1, + const btas::zb::RangeNd& r2) { return (r1.rank() == r2.rank()) && std::equal(r1.extent_data(), r1.extent_data() + r1.rank(), r2.extent_data()); diff --git a/src/TiledArray/tensor/operators_body.ipp b/src/TiledArray/tensor/operators_body.ipp index 9758af5d24..4e2d736a84 100644 --- a/src/TiledArray/tensor/operators_body.ipp +++ b/src/TiledArray/tensor/operators_body.ipp @@ -154,10 +154,20 @@ inline std::basic_ostream& operator<<( for (auto&& idx : t.range()) { const auto& inner_t = *(t.data() + offset + t.range().ordinal(idx)); os << " " << idx << ":"; - TA::detail::NDArrayPrinter{}.print( - inner_t.data(), inner_t.range().rank(), - inner_t.range().extent_data(), inner_t.range().stride_data(), os, - more_than_1_batch ? 6 : 4); + using inner_range_t = + std::remove_cv_t>; + if constexpr (TA::detail::has_member_function_stride_data_anyreturn_v< + inner_range_t>) { + TA::detail::NDArrayPrinter{}.print( + inner_t.data(), inner_t.range().rank(), + inner_t.range().extent_data(), inner_t.range().stride_data(), os, + more_than_1_batch ? 6 : 4); + } else { + // Inner range doesn't expose stride_data (e.g. btas::zb::RangeNd, + // which intentionally synthesizes row-major strides on demand and + // stores none). Skip the strided pretty-printer for this element. + os << " "; + } os << "\n"; } } diff --git a/src/TiledArray/type_traits.h b/src/TiledArray/type_traits.h index 09d2824464..b18fc8c7ee 100644 --- a/src/TiledArray/type_traits.h +++ b/src/TiledArray/type_traits.h @@ -336,6 +336,9 @@ GENERATE_HAS_MEMBER_FUNCTION(total_size) GENERATE_HAS_MEMBER_FUNCTION_ANYRETURN(nbatch) GENERATE_HAS_MEMBER_FUNCTION(nbatch) +// Range-only +GENERATE_HAS_MEMBER_FUNCTION_ANYRETURN(stride_data) + GENERATE_HAS_MEMBER_FUNCTION_ANYRETURN(begin) GENERATE_HAS_MEMBER_FUNCTION(begin) GENERATE_HAS_MEMBER_FUNCTION_ANYRETURN(end) From d6a281df2ba64d16d005af138656787cf761a8f6 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Tue, 12 May 2026 23:47:25 -0400 Subject: [PATCH 5/5] address PR review feedback on the btas inner-tile path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - tensor/type_traits.h: fix the operators_body.ipp filename in the comment above ta_ops_match_tensor (was .inl, the new shared file is .ipp). And fix result_tensor_helper to fall back to TA::Tensor when TensorA_ lacks rebind_t — TensorInterface and ShiftWrapper satisfy is_tensor_v but don't expose rebind_t, so using it unconditionally would hard-error in code paths like tensor_contract called with view operands. - tensor/print.h, tensor/print.cpp: NDArrayPrinter::print/toString now accept differently-typed ExtentIndex/StrideIndex pointers, but the explicit instantiations only cover the default Range1::index1_type case — calls with other index types (e.g. a TA::Tensor whose tile range exposes mismatched extent/stride integer types) would hit unresolved-symbol link errors. Include print.ipp at the end of print.h so non-default instantiations are produced at use-site; the default-case `extern template` decls still keep the common-case instantiation centralized in print.cpp. --- src/TiledArray/tensor/print.cpp | 4 +++- src/TiledArray/tensor/print.h | 8 ++++++++ src/TiledArray/tensor/type_traits.h | 28 ++++++++++++++++++++-------- 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/TiledArray/tensor/print.cpp b/src/TiledArray/tensor/print.cpp index 7b02a2caaf..17fffae8c5 100644 --- a/src/TiledArray/tensor/print.cpp +++ b/src/TiledArray/tensor/print.cpp @@ -23,7 +23,9 @@ * */ -#include +// print.h pulls in print.ipp at the end so the template definitions are +// visible here for the explicit instantiations below. +#include namespace TiledArray { diff --git a/src/TiledArray/tensor/print.h b/src/TiledArray/tensor/print.h index abb1ab315e..42db06c615 100644 --- a/src/TiledArray/tensor/print.h +++ b/src/TiledArray/tensor/print.h @@ -142,4 +142,12 @@ TILEDARRAY_MAKE_NDARRAY_PRINTER_INSTANTIATION(int, wchar_t); } // namespace TiledArray +// Pull in the template definitions so callers using non-default +// ExtentIndex/StrideIndex pointer types (e.g. tensors over non-TA ranges +// whose extent_data() and stride_data() expose differently-sized integer +// pointers) get an implicit instantiation. The explicit instantiations +// declared above cover the common Range1::index1_type case and are +// satisfied by the symbols emitted from print.cpp. +#include + #endif // TILEDARRAY_SRC_TILEDARRAY_TENSOR_PRINT_H__INCLUDED diff --git a/src/TiledArray/tensor/type_traits.h b/src/TiledArray/tensor/type_traits.h index bf4836e968..ebc04ebe23 100644 --- a/src/TiledArray/tensor/type_traits.h +++ b/src/TiledArray/tensor/type_traits.h @@ -118,7 +118,7 @@ template inline constexpr const bool is_nested_tensor_v = is_nested_tensor::value; /// Predicate used by the shared operator body in -/// @c TiledArray/tensor/operators_body.inl to gate the element-wise tensor +/// @c TiledArray/tensor/operators_body.ipp to gate the element-wise tensor /// operators that are injected into @c namespace TiledArray . The btas-side /// copy of the same operators (in @c external/btas.h) partial-specializes /// this predicate to @c std::false_type for @c btas::Tensor so the two @@ -490,6 +490,15 @@ template constexpr bool is_binop_v>>{true}; +// Detect whether T exposes a `rebind_t` member template. Both TA::Tensor +// and btas::Tensor do; view types like TensorInterface and ShiftWrapper do +// not, so callers must fall back to a concrete tensor for the result type. +template +struct has_rebind_t : std::false_type {}; +template +struct has_rebind_t>> + : std::true_type {}; + template >> @@ -504,15 +513,18 @@ struct result_tensor_helper { using numeric_type = binop_result_t; // Result tensor type stays in TensorA's family with the allocator rebound to - // hold `numeric_type`. Both TA::Tensor and btas::Tensor expose this as + // hold `numeric_type`. TA::Tensor and btas::Tensor expose this as // `rebind_t` (TA::Tensor via std::allocator_traits::rebind_alloc; btas - // via storage_traits::rebind_t). An explicit @tparam Allocator override only - // applies when TensorA is a TA::Tensor. - using result_type = - std::conditional_t || - !is_ta_tensor_v, + // via storage_traits::rebind_t). View types (TensorInterface, ShiftWrapper) + // satisfy is_tensor_v but have no `rebind_t` — fall back to TA::Tensor for + // those. An explicit @tparam Allocator override only applies when TensorA + // is a TA::Tensor. + using result_type = std::conditional_t< + std::is_same_v || !is_ta_tensor_v, + std::conditional_t::value, typename TensorA_::template rebind_t, - TA::Tensor>; + TA::Tensor>, + TA::Tensor>; }; } // namespace