Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions external/versions.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ set(TA_TRACKED_MADNESS_PREVIOUS_TAG 7d8aaf9d51981e4accf4d84742270d1473f8ca2e)
set(TA_TRACKED_MADNESS_VERSION 0.10.1)
set(TA_TRACKED_MADNESS_PREVIOUS_VERSION 0.10.1)

set(TA_TRACKED_BTAS_TAG 287b145ead818a0332f2b7ce0b7375a83d328bae)
set(TA_TRACKED_BTAS_PREVIOUS_TAG 62d57d9b1e0c733b4b547bc9cfdd07047159dbca)
set(TA_TRACKED_BTAS_TAG 245e49f117981d6124e0f1aa0d1ae72f1c16318b)
set(TA_TRACKED_BTAS_PREVIOUS_TAG 7e64fbad97c76f316f313f4c8ed3fca5445da15f)

set(TA_TRACKED_LIBRETT_TAG 6eed30d4dd2a5aa58840fe895dcffd80be7fbece)
set(TA_TRACKED_LIBRETT_PREVIOUS_TAG 354e0ccee54aeb2f191c3ce2c617ebf437e49d83)
Expand Down
14 changes: 13 additions & 1 deletion src/TiledArray/dist_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -1852,8 +1852,20 @@ size_t volume(const DistArray<Tile, Policy>& array) {

auto local_vol = [&vol](Tile const& in_tile) {
if constexpr (detail::is_tensor_of_tensor_v<Tile>) {
// Inner tile pointer is passed (see is_reduce_op_v in
// tensor/type_traits.h selecting the pointer-passing tensor_reduce
// overload). Prefer `total_size()` (TA::Tensor exposes it, batches
// included); fall back to `size()` for inner tile types that don't
// (e.g. btas::Tensor).
auto reduce_op = [](size_t& MADNESS_RESTRICT result, auto&& arg) {
result += arg->total_size();
using InnerTile =
std::remove_cv_t<std::remove_reference_t<decltype(*arg)>>;
if constexpr (detail::has_member_function_total_size_anyreturn_v<
InnerTile>) {
result += arg->total_size();
} else {
result += arg->size();
}
};
auto join_op = [](auto& MADNESS_RESTRICT result, size_t count) {
result += count;
Expand Down
46 changes: 38 additions & 8 deletions src/TiledArray/einsum/tiledarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,31 @@ template <typename ArrayT1, typename ArrayT2>
constexpr bool AreArraySame =
AreArrayT<ArrayT1, ArrayT2> || AreArrayToT<ArrayT1, ArrayT2>;

// "Denested" companion of a ToT array: drops the inner-tile nesting, leaving
// a regular (non-nested) DistArray. For ToT inputs, the outer tile of the
// denested array is always TA::Tensor — nested inner-tile types (e.g.
// btas::Tensor) are only valid as the *innermost* tile and don't support the
// outer-tile operations einsum needs (permute/reshape/batch/range+lambda
// ctor). So for ToT we drop the inner tile and re-wrap its numeric type in
// TA::Tensor. For non-ToT inputs, the original "drop one level" behavior is
// preserved.
namespace detail_denested {
template <typename Array, typename Enabler = void>
struct denested {
using type = DistArray<typename Array::value_type::value_type,
typename Array::policy_type>;
};
template <typename Array>
struct denested<Array,
std::enable_if_t<TiledArray::detail::is_tensor_of_tensor_v<
typename Array::value_type>>> {
using type = DistArray<
TA::Tensor<typename Array::value_type::value_type::numeric_type>,
typename Array::policy_type>;
};
} // namespace detail_denested
template <typename Array>
using DeNestedArray = DistArray<typename Array::value_type::value_type,
typename Array::policy_type>;
using DeNestedArray = typename detail_denested::denested<Array>::type;

template <typename Array1, typename Array2>
using MaxNestedArray = std::conditional_t<(detail::nested_rank<Array2> >
Expand Down Expand Up @@ -496,13 +518,21 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
// Step III: C1(ijpqab) -> C2(ijpq)
// Step IV: C2(ijpq) -> C(ipjq)

// Build a "denested" tile: one scalar per outer index, summed over the
// inner tile. The result tile's outer type is TA::Tensor (inner tile
// types like btas::Tensor are only valid as the innermost tile and don't
// expose the range+lambda ctor used here).
auto sum_tot_2_tos = [](auto const &tot) {
using tot_t = std::remove_reference_t<decltype(tot)>;
typename tot_t::value_type result(tot.range(), [tot](auto &&ix) {
using numeric_type = typename tot_t::numeric_type;
TA::Tensor<numeric_type> result(tot.range(), [tot](auto &&ix) {
// unqualified `sum` so ADL finds the right overload for both
// TA::Tensor inner (free fn in namespace TiledArray, calls .sum())
// and btas::Tensor inner (free fn in namespace btas).
if (!tot(ix).empty())
return tot(ix).sum();
return sum(tot(ix));
else
return typename tot_t::numeric_type{};
return numeric_type{};
});
return result;
};
Expand Down Expand Up @@ -722,7 +752,7 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
using TensorT = std::remove_reference_t<decltype(el)>;

for (auto i = 0; i < vol; ++i)
el.add_to(element_product_op(aik.data()[i], bik.data()[i]));
add_to(el, element_product_op(aik.data()[i], bik.data()[i]));

} else if constexpr (!AreArraySame<ArrayA, ArrayB>) {
auto aik = ai.batch(k);
Expand All @@ -734,9 +764,9 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,

for (auto i = 0; i < vol; ++i)
if constexpr (IsArrayToT<ArrayA>) {
el.add_to(aik.data()[i].scale(bik.data()[i]));
add_to(el, scale(aik.data()[i], bik.data()[i]));
} else {
el.add_to(bik.data()[i].scale(aik.data()[i]));
add_to(el, scale(bik.data()[i], aik.data()[i]));
}

} else {
Expand Down
Loading
Loading