From f9961279f617b7adca4152395b490a90a887d162 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 25 Mar 2026 09:31:39 -0500 Subject: [PATCH 01/31] add MapAsIndexLambdaMixin --- pytato/transform/lower_to_index_lambda.py | 66 ++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 773ee4cbe..31f092bfc 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -2,6 +2,7 @@ .. currentmodule:: pytato.transform.lower_to_index_lambda .. autofunction:: to_index_lambda +.. autoclass:: MapAsIndexLambdaMixin """ from __future__ import annotations @@ -28,8 +29,9 @@ THE SOFTWARE. """ +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from constantdict import constantdict from typing_extensions import Never @@ -65,6 +67,8 @@ from pytato.tags import AssumeNonNegative from pytato.transform import ( Mapper, + P, + ResultT, _verify_is_array, ) from pytato.utils import normalized_slice_does_not_change_axis @@ -790,4 +794,64 @@ def to_index_lambda(expr: Array) -> IndexLambda: assert isinstance(res, IndexLambda) return res + +class MapAsIndexLambdaMixin(ABC, Generic[ResultT, P]): + """ + Mixin that, where possible, lowers arrays to :class:`~pytato.array.IndexLambda` + and calls :meth:`map_as_index_lambda` on them. + + .. automethod:: map_as_index_lambda + """ + @abstractmethod + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + """ + Map *expr* via its :class:`~pytato.array.IndexLambda` representation + *idx_lambda*. + """ + + def map_index_lambda( + self, expr: IndexLambda, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, expr, *args, **kwargs) + + def map_stack(self, expr: Stack, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_roll(self, expr: Roll, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_axis_permutation( + self, expr: AxisPermutation, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_basic_index( + self, expr: BasicIndex, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_contiguous_advanced_index( + self, expr: AdvancedIndexInContiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_non_contiguous_advanced_index( + self, expr: AdvancedIndexInNoncontiguousAxes, + *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_reshape(self, expr: Reshape, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_concatenate( + self, expr: Concatenate, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_einsum(self, expr: Einsum, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + def map_csr_matmul( + self, expr: CSRMatmul, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.map_as_index_lambda(expr, to_index_lambda(expr), *args, **kwargs) + + # vim:fdm=marker From 55c8afd211e352b2f3a8cc396368bf2d934e6a6c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 24 Apr 2026 10:13:34 -0500 Subject: [PATCH 02/31] reimplement AxesTagsEquationCollector using MapAsIndexLambdaMixin --- pytato/transform/metadata.py | 74 ++++++------------------------------ 1 file changed, 12 insertions(+), 62 deletions(-) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 36aab2391..957c3ce35 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -54,7 +54,7 @@ ) from bidict import bidict -from typing_extensions import Never +from typing_extensions import Never, override import pymbolic.primitives as prim from pymbolic.typing import Expression @@ -63,12 +63,7 @@ from pytato.array import ( AbstractResultWithNamedArrays, - AdvancedIndexInContiguousAxes, Array, - AxisPermutation, - BasicIndex, - Concatenate, - CSRMatmul, DictOfNamedArrays, Einsum, EinsumReductionAxis, @@ -76,7 +71,6 @@ InputArgumentBase, NamedArray, Reshape, - Stack, ) from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import NamedCallResult @@ -91,7 +85,7 @@ Mapper, TransformMapperCache, ) -from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.transform.lower_to_index_lambda import MapAsIndexLambdaMixin logger = logging.getLogger(__name__) @@ -156,7 +150,9 @@ def map_constant(self, expr: object) -> dict[BindingName, # {{{ AxesTagsEquationCollector -class AxesTagsEquationCollector(Mapper[None, Never, []]): +class AxesTagsEquationCollector( + Mapper[None, Never, []], + MapAsIndexLambdaMixin[None, []]): r""" Records equations arising from operand/output axes equivalence for an array operation. An equation is recorded for "straight-through" axes in expressions, @@ -183,15 +179,12 @@ class AxesTagsEquationCollector(Mapper[None, Never, []]): iaxis)`` to the :class:`str` by which it will be referenced in :attr:`equations`. - .. automethod:: map_index_lambda .. automethod:: map_placeholder .. automethod:: map_data_wrapper .. automethod:: map_size_param .. automethod:: map_reshape - .. automethod:: map_basic_index - .. automethod:: map_contiguous_advanced_index - .. automethod:: map_stack - .. automethod:: map_concatenate + + .. automethod:: map_as_index_lambda .. note:: @@ -281,20 +274,16 @@ def _map_input_base(self, expr: InputArgumentBase) -> None: map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_index_lambda(self, expr: IndexLambda) -> None: - for bnd in expr.bindings.values(): - self.rec(bnd) - - self.add_equations_using_index_lambda_version_of_expr(expr) - - def add_equations_using_index_lambda_version_of_expr(self, expr: Array) -> None: + @override + def map_as_index_lambda(self, expr: Array, idx_lambda: IndexLambda) -> None: """ Equations are added between an axis of the bindings of *expr* and an axis of *expr* if the binding's axis is indexed by by a :class:`~pymbolic.Variable` which has a name that follows the reserved iname format, "_[0-9]+", and the axis of the output specified by the iname. """ - idx_lambda = expr if isinstance(expr, IndexLambda) else to_index_lambda(expr) + for bnd in idx_lambda.bindings.values(): + self.rec(bnd) index_expr_used = BindingSubscriptsCollector()(idx_lambda.expr) @@ -327,32 +316,7 @@ def add_equations_using_index_lambda_version_of_expr(self, expr: Array) -> None: # Other cases are considered "complicated" and we won't # handle them here. - def map_stack(self, expr: Stack) -> None: - for ary in expr.arrays: - self.rec(ary) - - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_concatenate(self, expr: Concatenate) -> None: - for ary in expr.arrays: - self.rec(ary) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_axis_permutation(self, expr: AxisPermutation - ) -> None: - self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_basic_index(self, expr: BasicIndex) -> None: - self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_contiguous_advanced_index(self, - expr: AdvancedIndexInContiguousAxes - ) -> None: - self.rec(expr.array) - self.add_equations_using_index_lambda_version_of_expr(expr) - + @override def map_reshape(self, expr: Reshape) -> None: """ Reshaping generally does not preserve the axis between its input and @@ -386,20 +350,6 @@ def map_reshape(self, expr: Reshape) -> None: assert i_in_axis == expr.array.ndim - def map_einsum(self, expr: Einsum) -> None: - for arg in expr.args: - self.rec(arg) - self.add_equations_using_index_lambda_version_of_expr(expr) - - def map_csr_matmul(self, expr: CSRMatmul) -> None: - for ary in ( - expr.matrix.elem_values, - expr.matrix.elem_col_indices, - expr.matrix.row_starts, - expr.array): - self.rec(ary) - self.add_equations_using_index_lambda_version_of_expr(expr) - def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: for _, subexpr in sorted(expr._data.items()): self.rec(subexpr) From bf097005d09df22e8fc752a067acf8e2d08fda93 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 24 Oct 2025 15:51:27 -0500 Subject: [PATCH 03/31] add flop counting functions --- pytato/analysis/__init__.py | 361 +++++++++++++++++++++++++++++++++++- pytato/reductions.py | 36 ++++ pytato/scalar_expr.py | 136 ++++++++++++++ pytato/transform/calls.py | 94 +++++++++- pytato/utils.py | 21 ++- test/test_pytato.py | 281 ++++++++++++++++++++++++++++ 6 files changed, 924 insertions(+), 5 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 51d70fba2..9235216d5 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -27,16 +27,19 @@ """ from collections import defaultdict -from typing import TYPE_CHECKING, Any, overload +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast, overload from orderedsets import FrozenOrderedSet, OrderedSet from typing_extensions import Never, Self, override from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper +from pytools import product from pytato.array import ( Array, + ArrayOrScalar, Concatenate, CSRMatmul, DictOfNamedArrays, @@ -52,15 +55,22 @@ from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall, LoopyCallResult -from pytato.scalar_expr import SCALAR_CLASSES +from pytato.scalar_expr import ( + SCALAR_CLASSES, + FlopCounter as ScalarFlopCounter, +) from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, + ArrayOrNamesTc, CachedWalkMapper, CombineMapper, Mapper, VisitKeyT, + map_and_copy, ) +from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import is_materializable if TYPE_CHECKING: @@ -92,6 +102,13 @@ .. autoclass:: MaterializedNodeCollector .. autofunction:: collect_materialized_nodes + +.. autoclass:: UndefinedOpFlopCountError +.. autofunction:: get_default_op_name_to_num_flops +.. autofunction:: get_num_flops +.. autofunction:: get_materialized_node_flop_counts +.. autoclass:: UnmaterializedNodeFlopCounts +.. autofunction:: get_unmaterialized_node_flop_counts """ @@ -886,4 +903,344 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # }}} + +# {{{ flop counting + +def _is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + return ( + is_materializable(expr) + and bool(expr.tags_of_type(ImplStored))) + + +def _is_unmaterialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + return ( + is_materializable(expr) + and not bool(expr.tags_of_type(ImplStored))) + + +@dataclass +class UndefinedOpFlopCountError(ValueError): + op_name: str + + +class _PerEntryFlopCounter(CombineMapper[int, Never, []]): + def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() + self.scalar_flop_counter: ScalarFlopCounter = ScalarFlopCounter( + op_name_to_num_flops) + self.node_to_nflops: dict[Array, int] = {} + + @override + def combine(self, *args: int) -> int: + return sum(args) + + @override + def rec(self, expr: ArrayOrNames) -> int: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: int + if _is_unmaterialized(expr): + assert isinstance(expr, Array) + self_nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + if not isinstance(self_nflops, int): + from pytato.scalar_expr import InputGatherer as ScalarInputGatherer + var_names: set[str] = set(ScalarInputGatherer()(self_nflops)) + var_names.discard("nflops") + if var_names: + raise UndefinedOpFlopCountError(next(iter(var_names))) from None + else: + raise AssertionError from None + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = self_nflops + cast("int", Mapper.rec(self, expr)) + else: + result = 0 + if isinstance(expr, Array): + self.node_to_nflops[expr] = result + return self._cache_add(inputs, result) + + +class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the number of floating point operations of each materialized + expression in a DAG. + + .. note:: + + Flops from nodes inside function calls are accumulated onto the corresponding + call node. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + _visited_functions: set[VisitKeyT] | None = None, + _function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] | None = None + ) -> None: + super().__init__(_visited_functions=_visited_functions) + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} + self.call_to_nflops: dict[Call, ArrayOrScalar] = {} + self._function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] = \ + _function_to_nflops if _function_to_nflops is not None else {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + op_name_to_num_flops=self.op_name_to_num_flops, + _visited_functions=self._visited_functions, + _function_to_nflops=self._function_to_nflops) + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for subexpr in expr.returns.values(): + # Assume that any calls that haven't been inlined have their functions' + # outputs materialized + assert not _is_unmaterialized(subexpr) + new_mapper(subexpr) + + self._function_to_nflops[expr] = ( + sum(new_mapper.materialized_node_to_nflops.values()) + + sum(new_mapper.call_to_nflops.values())) + + self.post_visit(expr) + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + self.rec_function_definition(expr.function) + for bnd in expr.bindings.values(): + # Assume that any calls that haven't been inlined have their inputs + # materialized + assert not _is_unmaterialized(bnd) + self.rec(bnd) + + self.call_to_nflops[expr] = self._function_to_nflops[expr.function] + + self.post_visit(expr) + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not _is_materialized(expr): + return + assert isinstance(expr, Array) + unmaterialized_expr = expr.without_tags(ImplStored()) + self._per_entry_flop_counter(unmaterialized_expr) + self.materialized_node_to_nflops[expr] = ( + product(expr.shape) + * self._per_entry_flop_counter.node_to_nflops[unmaterialized_expr]) + + +class _UnmaterializedSubexpressionUseCounter( + CombineMapper[dict[Array, int], Never, []]): + @override + def combine(self, *args: dict[Array, int]) -> dict[Array, int]: + result: dict[Array, int] = defaultdict(int) + for arg in args: + for ary, nuses in arg.items(): + result[ary] += nuses + return result + + @override + def rec(self, expr: ArrayOrNames) -> dict[Array, int]: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: dict[Array, int] + if _is_unmaterialized(expr): + assert isinstance(expr, Array) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = self.combine( + {expr: 1}, cast("dict[Array, int]", Mapper.rec(self, expr))) + else: + result = {} + return self._cache_add(inputs, result) + + +@dataclass +class UnmaterializedNodeFlopCounts: + materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] + nflops_if_materialized: ArrayOrScalar + + +class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the accumulated number of floating point operations that each + unmaterialized expression contributes to materialized expressions in the DAG. + + .. note:: + + This mapper does not descend into functions. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + _visited_functions: set[VisitKeyT] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.unmaterialized_node_to_flop_counts: \ + dict[Array, UnmaterializedNodeFlopCounts] = {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not _is_materialized(expr): + return + assert isinstance(expr, Array) + unmaterialized_expr = expr.without_tags(ImplStored()) + subexpr_to_nuses = _UnmaterializedSubexpressionUseCounter()( + unmaterialized_expr) + del subexpr_to_nuses[unmaterialized_expr] + self._per_entry_flop_counter(unmaterialized_expr) + for subexpr, nuses in subexpr_to_nuses.items(): + per_entry_nflops = self._per_entry_flop_counter.node_to_nflops[subexpr] + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = product(subexpr.shape) * per_entry_nflops + flop_counts = UnmaterializedNodeFlopCounts({}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses * product(expr.shape) * per_entry_nflops) + + +# FIXME: Should this be added to normalize_outputs? +def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + # Make sure call bindings and function results are materialized + from pytato.transform.calls import normalize_calls + expr = normalize_calls(expr) + + # Make sure outputs are materialized + if isinstance(expr, DictOfNamedArrays): + output_to_materialized_output: dict[Array, Array] = { + ary: ary.tagged(ImplStored()) if is_materializable(ary) else ary + for ary in expr._data.values()} + + def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: + if not isinstance(ary, Array): + return ary + try: + return output_to_materialized_output[ary] + except KeyError: + return ary + + expr = map_and_copy(expr, replace_with_materialized) + + return expr + + +def get_default_op_name_to_num_flops() -> dict[str, int]: + """ + Returns a mapping from operator name to floating point operation count for + operators that are almost always a single flop. + """ + return { + "+": 1, + "*": 1, + "==": 1, + "!=": 1, + "<": 1, + ">": 1, + "<=": 1, + ">=": 1, + "min": 1, + "max": 1} + + +def get_num_flops( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> ArrayOrScalar: + """Count the total number of floating point operations in the DAG *expr*.""" + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return ( + sum(fc.materialized_node_to_nflops.values()) + + sum(fc.call_to_nflops.values())) + + +def get_materialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, ArrayOrScalar]: + """ + Returns a dictionary mapping materialized nodes in DAG *expr* to their floating + point operation count. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.materialized_node_to_nflops + + +def get_unmaterialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, UnmaterializedNodeFlopCounts]: + """ + Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a + :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count + information. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.unmaterialized_node_to_flop_counts + +# }}} + + # vim: fdm=marker diff --git a/pytato/reductions.py b/pytato/reductions.py index 6efa45ac2..a35c1d98a 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -34,6 +34,7 @@ import numpy as np from constantdict import constantdict +from typing_extensions import override import pymbolic.primitives as prim from pymbolic import ArithmeticExpression @@ -80,10 +81,15 @@ class _NoValue: class ReductionOperation(ABC): """ + .. automethod:: scalar_op_name .. automethod:: neutral_element .. automethod:: __hash__ .. automethod:: __eq__ """ + @classmethod + @abstractmethod + def scalar_op_name(cls) -> str: + pass @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -110,16 +116,31 @@ def __eq__(self, other: object) -> bool: class SumReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "+" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 0 class ProductReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "*" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 1 class MaxReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "max" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("-inf")) @@ -130,6 +151,11 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MinReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "min" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("inf")) @@ -140,11 +166,21 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AllReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "or" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(True) class AnyReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "and" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(False) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 06219e540..4065d63b5 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -60,6 +60,7 @@ from loopy.symbolic import guarded_pwaff_from_expr from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass from pymbolic.mapper import ( + Collector, CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, P, @@ -74,6 +75,7 @@ from pymbolic.mapper.distributor import DistributeMapper as DistributeMapperBase from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase from pymbolic.mapper.flattener import FlattenMapper as FlattenMapperBase +from pymbolic.mapper.flop_counter import FlopCounterBase from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import SubstitutionMapper as SubstitutionMapperBase from pymbolic.typing import Integer @@ -247,6 +249,140 @@ def map_type_cast(self, inner_str = self.rec(expr.inner_expr, PREC_NONE, *args, **kwargs) return f"cast({expr.dtype}, {inner_str})" + +class InputGatherer(Collector[str, []]): + @override + def map_variable(self, expr: prim.Variable) -> set[str]: + return {expr.name} + + +class FlopCounter(FlopCounterBase): + def __init__( + self, + op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): + super().__init__() + self.op_name_to_num_flops: dict[str, ArithmeticExpression] + if op_name_to_num_flops: + self.op_name_to_num_flops = dict(op_name_to_num_flops) + else: + self.op_name_to_num_flops = {} + + def _get_op_nflops(self, name: str) -> ArithmeticExpression: + try: + return self.op_name_to_num_flops[name] + except KeyError: + from pymbolic import var + result = var("nflops")(var(name)) + self.op_name_to_num_flops[name] = result + return result + + @override + def map_call(self, expr: prim.Call) -> ArithmeticExpression: + assert isinstance(expr.function, prim.Variable) + return ( + self._get_op_nflops(expr.function.name) + + sum(self.rec(child) for child in expr.parameters)) + + @override + def map_subscript(self, expr: prim.Subscript) -> ArithmeticExpression: + # Assume calculations inside subscripts are performed on non-floats + return 0 + + @override + def map_sum(self, expr: prim.Sum) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + if expr.children: + return ( + self._get_op_nflops("+") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_product(self, expr: prim.Product) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("*") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_quotient(self, expr: prim.Quotient) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + return ( + self._get_op_nflops("/") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: + return ( + self._get_op_nflops("//") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_power(self, expr: prim.Power) -> ArithmeticExpression: + if isinstance(expr.exponent, int): + if expr.exponent >= 0: + return ( + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("/") + + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("**") + + self.rec(expr.base) + + self.rec(expr.exponent)) + + @override + def map_comparison(self, expr: prim.Comparison) -> ArithmeticExpression: + return ( + self._get_op_nflops(expr.operator) + + self.rec(expr.left) + + self.rec(expr.right)) + + @override + def map_if(self, expr: prim.If) -> ArithmeticExpression: + return ( + self.rec(expr.condition) + + self.rec(expr.then) + + self.rec(expr.else_)) + + @override + def map_max(self, expr: prim.Max) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("max") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_min(self, expr: prim.Min) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("min") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_nan(self, expr: prim.NaN) -> ArithmeticExpression: + return 0 + + def map_reduce(self, expr: Reduce) -> ArithmeticExpression: + result = self.rec(expr.inner_expr) + nflops_op = self._get_op_nflops(expr.op.scalar_op_name()) + for lower_bd, upper_bd in expr.bounds.values(): + nops = upper_bd - lower_bd + result = result * nops + nflops_op * (nops-1) + + return result + # }}} diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index ffad1e5f7..e82c2b1e0 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,6 +1,7 @@ """ .. currentmodule:: pytato.transform.calls +.. autofunction:: normalize_calls .. autofunction:: inline_calls .. autofunction:: tag_all_calls_to_be_inlined """ @@ -32,20 +33,26 @@ from typing import TYPE_CHECKING, cast -from typing_extensions import Self +from immutabledict import immutabledict +from typing_extensions import Never, Self, override from pytato.array import ( AbstractResultWithNamedArrays, Array, + DataWrapper, DictOfNamedArrays, Placeholder, + SizeParam, + make_dict_of_named_arrays, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.tags import InlineCallTag +from pytato.tags import ImplStored, InlineCallTag from pytato.transform import ( ArrayOrNames, ArrayOrNamesTc, + CombineMapper, CopyMapper, + Mapper, TransformMapperCache, _verify_is_array, deduplicate, @@ -56,6 +63,89 @@ from collections.abc import Mapping +# {{{ normalizing + +class _LocalStackCallBindingCollector(CombineMapper[frozenset[Array], Never, []]): + """Mapper to collect bindings of calls on the current call stack.""" + @override + def combine(self, *args: frozenset[Array]) -> frozenset[Array]: + from functools import reduce + return reduce(lambda a, b: a | b, args, cast("frozenset[Array]", frozenset())) + + @override + def map_call(self, expr: Call) -> frozenset[Array]: + return frozenset(expr.bindings.values()) + + +class _CallMaterializer(CopyMapper): + """Mapper to add materialization tags for call bindings and function results.""" + def __init__( + self, + local_call_bindings: frozenset[Array], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.local_call_bindings: frozenset[Array] = local_call_bindings + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + local_call_bindings = _LocalStackCallBindingCollector()( + make_dict_of_named_arrays(function.returns)) + return type(self)( + local_call_bindings, + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + + def _materialize_if_possible(self, expr: ArrayOrNames) -> ArrayOrNames: + if ( + isinstance(expr, Array) + and not isinstance(expr, + (DataWrapper, Placeholder, SizeParam, NamedCallResult))): + return expr.tagged(ImplStored()) + else: + return expr + + @override + def map_function_definition(self, + expr: FunctionDefinition) -> FunctionDefinition: + new_mapper = self.clone_for_callee(expr) + new_returns: Mapping[str, Array] = immutabledict({ + name: self._materialize_if_possible(_verify_is_array(new_mapper(ret))) + for name, ret in expr.returns.items()}) + return expr.replace_if_different(returns=new_returns) + + @override + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = cast("ArrayOrNames", Mapper.rec(self, expr)) + if expr in self.local_call_bindings: + result = self._materialize_if_possible(result) + return self._cache_add(inputs, result) + + +def normalize_calls(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + """ + Ensure that calls/functions are defined uniformly. + + Adds any missing materialization tags for call bindings and function results. + """ + local_call_bindings = _LocalStackCallBindingCollector()(expr) + return _CallMaterializer(local_call_bindings)(expr) + +# }}} + + # {{{ inlining class PlaceholderSubstitutor(CopyMapper): diff --git a/pytato/utils.py b/pytato/utils.py index 19dc08ef7..92c1a1cbb 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -61,7 +61,7 @@ ScalarExpression, TypeCast, ) -from pytato.transform import CachedMapper +from pytato.transform import ArrayOrNames, CachedMapper if TYPE_CHECKING: @@ -69,6 +69,8 @@ from pytools.tag import Tag + from pytato.function import FunctionDefinition + __doc__ = """ Helper routines @@ -80,6 +82,7 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str +.. autofunction:: is_materializable References ^^^^^^^^^^ @@ -735,4 +738,20 @@ def get_einsum_specification(expr: Einsum) -> str: for i in range(expr.ndim)) return f"{','.join(input_specs)}->{output_spec}" + + +def is_materializable(expr: ArrayOrNames | FunctionDefinition) -> bool: + """ + Returns *True* if *expr* is an instance of an array type that can be materialized. + """ + from pytato.array import InputArgumentBase, NamedArray + from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder + return ( + isinstance(expr, Array) + and not isinstance(expr, ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, NamedArray, DistributedRecv, + DistributedSendRefHolder))) + + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index ca3786e32..848f5b7f0 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -952,6 +952,287 @@ def f(a, b): # }}} +def test_scalar_flop_count(): + from pytato.scalar_expr import FlopCounter + fc = FlopCounter({ + "+": 1, + "*": 1, + "/": 4, + "//": 4, + "%": 4, + "**": 8, + "<": 1, + "min": 1, + "max": 1, + "f": 32}) + + import pymbolic.primitives as prim + from pymbolic import Variable + + x = Variable("x") + y = Variable("y") + + assert fc(Variable("f")(x)) == 32 + + assert fc(x[0]) == 0 + + assert fc(x + 2) == 1 + assert fc(2 + y) == 1 + assert fc(x + y) == 1 + + assert fc(prim.Sum((2, x, y))) == 2 + + assert fc(x - 2) == 1 + assert fc(2 - y) == 2 + assert fc(x - y) == 2 + + assert fc(x * 2) == 1 + assert fc(2 * y) == 1 + assert fc(x * y) == 1 + + assert fc(prim.Product((2, x, y))) == 2 + + assert fc(x.or_(y)) == 0 + assert fc(x.and_(y)) == 0 + + assert fc(x / 2) == 4 + assert fc(2 / y) == 4 + assert fc(x / y) == 4 + + assert fc(x // 2) == 4 + + assert fc(x % 2) == 0 + + assert fc(x ** 3) == 3 + assert fc(x ** 0.3) == 8 + + assert fc(x.lt(y)) == 1 + + assert fc(prim.If(x, x, y)) == 0 + + assert fc(prim.Min((2, x, y))) == 2 + assert fc(prim.Max((2, x, y))) == 2 + + from constantdict import constantdict + + from pytato.reductions import SumReductionOperation + from pytato.scalar_expr import Reduce + + assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) == 9 + + +def test_flop_count(): + from pytato.analysis import ( + UndefinedOpFlopCountError, + get_default_op_name_to_num_flops, + get_num_flops, + ) + from pytato.tags import ImplStored + + # {{{ basic expression + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = u - v + + # expr[i, j] = 2*(x[i, j] + y[i, j]) + (-1)*3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*6 + + # }}} + + # {{{ expression with operators that don't have default flop counts + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + expr = pt.cmath.exp(x / y) + + with pytest.raises(UndefinedOpFlopCountError): + get_num_flops(expr) + + op_name_to_num_flops = get_default_op_name_to_num_flops() + op_name_to_num_flops.update({ + "/": 4, + "pytato.c99.exp": 8}) + + assert get_num_flops(expr, op_name_to_num_flops) == 40*12 + + # }}} + + # {{{ multiple expressions + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = pt.make_dict_of_named_arrays({"u": u, "v": v}) + + # expr["u"][i, j] = 2*(x[i, j] + y[i, j]) + # expr["v"][i, j] = 3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*4 + + # }}} + + # {{{ subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # expr[i, j] = 2*(x[2*i, j] + y[2*i, j]) + (-1)*3*(x[2*i, j] + y[2*i, j]) + assert get_num_flops(expr) == 20*6 + + # }}} + + # {{{ materialized array + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert get_num_flops(expr) == 40 + 40*4 + + # }}} + + # {{{ materialized array and subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[2*i, j] + (-1)*3*z[2*i, j] + assert get_num_flops(expr) == 40 + 20*4 + + # }}} + + # {{{ function call + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + def f(x, y): + z = x + y + return 2*z, 3*z + + u, v = pt.trace_call(f, x, y) + expr = u - v + + # u[i, j] = 2*(x[i, j] + y[i, j]) + # v[i, j] = 3*(x[i, j] + y[i, j]) + # expr[i, j] = u[i, j] + (-1)*v[i, j] + assert get_num_flops(expr) == 40*2 + 40*2 + 40*2 + + # }}} + + # {{{ einsum + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->ijk", x, y) + + # expr[i, j, k] = x[i, j, k] * y[j, k] + assert get_num_flops(expr) == 24 + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->i", x, y) + + # expr[i] = sum(sum(x[i, j, k] * y[j, k], j), k) + assert get_num_flops(expr) == 2*(4 * (3*1 + 2) + 3) + + # }}} + + +def test_materialized_node_flop_counts(): + from pytato.analysis import get_materialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert len(materialized_node_to_flop_count) == 2 + assert z in materialized_node_to_flop_count + assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert materialized_node_to_flop_count[z] == 40 + assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 + + +def test_unmaterialized_node_flop_counts(): + from pytato.analysis import get_unmaterialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + # Make a reduction over a bunch of expressions that reference z + z = x + y + w = [i*z for i in range(1, 11)] + s = [w[0]] + for w_i in w[1:-1]: + s.append(s[-1] + w_i) + expr = s[-1] + w[-1] + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + materialized_expr = expr.tagged(ImplStored()) + + # Everything except expr stays unmaterialized + assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 + assert z in unmaterialized_node_to_flop_counts + assert all(w_i in unmaterialized_node_to_flop_counts for w_i in w) + assert all(s_i in unmaterialized_node_to_flop_counts for s_i in s) + flop_counts = unmaterialized_node_to_flop_counts[z] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*10 + assert flop_counts.nflops_if_materialized == 40 + for w_i in w: + flop_counts = unmaterialized_node_to_flop_counts[w_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2 + assert flop_counts.nflops_if_materialized == 40*2 + for i, s_i in enumerate(s): + flop_counts = unmaterialized_node_to_flop_counts[s_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2*(i+1) + 40*i + assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) From 182f9a1452867c604e7367fc8255324f44b625e5 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 27 Oct 2025 16:17:59 -0500 Subject: [PATCH 04/31] Update baseline --- .basedpyright/baseline.json | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 1e5fb4d7b..a27334ffe 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1936,6 +1936,22 @@ "endColumn": 25, "lineCount": 1 } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 51, + "endColumn": 61, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 56, + "endColumn": 66, + "lineCount": 1 + } } ], "./pytato/array.py": [ @@ -7675,6 +7691,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 42, + "endColumn": 52, + "lineCount": 1 + } + }, { "code": "reportUnannotatedClassAttribute", "range": { From e8bad597fc5faac47989cb3816009a59a0acb45c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Nov 2025 16:49:36 -0600 Subject: [PATCH 05/31] add note to docs about assumptions when handling conditional expressions --- pytato/analysis/__init__.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 9235216d5..11df8e79d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1182,7 +1182,15 @@ def get_num_flops( expr: ArrayOrNames, op_name_to_num_flops: Mapping[str, int] | None = None, ) -> ArrayOrScalar: - """Count the total number of floating point operations in the DAG *expr*.""" + """ + Count the total number of floating point operations in the DAG *expr*. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) expr = _normalize_materialization(expr) @@ -1205,6 +1213,12 @@ def get_materialized_node_flop_counts( """ Returns a dictionary mapping materialized nodes in DAG *expr* to their floating point operation count. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1227,6 +1241,12 @@ def get_unmaterialized_node_flop_counts( Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count information. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) From cb117676bb4ad71854f42a1d256a434cef0eb443 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 10:46:02 -0600 Subject: [PATCH 06/31] change 'pass' -> '...' in ReductionOperation abstract methods --- pytato/reductions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index a35c1d98a..a229328c8 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -89,19 +89,19 @@ class ReductionOperation(ABC): @classmethod @abstractmethod def scalar_op_name(cls) -> str: - pass + ... @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: - pass + ... @abstractmethod def __hash__(self) -> int: - pass + ... @abstractmethod def __eq__(self, other: object) -> bool: - pass + ... class _StatelessReductionOperation(ReductionOperation): From 7103c924c0630e4904c597750bf39a016419735f Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 10:52:14 -0600 Subject: [PATCH 07/31] move op_name_to_num_flops type declaration to FlopCounter class body --- pytato/scalar_expr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 4065d63b5..a09416805 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -257,11 +257,12 @@ def map_variable(self, expr: prim.Variable) -> set[str]: class FlopCounter(FlopCounterBase): + op_name_to_num_flops: dict[str, ArithmeticExpression] + def __init__( self, op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): super().__init__() - self.op_name_to_num_flops: dict[str, ArithmeticExpression] if op_name_to_num_flops: self.op_name_to_num_flops = dict(op_name_to_num_flops) else: From b529bcf6c24d3ec6dff31fa190d1da1ac5c6aa32 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 15:31:55 -0600 Subject: [PATCH 08/31] add some details about how flop counts are computed --- pytato/analysis/__init__.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 11df8e79d..961d2d986 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1080,6 +1080,10 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]: @dataclass class UnmaterializedNodeFlopCounts: + """ + Floating point operation counts for an unmaterialized node. See + :func:`get_unmaterialized_node_flop_counts` for details. + """ materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] nflops_if_materialized: ArrayOrScalar @@ -1185,6 +1189,11 @@ def get_num_flops( """ Count the total number of floating point operations in the DAG *expr*. + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. The total flop count is the sum + over all materialized nodes. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, @@ -1214,6 +1223,10 @@ def get_materialized_node_flop_counts( Returns a dictionary mapping materialized nodes in DAG *expr* to their floating point operation count. + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, @@ -1242,6 +1255,15 @@ def get_unmaterialized_node_flop_counts( :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count information. + The :class:`UnmaterializedNodeFlopCounts` instance for each unmaterialized node + (i.e., a node that can be tagged with :class:`pytato.tags.ImplStored` but isn't) + contains `materialized_successor_to_contrib_nflops` and `nflops_if_materialized` + attributes. The former is a mapping from each materialized successor of the + unmaterialized node to the number of flops the node contributes to evaluating + that successor (this includes flops from the predecessors of the unmaterialized + node). The latter is the number of flops that would be required to evaluate the + unmaterialized node if it was materialized instead. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, From c41563b44a5d142bd20bc6bec87dfc36bd5c386b Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 15:33:20 -0600 Subject: [PATCH 09/31] clarify how flop counting functions behave w.r.t. DAG functions --- pytato/analysis/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 961d2d986..e0ca041a7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1199,6 +1199,10 @@ def get_num_flops( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1232,6 +1236,10 @@ def get_materialized_node_flop_counts( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does not* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1269,6 +1277,10 @@ def get_unmaterialized_node_flop_counts( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does not* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) From d637e1f43604e7114fe99d48b93bb4c368ac5a08 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 8 Dec 2025 16:57:43 -0600 Subject: [PATCH 10/31] don't try to count flops for function calls --- .basedpyright/baseline.json | 8 ---- pytato/analysis/__init__.py | 81 ++++++++++++-------------------- pytato/transform/calls.py | 94 +------------------------------------ test/test_pytato.py | 19 -------- 4 files changed, 32 insertions(+), 170 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index a27334ffe..9a2764602 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -7691,14 +7691,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 42, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e0ca041a7..f39f8cffb 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -976,15 +976,10 @@ class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): def __init__( self, op_name_to_num_flops: Mapping[str, int], - _visited_functions: set[VisitKeyT] | None = None, - _function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] | None = None ) -> None: - super().__init__(_visited_functions=_visited_functions) + super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} - self.call_to_nflops: dict[Call, ArrayOrScalar] = {} - self._function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] = \ - _function_to_nflops if _function_to_nflops is not None else {} self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( self.op_name_to_num_flops) @@ -992,50 +987,25 @@ def __init__( def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: return expr - @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: - return expr - @override def clone_for_callee(self, function: FunctionDefinition) -> Self: - return type(self)( - op_name_to_num_flops=self.op_name_to_num_flops, - _visited_functions=self._visited_functions, - _function_to_nflops=self._function_to_nflops) + raise AssertionError("Control shouldn't reach this point.") @override def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return - new_mapper = self.clone_for_callee(expr) - for subexpr in expr.returns.values(): - # Assume that any calls that haven't been inlined have their functions' - # outputs materialized - assert not _is_unmaterialized(subexpr) - new_mapper(subexpr) - - self._function_to_nflops[expr] = ( - sum(new_mapper.materialized_node_to_nflops.values()) - + sum(new_mapper.call_to_nflops.values())) - - self.post_visit(expr) + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def map_call(self, expr: Call) -> None: if not self.visit(expr): return - self.rec_function_definition(expr.function) - for bnd in expr.bindings.values(): - # Assume that any calls that haven't been inlined have their inputs - # materialized - assert not _is_unmaterialized(bnd) - self.rec(bnd) - - self.call_to_nflops[expr] = self._function_to_nflops[expr.function] - - self.post_visit(expr) + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: @@ -1099,9 +1069,8 @@ class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): """ def __init__( self, - op_name_to_num_flops: Mapping[str, int], - _visited_functions: set[VisitKeyT] | None = None) -> None: - super().__init__(_visited_functions=_visited_functions) + op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops self.unmaterialized_node_to_flop_counts: \ dict[Array, UnmaterializedNodeFlopCounts] = {} @@ -1113,8 +1082,24 @@ def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: return expr @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: - return expr + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: @@ -1141,10 +1126,6 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: # FIXME: Should this be added to normalize_outputs? def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - # Make sure call bindings and function results are materialized - from pytato.transform.calls import normalize_calls - expr = normalize_calls(expr) - # Make sure outputs are materialized if isinstance(expr, DictOfNamedArrays): output_to_materialized_output: dict[Array, Array] = { @@ -1202,7 +1183,7 @@ def get_num_flops( .. note:: - This *does* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1214,9 +1195,7 @@ def get_num_flops( fc = MaterializedNodeFlopCounter(op_name_to_num_flops) fc(expr) - return ( - sum(fc.materialized_node_to_nflops.values()) - + sum(fc.call_to_nflops.values())) + return sum(fc.materialized_node_to_nflops.values()) def get_materialized_node_flop_counts( @@ -1239,7 +1218,7 @@ def get_materialized_node_flop_counts( .. note:: - This *does not* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1280,7 +1259,7 @@ def get_unmaterialized_node_flop_counts( .. note:: - This *does not* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index e82c2b1e0..ffad1e5f7 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,7 +1,6 @@ """ .. currentmodule:: pytato.transform.calls -.. autofunction:: normalize_calls .. autofunction:: inline_calls .. autofunction:: tag_all_calls_to_be_inlined """ @@ -33,26 +32,20 @@ from typing import TYPE_CHECKING, cast -from immutabledict import immutabledict -from typing_extensions import Never, Self, override +from typing_extensions import Self from pytato.array import ( AbstractResultWithNamedArrays, Array, - DataWrapper, DictOfNamedArrays, Placeholder, - SizeParam, - make_dict_of_named_arrays, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.tags import ImplStored, InlineCallTag +from pytato.tags import InlineCallTag from pytato.transform import ( ArrayOrNames, ArrayOrNamesTc, - CombineMapper, CopyMapper, - Mapper, TransformMapperCache, _verify_is_array, deduplicate, @@ -63,89 +56,6 @@ from collections.abc import Mapping -# {{{ normalizing - -class _LocalStackCallBindingCollector(CombineMapper[frozenset[Array], Never, []]): - """Mapper to collect bindings of calls on the current call stack.""" - @override - def combine(self, *args: frozenset[Array]) -> frozenset[Array]: - from functools import reduce - return reduce(lambda a, b: a | b, args, cast("frozenset[Array]", frozenset())) - - @override - def map_call(self, expr: Call) -> frozenset[Array]: - return frozenset(expr.bindings.values()) - - -class _CallMaterializer(CopyMapper): - """Mapper to add materialization tags for call bindings and function results.""" - def __init__( - self, - local_call_bindings: frozenset[Array], - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None - ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) - self.local_call_bindings: frozenset[Array] = local_call_bindings - - @override - def clone_for_callee(self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - local_call_bindings = _LocalStackCallBindingCollector()( - make_dict_of_named_arrays(function.returns)) - return type(self)( - local_call_bindings, - _function_cache=cast( - "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) - - def _materialize_if_possible(self, expr: ArrayOrNames) -> ArrayOrNames: - if ( - isinstance(expr, Array) - and not isinstance(expr, - (DataWrapper, Placeholder, SizeParam, NamedCallResult))): - return expr.tagged(ImplStored()) - else: - return expr - - @override - def map_function_definition(self, - expr: FunctionDefinition) -> FunctionDefinition: - new_mapper = self.clone_for_callee(expr) - new_returns: Mapping[str, Array] = immutabledict({ - name: self._materialize_if_possible(_verify_is_array(new_mapper(ret))) - for name, ret in expr.returns.items()}) - return expr.replace_if_different(returns=new_returns) - - @override - def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = cast("ArrayOrNames", Mapper.rec(self, expr)) - if expr in self.local_call_bindings: - result = self._materialize_if_possible(result) - return self._cache_add(inputs, result) - - -def normalize_calls(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - """ - Ensure that calls/functions are defined uniformly. - - Adds any missing materialization tags for call bindings and function results. - """ - local_call_bindings = _LocalStackCallBindingCollector()(expr) - return _CallMaterializer(local_call_bindings)(expr) - -# }}} - - # {{{ inlining class PlaceholderSubstitutor(CopyMapper): diff --git a/test/test_pytato.py b/test/test_pytato.py index 848f5b7f0..8e2a4a1d0 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1126,25 +1126,6 @@ def test_flop_count(): # }}} - # {{{ function call - - x = pt.make_placeholder("x", (10, 4)) - y = pt.make_placeholder("y", (10, 4)) - - def f(x, y): - z = x + y - return 2*z, 3*z - - u, v = pt.trace_call(f, x, y) - expr = u - v - - # u[i, j] = 2*(x[i, j] + y[i, j]) - # v[i, j] = 3*(x[i, j] + y[i, j]) - # expr[i, j] = u[i, j] + (-1)*v[i, j] - assert get_num_flops(expr) == 40*2 + 40*2 + 40*2 - - # }}} - # {{{ einsum x = pt.make_placeholder("x", (2, 3, 4)) From 83f592a7951eb41555e7483e9150eaac60aad430 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 9 Dec 2025 13:33:13 -0600 Subject: [PATCH 11/31] move own flop counting out of rec and into its own method in _PerEntryFlopCounter --- .basedpyright/baseline.json | 4 ++-- pytato/analysis/__init__.py | 31 ++++++++++++++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 9a2764602..b70d84ab8 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1940,8 +1940,8 @@ { "code": "reportUnknownMemberType", "range": { - "startColumn": 51, - "endColumn": 61, + "startColumn": 34, + "endColumn": 44, "lineCount": 1 } }, diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f39f8cffb..f9af09544 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -934,6 +934,18 @@ def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: def combine(self, *args: int) -> int: return sum(args) + def _get_own_flop_count(self, expr: Array) -> int: + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + if not isinstance(nflops, int): + from pytato.scalar_expr import InputGatherer as ScalarInputGatherer + var_names: set[str] = set(ScalarInputGatherer()(nflops)) + var_names.discard("nflops") + if var_names: + raise UndefinedOpFlopCountError(next(iter(var_names))) from None + else: + raise AssertionError from None + return nflops + @override def rec(self, expr: ArrayOrNames) -> int: inputs = self._make_cache_inputs(expr) @@ -943,19 +955,12 @@ def rec(self, expr: ArrayOrNames) -> int: result: int if _is_unmaterialized(expr): assert isinstance(expr, Array) - self_nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - if not isinstance(self_nflops, int): - from pytato.scalar_expr import InputGatherer as ScalarInputGatherer - var_names: set[str] = set(ScalarInputGatherer()(self_nflops)) - var_names.discard("nflops") - if var_names: - raise UndefinedOpFlopCountError(next(iter(var_names))) from None - else: - raise AssertionError from None - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = self_nflops + cast("int", Mapper.rec(self, expr)) + result = ( + self._get_own_flop_count(expr) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + + cast("int", Mapper.rec(self, expr))) else: result = 0 if isinstance(expr, Array): From d4a8a9503a1fd525955a7e705d965a1c3ea638a7 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 9 Dec 2025 13:27:46 -0600 Subject: [PATCH 12/31] change is_materializable -> is_materialized / has_taggable_materialization --- pytato/analysis/__init__.py | 47 +++++++++++++++++-------------------- pytato/utils.py | 37 +++++++++++++++++++++++------ test/test_pytato.py | 6 ++++- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f9af09544..ab00ed62e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -52,6 +52,7 @@ ShapeType, Stack, ) +from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall, LoopyCallResult @@ -70,7 +71,7 @@ map_and_copy, ) from pytato.transform.lower_to_index_lambda import to_index_lambda -from pytato.utils import is_materializable +from pytato.utils import has_taggable_materialization, is_materialized if TYPE_CHECKING: @@ -906,18 +907,6 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # {{{ flop counting -def _is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - return ( - is_materializable(expr) - and bool(expr.tags_of_type(ImplStored))) - - -def _is_unmaterialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - return ( - is_materializable(expr) - and not bool(expr.tags_of_type(ImplStored))) - - @dataclass class UndefinedOpFlopCountError(ValueError): op_name: str @@ -935,7 +924,10 @@ def combine(self, *args: int) -> int: return sum(args) def _get_own_flop_count(self, expr: Array) -> int: - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + try: + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + except CannotBeLoweredToIndexLambda: + nflops = 0 if not isinstance(nflops, int): from pytato.scalar_expr import InputGatherer as ScalarInputGatherer var_names: set[str] = set(ScalarInputGatherer()(nflops)) @@ -953,8 +945,7 @@ def rec(self, expr: ArrayOrNames) -> int: return self._cache_retrieve(inputs) except KeyError: result: int - if _is_unmaterialized(expr): - assert isinstance(expr, Array) + if isinstance(expr, Array) and not is_materialized(expr): result = ( self._get_own_flop_count(expr) # Intentionally going to Mapper instead of super() to avoid @@ -1014,14 +1005,16 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not _is_materialized(expr): + if not is_materialized(expr): return assert isinstance(expr, Array) - unmaterialized_expr = expr.without_tags(ImplStored()) - self._per_entry_flop_counter(unmaterialized_expr) - self.materialized_node_to_nflops[expr] = ( - product(expr.shape) - * self._per_entry_flop_counter.node_to_nflops[unmaterialized_expr]) + if has_taggable_materialization(expr): + unmaterialized_expr = expr.without_tags(ImplStored()) + self.materialized_node_to_nflops[expr] = ( + product(expr.shape) + * self._per_entry_flop_counter(unmaterialized_expr)) + else: + self.materialized_node_to_nflops[expr] = 0 class _UnmaterializedSubexpressionUseCounter( @@ -1041,8 +1034,7 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]: return self._cache_retrieve(inputs) except KeyError: result: dict[Array, int] - if _is_unmaterialized(expr): - assert isinstance(expr, Array) + if isinstance(expr, Array) and not is_materialized(expr): # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 @@ -1108,7 +1100,7 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not _is_materialized(expr): + if not is_materialized(expr) or not has_taggable_materialization(expr): return assert isinstance(expr, Array) unmaterialized_expr = expr.without_tags(ImplStored()) @@ -1134,7 +1126,10 @@ def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: # Make sure outputs are materialized if isinstance(expr, DictOfNamedArrays): output_to_materialized_output: dict[Array, Array] = { - ary: ary.tagged(ImplStored()) if is_materializable(ary) else ary + ary: ( + ary.tagged(ImplStored()) + if has_taggable_materialization(ary) + else ary) for ary in expr._data.values()} def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: diff --git a/pytato/utils.py b/pytato/utils.py index 92c1a1cbb..d5d377966 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -82,7 +82,8 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str -.. autofunction:: is_materializable +.. autofunction:: is_materialized +.. autofunction:: has_taggable_materialization References ^^^^^^^^^^ @@ -740,18 +741,40 @@ def get_einsum_specification(expr: Einsum) -> str: return f"{','.join(input_specs)}->{output_spec}" -def is_materializable(expr: ArrayOrNames | FunctionDefinition) -> bool: +def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + """Returns *True* if *expr* is materialized.""" + from pytato.array import InputArgumentBase + from pytato.distributed.nodes import DistributedRecv + from pytato.tags import ImplStored + return ( + ( + isinstance(expr, Array) + and bool(expr.tags_of_type(ImplStored))) + or isinstance( + expr, + ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, + DistributedRecv))) + + +def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool: """ - Returns *True* if *expr* is an instance of an array type that can be materialized. + Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to + determine whether or not it is materialized. """ from pytato.array import InputArgumentBase, NamedArray from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder return ( isinstance(expr, Array) - and not isinstance(expr, ( - # FIXME: Is there a nice way to generalize this? - InputArgumentBase, NamedArray, DistributedRecv, - DistributedSendRefHolder))) + and not isinstance( + expr, + ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, + DistributedRecv, + NamedArray, + DistributedSendRefHolder))) # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 8e2a4a1d0..970499799 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1161,9 +1161,13 @@ def test_materialized_node_flop_counts(): # z[i, j] = x[i, j] + y[i, j] # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] - assert len(materialized_node_to_flop_count) == 2 + assert len(materialized_node_to_flop_count) == 4 + assert x in materialized_node_to_flop_count + assert y in materialized_node_to_flop_count assert z in materialized_node_to_flop_count assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert materialized_node_to_flop_count[x] == 0 + assert materialized_node_to_flop_count[y] == 0 assert materialized_node_to_flop_count[z] == 40 assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 From a997b2a69ac02374e748b1b5d215f249ff9549d6 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 15 Dec 2025 10:54:02 -0600 Subject: [PATCH 13/31] use explicit isinstance() check instead of try/except around to_index_lambda() --- pytato/analysis/__init__.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ab00ed62e..893a1d591 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -42,6 +42,7 @@ ArrayOrScalar, Concatenate, CSRMatmul, + DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -49,10 +50,10 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Placeholder, ShapeType, Stack, ) -from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall, LoopyCallResult @@ -924,10 +925,16 @@ def combine(self, *args: int) -> int: return sum(args) def _get_own_flop_count(self, expr: Array) -> int: - try: - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - except CannotBeLoweredToIndexLambda: - nflops = 0 + if isinstance( + expr, + ( + DataWrapper, + Placeholder, + NamedArray, + DistributedRecv, + DistributedSendRefHolder)): + return 0 + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) if not isinstance(nflops, int): from pytato.scalar_expr import InputGatherer as ScalarInputGatherer var_names: set[str] = set(ScalarInputGatherer()(nflops)) From 79da0ab152109ec8df29c11af023a0ed8aa4d9ad Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 15 Dec 2025 11:45:20 -0600 Subject: [PATCH 14/31] remove a couple of FIXMEs --- pytato/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index d5d377966..f2b50791c 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -753,7 +753,6 @@ def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: or isinstance( expr, ( - # FIXME: Is there a nice way to generalize this? InputArgumentBase, DistributedRecv))) @@ -770,7 +769,6 @@ def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> boo and not isinstance( expr, ( - # FIXME: Is there a nice way to generalize this? InputArgumentBase, DistributedRecv, NamedArray, From 57dc83f302bd86277bc519cd2c0e347ed2f14893 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 26 Jan 2026 12:27:00 -0600 Subject: [PATCH 15/31] add placeholder class for operator flop counts that aren't specified --- pytato/analysis/__init__.py | 15 +++++++++------ pytato/scalar_expr.py | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 893a1d591..3dc89500b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -936,13 +936,16 @@ def _get_own_flop_count(self, expr: Array) -> int: return 0 nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) if not isinstance(nflops, int): - from pytato.scalar_expr import InputGatherer as ScalarInputGatherer - var_names: set[str] = set(ScalarInputGatherer()(nflops)) - var_names.discard("nflops") - if var_names: - raise UndefinedOpFlopCountError(next(iter(var_names))) from None + # Restricting to numerical result here because the flop counters that use + # this mapper subsequently multiply the result by things that are + # potentially arrays (e.g., shape components), and arrays and scalar + # expressions are not interoperable + from pytato.scalar_expr import OpFlops, OpFlopsCollector + op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) + if op_flops: + raise UndefinedOpFlopCountError(next(iter(op_flops)).op) else: - raise AssertionError from None + raise AssertionError return nflops @override diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index a09416805..3eccd6256 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -45,6 +45,7 @@ import re from collections.abc import Iterable, Mapping, Set as AbstractSet +from functools import reduce from typing import ( TYPE_CHECKING, Any, @@ -272,8 +273,7 @@ def _get_op_nflops(self, name: str) -> ArithmeticExpression: try: return self.op_name_to_num_flops[name] except KeyError: - from pymbolic import var - result = var("nflops")(var(name)) + result = OpFlops(name) self.op_name_to_num_flops[name] = result return result @@ -486,9 +486,42 @@ class TypeCast(ExpressionBase): dtype: np.dtype[Any] inner_expr: ScalarExpression + +@expr_dataclass() +class OpFlops(prim.AlgebraicLeaf): + """ + Placeholder flop count for an operator. + + .. autoattribute:: op + """ + op: str + # }}} +class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]): + """ + Constructs a :class:`frozenset` containing all instances of + :class:`pytato.scalar_expr.OpFlops` found in a scalar expression. + """ + @override + def combine( + self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]: + return reduce( + lambda x, y: x.union(y), + values, + cast("frozenset[OpFlops]", frozenset())) + + @override + def map_algebraic_leaf( + self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]: + return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset() + + @override + def map_constant(self, expr: object) -> frozenset[OpFlops]: + return frozenset() + + class InductionVariableCollector(CombineMapper[AbstractSet[str], []]): def combine(self, values: Iterable[AbstractSet[str]]) -> frozenset[str]: from functools import reduce From d33e153f5ff279d9a53f79aa012acf2cb6efe7f7 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 13 Mar 2026 10:26:03 -0500 Subject: [PATCH 16/31] fix swapped and/or op names Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- pytato/reductions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index a229328c8..1cb016def 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -169,7 +169,7 @@ class AllReductionOperation(_StatelessReductionOperation): @override @classmethod def scalar_op_name(cls): - return "or" + return "and" def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(True) @@ -179,7 +179,7 @@ class AnyReductionOperation(_StatelessReductionOperation): @override @classmethod def scalar_op_name(cls): - return "and" + return "or" def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(False) From cb621fe117bce5ad50b4537459cb7d22e044e492 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 13 Mar 2026 10:30:21 -0500 Subject: [PATCH 17/31] fix flop counting for negative power Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- pytato/scalar_expr.py | 2 +- test/test_pytato.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 3eccd6256..fbdf07674 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -331,7 +331,7 @@ def map_power(self, expr: prim.Power) -> ArithmeticExpression: else: return ( self._get_op_nflops("/") - + expr.exponent * self._get_op_nflops("*") + + (-expr.exponent) * self._get_op_nflops("*") + self.rec(expr.base)) else: return ( diff --git a/test/test_pytato.py b/test/test_pytato.py index 970499799..52694ec32 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1004,6 +1004,7 @@ def test_scalar_flop_count(): assert fc(x % 2) == 0 assert fc(x ** 3) == 3 + assert fc(x ** (-3)) == 7 assert fc(x ** 0.3) == 8 assert fc(x.lt(y)) == 1 From 8813c9bd939cda26ebb825e4deb98761e4952ddb Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 13 Mar 2026 10:59:34 -0500 Subject: [PATCH 18/31] make UndefinedOpFlopCountError not a dataclass to avoid issues with __init__ --- pytato/analysis/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3dc89500b..75058b038 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -908,9 +908,8 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # {{{ flop counting -@dataclass class UndefinedOpFlopCountError(ValueError): - op_name: str + pass class _PerEntryFlopCounter(CombineMapper[int, Never, []]): @@ -943,7 +942,9 @@ def _get_own_flop_count(self, expr: Array) -> int: from pytato.scalar_expr import OpFlops, OpFlopsCollector op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) if op_flops: - raise UndefinedOpFlopCountError(next(iter(op_flops)).op) + op_name = next(iter(op_flops)).op + raise UndefinedOpFlopCountError( + f"Undefined flop count for operation '{op_name}'.") else: raise AssertionError return nflops From 609f3a9c1b66c195e9fd22419f44ec3a006f5233 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 13 Mar 2026 14:21:33 -0500 Subject: [PATCH 19/31] update note about function traversal for MaterializedNodeFlopCounter --- pytato/analysis/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 75058b038..ccdd7cf3c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -977,8 +977,7 @@ class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): .. note:: - Flops from nodes inside function calls are accumulated onto the corresponding - call node. + This mapper does not descend into functions. """ def __init__( self, From 8300604c0fa18e1e23db0f1313111d901693b318 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Mar 2026 13:24:41 -0500 Subject: [PATCH 20/31] add InputUseCounter in scalar_expr --- pytato/scalar_expr.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index fbdf07674..c2ed3b47e 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -44,6 +44,7 @@ import re +from collections import defaultdict from collections.abc import Iterable, Mapping, Set as AbstractSet from functools import reduce from typing import ( @@ -80,6 +81,7 @@ from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import SubstitutionMapper as SubstitutionMapperBase from pymbolic.typing import Integer +from pytools import product if TYPE_CHECKING: @@ -257,6 +259,36 @@ def map_variable(self, expr: prim.Variable) -> set[str]: return {expr.name} +class InputUseCounter(CombineMapper[dict[str, ArithmeticExpression], []]): + @override + def combine( + self, values: Iterable[dict[str, ArithmeticExpression]] + ) -> dict[str, ArithmeticExpression]: + result: dict[str, ArithmeticExpression] = defaultdict(int) + for val in values: + for name, nuses in val.items(): + result[name] += nuses + return result + + @override + def map_constant(self, expr: object) -> dict[str, ArithmeticExpression]: + return {} + + @override + def map_variable(self, expr: prim.Variable) -> dict[str, ArithmeticExpression]: + return {expr.name: 1} + + @override + def map_reduce(self, expr: Reduce) -> dict[str, ArithmeticExpression]: + inner_expr_result = self.rec(expr.inner_expr) + niters = product(( + upper_bd - lower_bd + for lower_bd, upper_bd in expr.bounds.values())) + return { + name: niters * inner_expr_nuses + for name, inner_expr_nuses in inner_expr_result.items()} + + class FlopCounter(FlopCounterBase): op_name_to_num_flops: dict[str, ArithmeticExpression] From 7bda5ac88c79edb89e7d7a9a0ef92fd2313c20f6 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Mar 2026 13:53:57 -0500 Subject: [PATCH 21/31] abandon is_materialized/has_taggable_materialization approach in favor of collecting materialized arrays ahead of time the former doesn't work because some materialized arrays can only be identified by context (e.g., the 'data' attribute of a DistributedSend). --- .basedpyright/baseline.json | 16 -- pytato/analysis/__init__.py | 380 +++++++++++++++++++++++++----------- pytato/utils.py | 40 +--- test/test_pytato.py | 23 +-- 4 files changed, 277 insertions(+), 182 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index b70d84ab8..1e5fb4d7b 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1936,22 +1936,6 @@ "endColumn": 25, "lineCount": 1 } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 34, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 56, - "endColumn": 66, - "lineCount": 1 - } } ], "./pytato/array.py": [ diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ccdd7cf3c..0d2c00fcd 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -28,7 +28,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, cast, overload +from typing import TYPE_CHECKING, Any, overload from orderedsets import FrozenOrderedSet, OrderedSet from typing_extensions import Never, Self, override @@ -64,15 +64,15 @@ from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, - ArrayOrNamesTc, CachedWalkMapper, + CacheKeyT, CombineMapper, Mapper, VisitKeyT, - map_and_copy, ) -from pytato.transform.lower_to_index_lambda import to_index_lambda -from pytato.utils import has_taggable_materialization, is_materialized +from pytato.transform.lower_to_index_lambda import ( + MapAsIndexLambdaMixin, +) if TYPE_CHECKING: @@ -912,62 +912,131 @@ class UndefinedOpFlopCountError(ValueError): pass -class _PerEntryFlopCounter(CombineMapper[int, Never, []]): - def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: +class _NonIntegralPerEntryFlopCountError(ValueError): + pass + + +class _PerEntryFlopCounter( + MapAsIndexLambdaMixin[int, [bool]], + CombineMapper[int, Never, [bool]]): + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array]) -> None: super().__init__() self.scalar_flop_counter: ScalarFlopCounter = ScalarFlopCounter( op_name_to_num_flops) - self.node_to_nflops: dict[Array, int] = {} + self.materialized_nodes: frozenset[Array] = materialized_nodes + + @overload + def __call__( + self, + expr: ArrayOrNames, + is_root: bool = True, + ) -> int: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + is_root: bool = True, + ) -> Never: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + is_root: bool = True, + ) -> int: + return super().__call__(expr, is_root) + + @override + def get_cache_key(self, expr: ArrayOrNames, is_root: bool) -> CacheKeyT: + return expr, is_root @override def combine(self, *args: int) -> int: return sum(args) - def _get_own_flop_count(self, expr: Array) -> int: - if isinstance( - expr, - ( - DataWrapper, - Placeholder, - NamedArray, - DistributedRecv, - DistributedSendRefHolder)): + @override + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, is_root: bool) -> int: + if expr in self.materialized_nodes and not is_root: return 0 - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - if not isinstance(nflops, int): + + self_nflops = self.scalar_flop_counter(idx_lambda.expr) + + if not isinstance(self_nflops, int): # Restricting to numerical result here because the flop counters that use # this mapper subsequently multiply the result by things that are # potentially arrays (e.g., shape components), and arrays and scalar # expressions are not interoperable from pytato.scalar_expr import OpFlops, OpFlopsCollector - op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) + op_flops: frozenset[OpFlops] = OpFlopsCollector()(self_nflops) if op_flops: op_name = next(iter(op_flops)).op raise UndefinedOpFlopCountError( f"Undefined flop count for operation '{op_name}'.") else: - raise AssertionError - return nflops + raise _NonIntegralPerEntryFlopCountError( + "Unable to compute an integer-valued per-entry flop count.") + + return self.combine( + self_nflops, + *( + self.rec(bnd, False) + for _, bnd in sorted(idx_lambda.bindings.items()))) @override - def rec(self, expr: ArrayOrNames) -> int: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - result: int - if isinstance(expr, Array) and not is_materialized(expr): - result = ( - self._get_own_flop_count(expr) - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - + cast("int", Mapper.rec(self, expr))) - else: - result = 0 - if isinstance(expr, Array): - self.node_to_nflops[expr] = result - return self._cache_add(inputs, result) + def map_placeholder(self, expr: Placeholder, is_root: bool) -> int: + return 0 + + @override + def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> int: + return 0 + + @override + def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> int: + raise NotImplementedError + + @override + def map_named_array(self, expr: NamedArray, is_root: bool) -> int: + assert isinstance(expr._container, DictOfNamedArrays) + return self.combine(self.rec(expr._container._data[expr.name], False)) + + @override + def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, is_root: bool) -> int: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> int: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call_result(self, expr: LoopyCallResult, is_root: bool) -> int: + return 0 + + @override + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, is_root: bool) -> int: + # Ignore expr.send.data; it's not part of the computation being performed + return self.combine(self.rec(expr.passthrough_data, False)) + + @override + def map_distributed_recv(self, expr: DistributedRecv, is_root: bool) -> int: + return 0 + + @override + def map_call(self, expr: Call, is_root: bool) -> int: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_named_call_result(self, expr: NamedCallResult, is_root: bool) -> int: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): @@ -982,12 +1051,14 @@ class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): def __init__( self, op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array], ) -> None: super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_nodes: frozenset[Array] = materialized_nodes self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( - self.op_name_to_num_flops) + self.op_name_to_num_flops, self.materialized_nodes) @override def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: @@ -1015,20 +1086,59 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not is_materialized(expr): + if expr not in self.materialized_nodes: return assert isinstance(expr, Array) - if has_taggable_materialization(expr): - unmaterialized_expr = expr.without_tags(ImplStored()) - self.materialized_node_to_nflops[expr] = ( - product(expr.shape) - * self._per_entry_flop_counter(unmaterialized_expr)) + try: + nflops_per_entry = self._per_entry_flop_counter(expr) + except _NonIntegralPerEntryFlopCountError as e: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e else: - self.materialized_node_to_nflops[expr] = 0 + self.materialized_node_to_nflops[expr] = ( + nflops_per_entry + * product(expr.shape)) class _UnmaterializedSubexpressionUseCounter( - CombineMapper[dict[Array, int], Never, []]): + MapAsIndexLambdaMixin[dict[Array, int], [bool]], + CombineMapper[dict[Array, int], Never, [bool]]): + def __init__( + self, + materialized_nodes: frozenset[Array]) -> None: + super().__init__() + self.materialized_nodes: frozenset[Array] = materialized_nodes + + @overload + def __call__( + self, + expr: ArrayOrNames, + is_root: bool = True, + ) -> dict[Array, int]: + ... + + @overload + def __call__( + self, + expr: FunctionDefinition, + is_root: bool = True, + ) -> Never: + ... + + @override + def __call__( + self, + expr: ArrayOrNames | FunctionDefinition, + is_root: bool = True, + ) -> dict[Array, int]: + return super().__call__(expr, is_root) + + @override + def get_cache_key(self, expr: ArrayOrNames, is_root: bool) -> CacheKeyT: + return expr, is_root + @override def combine(self, *args: dict[Array, int]) -> dict[Array, int]: result: dict[Array, int] = defaultdict(int) @@ -1038,21 +1148,70 @@ def combine(self, *args: dict[Array, int]) -> dict[Array, int]: return result @override - def rec(self, expr: ArrayOrNames) -> dict[Array, int]: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - result: dict[Array, int] - if isinstance(expr, Array) and not is_materialized(expr): - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = self.combine( - {expr: 1}, cast("dict[Array, int]", Mapper.rec(self, expr))) - else: - result = {} - return self._cache_add(inputs, result) + def map_as_index_lambda( + self, expr: Array, idx_lambda: IndexLambda, is_root: bool + ) -> dict[Array, int]: + if expr in self.materialized_nodes and not is_root: + return {} + + return self.combine( + {expr: 1} if not is_root else {}, + *( + self.rec(bnd, False) + for _, bnd in sorted(idx_lambda.bindings.items()))) + + @override + def map_placeholder(self, expr: Placeholder, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> dict[Array, int]: + raise NotImplementedError + + @override + def map_named_array(self, expr: NamedArray, is_root: bool) -> dict[Array, int]: + assert isinstance(expr._container, DictOfNamedArrays) + return self.combine(self.rec(expr._container._data[expr.name], False)) + + @override + def map_dict_of_named_arrays( + self, expr: DictOfNamedArrays, is_root: bool) -> dict[Array, int]: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> dict[Array, int]: + raise AssertionError("Control shouldn't reach here.") + + @override + def map_loopy_call_result( + self, expr: LoopyCallResult, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_distributed_send_ref_holder( + self, expr: DistributedSendRefHolder, is_root: bool) -> dict[Array, int]: + # Ignore expr.send.data; it's not part of the computation being performed + return self.combine(self.rec(expr.passthrough_data, False)) + + @override + def map_distributed_recv( + self, expr: DistributedRecv, is_root: bool) -> dict[Array, int]: + return {} + + @override + def map_call(self, expr: Call, is_root: bool) -> dict[Array, int]: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") + + @override + def map_named_call_result( + self, expr: NamedCallResult, is_root: bool) -> dict[Array, int]: + # Shouldn't have calls + raise AssertionError("Control shouldn't reach here.") @dataclass @@ -1076,13 +1235,18 @@ class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): """ def __init__( self, - op_name_to_num_flops: Mapping[str, int]) -> None: + op_name_to_num_flops: Mapping[str, int], + materialized_nodes: frozenset[Array]) -> None: super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_nodes: frozenset[Array] = materialized_nodes self.unmaterialized_node_to_flop_counts: \ dict[Array, UnmaterializedNodeFlopCounts] = {} self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( - self.op_name_to_num_flops) + self.op_name_to_num_flops, self.materialized_nodes) + self._use_counter: _UnmaterializedSubexpressionUseCounter = \ + _UnmaterializedSubexpressionUseCounter( + self.materialized_nodes) @override def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: @@ -1110,49 +1274,33 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not is_materialized(expr) or not has_taggable_materialization(expr): + if expr not in self.materialized_nodes: return assert isinstance(expr, Array) - unmaterialized_expr = expr.without_tags(ImplStored()) - subexpr_to_nuses = _UnmaterializedSubexpressionUseCounter()( - unmaterialized_expr) - del subexpr_to_nuses[unmaterialized_expr] - self._per_entry_flop_counter(unmaterialized_expr) - for subexpr, nuses in subexpr_to_nuses.items(): - per_entry_nflops = self._per_entry_flop_counter.node_to_nflops[subexpr] - if subexpr not in self.unmaterialized_node_to_flop_counts: - nflops_if_materialized = product(subexpr.shape) * per_entry_nflops - flop_counts = UnmaterializedNodeFlopCounts({}, nflops_if_materialized) - self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts - else: - flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] - assert expr not in flop_counts.materialized_successor_to_contrib_nflops - flop_counts.materialized_successor_to_contrib_nflops[expr] = ( - nuses * product(expr.shape) * per_entry_nflops) - - -# FIXME: Should this be added to normalize_outputs? -def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - # Make sure outputs are materialized - if isinstance(expr, DictOfNamedArrays): - output_to_materialized_output: dict[Array, Array] = { - ary: ( - ary.tagged(ImplStored()) - if has_taggable_materialization(ary) - else ary) - for ary in expr._data.values()} - - def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: - if not isinstance(ary, Array): - return ary - try: - return output_to_materialized_output[ary] - except KeyError: - return ary - - expr = map_and_copy(expr, replace_with_materialized) - - return expr + try: + self._per_entry_flop_counter(expr) + except _NonIntegralPerEntryFlopCountError as e: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e + else: + subexpr_to_nuses = self._use_counter(expr) + for subexpr, nuses in subexpr_to_nuses.items(): + nflops_per_entry = self._per_entry_flop_counter( + subexpr, is_root=False) + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = nflops_per_entry * product(subexpr.shape) + flop_counts = UnmaterializedNodeFlopCounts( + {}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses + * nflops_per_entry + * product(expr.shape)) def get_default_op_name_to_num_flops() -> dict[str, int]: @@ -1197,12 +1345,14 @@ def get_num_flops( """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) - expr = _normalize_materialization(expr) + + materialized_nodes = collect_materialized_nodes(expr) if op_name_to_num_flops is None: op_name_to_num_flops = get_default_op_name_to_num_flops() - fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc = MaterializedNodeFlopCounter( + op_name_to_num_flops, frozenset(materialized_nodes)) fc(expr) return sum(fc.materialized_node_to_nflops.values()) @@ -1232,12 +1382,14 @@ def get_materialized_node_flop_counts( """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) - expr = _normalize_materialization(expr) + + materialized_nodes = collect_materialized_nodes(expr) if op_name_to_num_flops is None: op_name_to_num_flops = get_default_op_name_to_num_flops() - fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc = MaterializedNodeFlopCounter( + op_name_to_num_flops, frozenset(materialized_nodes)) fc(expr) return fc.materialized_node_to_nflops @@ -1273,12 +1425,14 @@ def get_unmaterialized_node_flop_counts( """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) - expr = _normalize_materialization(expr) + + materialized_nodes = collect_materialized_nodes(expr) if op_name_to_num_flops is None: op_name_to_num_flops = get_default_op_name_to_num_flops() - fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops) + fc = UnmaterializedNodeFlopCounter( + op_name_to_num_flops, frozenset(materialized_nodes)) fc(expr) return fc.unmaterialized_node_to_flop_counts diff --git a/pytato/utils.py b/pytato/utils.py index f2b50791c..d7bc8eade 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -61,7 +61,7 @@ ScalarExpression, TypeCast, ) -from pytato.transform import ArrayOrNames, CachedMapper +from pytato.transform import CachedMapper if TYPE_CHECKING: @@ -69,8 +69,6 @@ from pytools.tag import Tag - from pytato.function import FunctionDefinition - __doc__ = """ Helper routines @@ -82,8 +80,6 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str -.. autofunction:: is_materialized -.. autofunction:: has_taggable_materialization References ^^^^^^^^^^ @@ -741,38 +737,4 @@ def get_einsum_specification(expr: Einsum) -> str: return f"{','.join(input_specs)}->{output_spec}" -def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - """Returns *True* if *expr* is materialized.""" - from pytato.array import InputArgumentBase - from pytato.distributed.nodes import DistributedRecv - from pytato.tags import ImplStored - return ( - ( - isinstance(expr, Array) - and bool(expr.tags_of_type(ImplStored))) - or isinstance( - expr, - ( - InputArgumentBase, - DistributedRecv))) - - -def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool: - """ - Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to - determine whether or not it is materialized. - """ - from pytato.array import InputArgumentBase, NamedArray - from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder - return ( - isinstance(expr, Array) - and not isinstance( - expr, - ( - InputArgumentBase, - DistributedRecv, - NamedArray, - DistributedSendRefHolder))) - - # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 52694ec32..675fa95c0 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1166,16 +1166,15 @@ def test_materialized_node_flop_counts(): assert x in materialized_node_to_flop_count assert y in materialized_node_to_flop_count assert z in materialized_node_to_flop_count - assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert expr in materialized_node_to_flop_count assert materialized_node_to_flop_count[x] == 0 assert materialized_node_to_flop_count[y] == 0 assert materialized_node_to_flop_count[z] == 40 - assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 + assert materialized_node_to_flop_count[expr] == 40*4 def test_unmaterialized_node_flop_counts(): from pytato.analysis import get_unmaterialized_node_flop_counts - from pytato.tags import ImplStored x = pt.make_placeholder("x", (10, 4)) y = pt.make_placeholder("y", (10, 4)) @@ -1190,8 +1189,6 @@ def test_unmaterialized_node_flop_counts(): unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) - materialized_expr = expr.tagged(ImplStored()) - # Everything except expr stays unmaterialized assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 assert z in unmaterialized_node_to_flop_counts @@ -1199,23 +1196,21 @@ def test_unmaterialized_node_flop_counts(): assert all(s_i in unmaterialized_node_to_flop_counts for s_i in s) flop_counts = unmaterialized_node_to_flop_counts[z] assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 - assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops - assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ - == 40*10 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 40*10 assert flop_counts.nflops_if_materialized == 40 for w_i in w: flop_counts = unmaterialized_node_to_flop_counts[w_i] assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 - assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops - assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ - == 40*2 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 40*2 assert flop_counts.nflops_if_materialized == 40*2 for i, s_i in enumerate(s): flop_counts = unmaterialized_node_to_flop_counts[s_i] assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 - assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops - assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ - == 40*2*(i+1) + 40*i + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == \ + 40*2*(i+1) + 40*i assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i From 6362537284abe07f401a639cf4dbbb7ce43dcd42 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 23 Mar 2026 13:57:02 -0500 Subject: [PATCH 22/31] fix flop counting for reductions --- pytato/analysis/__init__.py | 27 ++++++++++++++++++++++----- test/test_pytato.py | 12 ++++++------ 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 0d2c00fcd..3620f0a3f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -28,7 +28,7 @@ from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload from orderedsets import FrozenOrderedSet, OrderedSet from typing_extensions import Never, Self, override @@ -60,6 +60,7 @@ from pytato.scalar_expr import ( SCALAR_CLASSES, FlopCounter as ScalarFlopCounter, + InputUseCounter as ScalarInputUseCounter, ) from pytato.tags import ImplStored from pytato.transform import ( @@ -983,11 +984,17 @@ def map_as_index_lambda( raise _NonIntegralPerEntryFlopCountError( "Unable to compute an integer-valued per-entry flop count.") + binding_to_nuses = ScalarInputUseCounter()(idx_lambda.expr) + # self_nflops check above should take care of non-constant use count case + assert all( + isinstance(nuses, int) + for nuses in binding_to_nuses.values()) + return self.combine( self_nflops, *( - self.rec(bnd, False) - for _, bnd in sorted(idx_lambda.bindings.items()))) + cast("int", binding_to_nuses[name]) * self.rec(bnd, False) + for name, bnd in sorted(idx_lambda.bindings.items()))) @override def map_placeholder(self, expr: Placeholder, is_root: bool) -> int: @@ -1154,11 +1161,21 @@ def map_as_index_lambda( if expr in self.materialized_nodes and not is_root: return {} + binding_to_nuses = ScalarInputUseCounter()(idx_lambda.expr) + if any( + not isinstance(nuses, int) + for nuses in binding_to_nuses.values()): + raise ValueError( + "Unable to compute integer-valued use counts for the predecessors " + f"of array of type '{type(expr).__name__}'.") + return self.combine( {expr: 1} if not is_root else {}, *( - self.rec(bnd, False) - for _, bnd in sorted(idx_lambda.bindings.items()))) + { + ary: cast("int", binding_to_nuses[name]) * nuses + for ary, nuses in self.rec(bnd, False).items()} + for name, bnd in sorted(idx_lambda.bindings.items()))) @override def map_placeholder(self, expr: Placeholder, is_root: bool) -> dict[Array, int]: diff --git a/test/test_pytato.py b/test/test_pytato.py index 675fa95c0..cc2783da3 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1131,17 +1131,17 @@ def test_flop_count(): x = pt.make_placeholder("x", (2, 3, 4)) y = pt.make_placeholder("y", (3, 4)) - expr = pt.einsum("ijk,jk->ijk", x, y) + expr = pt.einsum("ijk,jk->ijk", 2*x, 3*y) - # expr[i, j, k] = x[i, j, k] * y[j, k] - assert get_num_flops(expr) == 24 + # expr[i, j, k] = 2*x[i, j, k] * 3*y[j, k] + assert get_num_flops(expr) == 72 x = pt.make_placeholder("x", (2, 3, 4)) y = pt.make_placeholder("y", (3, 4)) - expr = pt.einsum("ijk,jk->i", x, y) + expr = pt.einsum("ijk,jk->i", 2*x, 3*y) - # expr[i] = sum(sum(x[i, j, k] * y[j, k], j), k) - assert get_num_flops(expr) == 2*(4 * (3*1 + 2) + 3) + # expr[i] = sum(sum(2*x[i, j, k] * 3*y[j, k], j), k) + assert get_num_flops(expr) == 2*(4 * (3*3 + 2) + 3) # }}} From 6ac5c11b87236e92082e1e1d15e76d0426db59ef Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 24 Mar 2026 14:56:55 -0500 Subject: [PATCH 23/31] handle CSR matmul in flop counting --- pytato/analysis/__init__.py | 80 ++++++++++++++++++----- test/test_pytato.py | 126 +++++++++++++++++++++++++++++++++++- 2 files changed, 189 insertions(+), 17 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3620f0a3f..f69284ee1 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1004,10 +1004,6 @@ def map_placeholder(self, expr: Placeholder, is_root: bool) -> int: def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> int: return 0 - @override - def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> int: - raise NotImplementedError - @override def map_named_array(self, expr: NamedArray, is_root: bool) -> int: assert isinstance(expr._container, DictOfNamedArrays) @@ -1099,10 +1095,31 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: try: nflops_per_entry = self._per_entry_flop_counter(expr) except _NonIntegralPerEntryFlopCountError as e: - raise ValueError( - "Unable to compute a flop count for array of type " - f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " - "non-integer-valued.") from e + if isinstance(expr, CSRMatmul): + # _PerEntryFlopCounter chokes on CSRMatmul because the row reduction + # bounds are data dependent. By handling it here instead we can take + # advantage of knowing that the total number of reduction iterations + # is len(elem_values). Note: assumes no flops for elem_col_indices + # and row_starts as they are integer-valued. + nelems = expr.matrix.elem_values.shape[0] + nflops_self = ( + ( + nelems # multiplies + + nelems - expr.shape[0]) # adds + * product(expr.shape[1:])) + nflops_children = ( + nelems + * ( + self._per_entry_flop_counter(expr.matrix.elem_values) + + self._per_entry_flop_counter(expr.array)) + * product(expr.shape[1:])) + self.materialized_node_to_nflops[expr] = ( + nflops_self + nflops_children) + else: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e else: self.materialized_node_to_nflops[expr] = ( nflops_per_entry @@ -1185,10 +1202,6 @@ def map_placeholder(self, expr: Placeholder, is_root: bool) -> dict[Array, int]: def map_data_wrapper(self, expr: DataWrapper, is_root: bool) -> dict[Array, int]: return {} - @override - def map_csr_matmul(self, expr: CSRMatmul, is_root: bool) -> dict[Array, int]: - raise NotImplementedError - @override def map_named_array(self, expr: NamedArray, is_root: bool) -> dict[Array, int]: assert isinstance(expr._container, DictOfNamedArrays) @@ -1297,10 +1310,45 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: try: self._per_entry_flop_counter(expr) except _NonIntegralPerEntryFlopCountError as e: - raise ValueError( - "Unable to compute a flop count for array of type " - f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " - "non-integer-valued.") from e + if isinstance(expr, CSRMatmul): + # _PerEntryFlopCounter chokes on CSRMatmul because the row reduction + # bounds are data dependent. By handling it here instead we can take + # advantage of knowing that the total number of reduction iterations + # is len(elem_values). Note: assumes no flops for elem_col_indices + # and row_starts as they are integer-valued. + nelems = expr.matrix.elem_values.shape[0] + subexpr_to_nuses_per_elem = self._use_counter.combine( + self._use_counter(expr.matrix.elem_values, is_root=False), + self._use_counter(expr.array, is_root=False)) + for subexpr, nuses_per_elem in subexpr_to_nuses_per_elem.items(): + try: + nflops_per_entry = self._per_entry_flop_counter( + subexpr, is_root=False) + except _NonIntegralPerEntryFlopCountError: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(subexpr).__name__}'; per-entry flop count is " + "unexpectedly non-integer-valued.") from e + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = ( + nflops_per_entry * product(subexpr.shape)) + flop_counts = UnmaterializedNodeFlopCounts( + {}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in \ + flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses_per_elem + * nelems + * nflops_per_entry + * product(expr.shape[1:])) + else: + raise ValueError( + "Unable to compute a flop count for array of type " + f"'{type(expr).__name__}'; per-entry flop count is unexpectedly " + "non-integer-valued.") from e else: subexpr_to_nuses = self._use_counter(expr) for subexpr, nuses in subexpr_to_nuses.items(): diff --git a/test/test_pytato.py b/test/test_pytato.py index cc2783da3..7391dd1fb 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1145,11 +1145,60 @@ def test_flop_count(): # }}} + # {{{ CSR matmul (trivial predecessors) + + x = pt.make_csr_matrix( + shape=(8, 10), + elem_values=pt.make_placeholder("x_elem_values", (16,)), + elem_col_indices=pt.make_placeholder("x_elem_col_indices", (16,)), + row_starts=pt.make_placeholder("x_row_starts", (9,))) + y = pt.make_placeholder("y", (10, 5, 3)) + expr = x @ y + + assert get_num_flops(expr) == 5*3*( + 16 # multiplies + + 16 - 8 # adds + ) + + # }}} + + # {{{ CSR matmul (nontrivial predecessors) + + elem_values = pt.zeros(12) + 1 + elem_col_indices = pt.make_data_wrapper(np.array([ + 0, + 0, 1, + 0, 1, 2, + 1, 2, 3, + 2, 3, + 3])) + row_starts = pt.make_data_wrapper(np.array([0, 1, 3, 6, 9, 11, 12])) + x = pt.make_csr_matrix( + shape=(6, 4), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = pt.zeros((4, 3, 2)) + 1 + expr = x @ y + + assert get_num_flops(expr) == 3*2*( + 3 # row 1 + + 3*2 + 1 # row 2 + + 3*3 + 2 # row 3 + + 3*3 + 2 # row 4 + + 3*2 + 1 # row 5 + + 3 # row 6 + ) + + # }}} + def test_materialized_node_flop_counts(): from pytato.analysis import get_materialized_node_flop_counts from pytato.tags import ImplStored + # {{{ basic DAG + x = pt.make_placeholder("x", (10, 4)) y = pt.make_placeholder("y", (10, 4)) @@ -1172,10 +1221,52 @@ def test_materialized_node_flop_counts(): assert materialized_node_to_flop_count[z] == 40 assert materialized_node_to_flop_count[expr] == 40*4 + # }}} + + # {{{ CSR matmul + + zeros_10 = pt.make_data_wrapper(np.zeros(10)) + zeros_20 = pt.make_data_wrapper(np.zeros(20)) + elem_values = zeros_20 + 1 + elem_col_indices = pt.make_data_wrapper((1 + np.arange(0, 20)) // 2) + row_starts = pt.make_data_wrapper(2*np.arange(0, 11)) + x = pt.make_csr_matrix( + shape=(10, 10), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = zeros_10 + 1 + z = x @ y + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + assert len(materialized_node_to_flop_count) == 6 + assert zeros_20 in materialized_node_to_flop_count + assert elem_col_indices in materialized_node_to_flop_count + assert row_starts in materialized_node_to_flop_count + assert zeros_10 in materialized_node_to_flop_count + assert z in materialized_node_to_flop_count + assert expr in materialized_node_to_flop_count + assert materialized_node_to_flop_count[zeros_20] == 0 + assert materialized_node_to_flop_count[elem_col_indices] == 0 + assert materialized_node_to_flop_count[row_starts] == 0 + assert materialized_node_to_flop_count[zeros_10] == 0 + # flops from elem_values/y/z (2 elems per row so y gets used twice) + assert materialized_node_to_flop_count[z] == 20 + 20 + 20 + 10 + # flops from u/v/expr (no flops from z because it's materialized) + assert materialized_node_to_flop_count[expr] == 40 + + # }}} + def test_unmaterialized_node_flop_counts(): from pytato.analysis import get_unmaterialized_node_flop_counts + # {{{ basic DAG + x = pt.make_placeholder("x", (10, 4)) y = pt.make_placeholder("y", (10, 4)) @@ -1189,7 +1280,7 @@ def test_unmaterialized_node_flop_counts(): unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) - # Everything except expr stays unmaterialized + # Everything except x/y/expr stays unmaterialized assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 assert z in unmaterialized_node_to_flop_counts assert all(w_i in unmaterialized_node_to_flop_counts for w_i in w) @@ -1213,6 +1304,39 @@ def test_unmaterialized_node_flop_counts(): 40*2*(i+1) + 40*i assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i + # }}} + + # {{{ CSR matmul + + elem_values = pt.make_data_wrapper(np.zeros(20)) + 1 + elem_col_indices = pt.make_data_wrapper((1 + np.arange(0, 20)) // 2) + row_starts = pt.make_data_wrapper(2*np.arange(0, 11)) + x = pt.make_csr_matrix( + shape=(10, 10), + elem_values=elem_values, + elem_col_indices=elem_col_indices, + row_starts=row_starts) + y = pt.make_data_wrapper(np.zeros(10)) + 1 + expr = x @ y + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + assert len(unmaterialized_node_to_flop_counts) == 2 + assert elem_values in unmaterialized_node_to_flop_counts + assert y in unmaterialized_node_to_flop_counts + flop_counts = unmaterialized_node_to_flop_counts[elem_values] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 20 + assert flop_counts.nflops_if_materialized == 20 + flop_counts = unmaterialized_node_to_flop_counts[y] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[expr] == 20 + assert flop_counts.nflops_if_materialized == 10 + + # }}} + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) From 7ca4d516c2a38321ebfd7fa7d6fab6c7b6adae85 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 27 Mar 2026 10:18:54 -0500 Subject: [PATCH 24/31] handle unused bindings Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pytato/analysis/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f69284ee1..694a307cc 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -993,7 +993,7 @@ def map_as_index_lambda( return self.combine( self_nflops, *( - cast("int", binding_to_nuses[name]) * self.rec(bnd, False) + cast("int", binding_to_nuses.get(name, 0)) * self.rec(bnd, False) for name, bnd in sorted(idx_lambda.bindings.items()))) @override @@ -1190,7 +1190,7 @@ def map_as_index_lambda( {expr: 1} if not is_root else {}, *( { - ary: cast("int", binding_to_nuses[name]) * nuses + ary: cast("int", binding_to_nuses.get(name, 0)) * nuses for ary, nuses in self.rec(bnd, False).items()} for name, bnd in sorted(idx_lambda.bindings.items()))) From 7a9fe60c7e677827b6d15d7a449cdff5edba155f Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 27 Mar 2026 10:50:49 -0500 Subject: [PATCH 25/31] enumerate all undefined op flop counts in exception message Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pytato/analysis/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 694a307cc..942a9bfd0 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -977,9 +977,10 @@ def map_as_index_lambda( from pytato.scalar_expr import OpFlops, OpFlopsCollector op_flops: frozenset[OpFlops] = OpFlopsCollector()(self_nflops) if op_flops: - op_name = next(iter(op_flops)).op + op_names = sorted({of.op for of in op_flops}) + formatted_ops = ", ".join(f"'{name}'" for name in op_names) raise UndefinedOpFlopCountError( - f"Undefined flop count for operation '{op_name}'.") + f"Undefined flop count for operation(s): {formatted_ops}.") else: raise _NonIntegralPerEntryFlopCountError( "Unable to compute an integer-valued per-entry flop count.") From b083ce7be8fa2ea1e8fb95f9a8caaa6046d6490a Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Mar 2026 10:36:42 -0500 Subject: [PATCH 26/31] forbid loopy calls in flop counting --- pytato/analysis/__init__.py | 38 +++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 942a9bfd0..d96c05d09 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1016,11 +1016,13 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, is_root: bool) -> in @override def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> int: + # Shouldn't have loopy calls raise AssertionError("Control shouldn't reach here.") @override def map_loopy_call_result(self, expr: LoopyCallResult, is_root: bool) -> int: - return 0 + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") @override def map_distributed_send_ref_holder( @@ -1072,6 +1074,21 @@ def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: def clone_for_callee(self, function: FunctionDefinition) -> Self: raise AssertionError("Control shouldn't reach this point.") + @override + def map_loopy_call(self, expr: LoopyCall) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + def map_loopy_call_result(self, expr: LoopyCallResult) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + @override def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): @@ -1215,12 +1232,14 @@ def map_dict_of_named_arrays( @override def map_loopy_call(self, expr: LoopyCall, is_root: bool) -> dict[Array, int]: + # Shouldn't have loopy calls raise AssertionError("Control shouldn't reach here.") @override def map_loopy_call_result( self, expr: LoopyCallResult, is_root: bool) -> dict[Array, int]: - return {} + # Shouldn't have loopy calls + raise AssertionError("Control shouldn't reach here.") @override def map_distributed_send_ref_holder( @@ -1287,6 +1306,21 @@ def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: def clone_for_callee(self, function: FunctionDefinition) -> Self: raise AssertionError("Control shouldn't reach this point.") + @override + def map_loopy_call(self, expr: LoopyCall) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + + def map_loopy_call_result(self, expr: LoopyCallResult) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support loopy calls.") + @override def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): From b386f37266c45ab0af0b6f26d470701b42cf58b8 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Mar 2026 10:55:32 -0500 Subject: [PATCH 27/31] add type annotation for scalar_op_name overrides --- pytato/reductions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index 1cb016def..ab8f265aa 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -118,7 +118,7 @@ def __eq__(self, other: object) -> bool: class SumReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "+" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -128,7 +128,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class ProductReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "*" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -138,7 +138,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MaxReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "max" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -153,7 +153,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MinReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "min" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -168,7 +168,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AllReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "and" def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -178,7 +178,7 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AnyReductionOperation(_StatelessReductionOperation): @override @classmethod - def scalar_op_name(cls): + def scalar_op_name(cls) -> str: return "or" def neutral_element(self, dtype: np.dtype[Any]) -> Any: From 5a6b2422aca1c8e1b9c5c3b9d7950e4077c9a64e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 27 Mar 2026 16:48:01 -0500 Subject: [PATCH 28/31] make power flop count more precise --- pytato/scalar_expr.py | 40 +++++++++++++++++++++++++++++++++------- test/test_pytato.py | 11 +++++++++-- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index c2ed3b47e..306605756 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -356,15 +356,41 @@ def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: @override def map_power(self, expr: prim.Power) -> ArithmeticExpression: if isinstance(expr.exponent, int): - if expr.exponent >= 0: + # The calculation below is based on the following code (which is an + # approximation of what is done in loopy) + # def pow(x, n): + # if n == 0: return 1 + # if n == 1: return x + # if n == 2: return x*x + # if n < 0: + # x = 1/x + # n = -n + # y = 1 + # while n > 1: + # if n % 2: + # y = x * y + # x = x * x + # n = n/2 + # return x*y + if expr.exponent == 0: + return 0 + elif expr.exponent > 0 and expr.exponent <= 2: return ( - expr.exponent * self._get_op_nflops("*") - + self.rec(expr.base)) - else: - return ( - self._get_op_nflops("/") - + (-expr.exponent) * self._get_op_nflops("*") + (expr.exponent - 1) * self._get_op_nflops("*") + self.rec(expr.base)) + nmults = 1 + remaining_exp = abs(expr.exponent) + while remaining_exp > 1: + if remaining_exp % 2: + nmults += 1 + nmults += 1 + remaining_exp //= 2 + nflops = ( + nmults * self._get_op_nflops("*") + + self.rec(expr.base)) + if expr.exponent < 0: + nflops += self._get_op_nflops("/") + return nflops else: return ( self._get_op_nflops("**") diff --git a/test/test_pytato.py b/test/test_pytato.py index 7391dd1fb..d3fbb41c9 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1003,9 +1003,16 @@ def test_scalar_flop_count(): assert fc(x % 2) == 0 - assert fc(x ** 3) == 3 - assert fc(x ** (-3)) == 7 + assert fc(x ** 0) == 0 + assert fc(x ** 1) == 0 + # x * x + assert fc(x ** 2) == 1 + # compute x^2, x^4, x^8, x^16, x^32; multiply x^32 * x^16 * x^8 * x^4 * x * 1 + assert fc(x ** 61) == 5 + 5 + # divide; compute x^2, x^4, x^8, x^16; multiply x^16 * x^4 * x^2 * 1 + assert fc(x ** -22) == 4 + 4 + 3 assert fc(x ** 0.3) == 8 + assert fc(x ** y) == 8 assert fc(x.lt(y)) == 1 From 87b43524d2268362b05ba1d78f9db544cf852134 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 30 Mar 2026 15:52:19 -0500 Subject: [PATCH 29/31] fix remainder flop counting --- pytato/scalar_expr.py | 7 +++++++ test/test_pytato.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 306605756..e5747a75b 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -353,6 +353,13 @@ def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: + self.rec(expr.numerator) + self.rec(expr.denominator)) + @override + def map_remainder(self, expr: prim.Remainder) -> ArithmeticExpression: + return ( + self._get_op_nflops("%") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + @override def map_power(self, expr: prim.Power) -> ArithmeticExpression: if isinstance(expr.exponent, int): diff --git a/test/test_pytato.py b/test/test_pytato.py index d3fbb41c9..9bf9593ed 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1001,7 +1001,7 @@ def test_scalar_flop_count(): assert fc(x // 2) == 4 - assert fc(x % 2) == 0 + assert fc(x % 2) == 4 assert fc(x ** 0) == 0 assert fc(x ** 1) == 0 From fe658de2f470ee10f23834ab098a4cf28b673ac0 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 30 Mar 2026 15:57:00 -0500 Subject: [PATCH 30/31] test recursion in test_scalar_flop_count --- pytato/scalar_expr.py | 4 +-- test/test_pytato.py | 70 +++++++++++++++++++++++-------------------- 2 files changed, 39 insertions(+), 35 deletions(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index e5747a75b..fcc6ce59d 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -318,8 +318,8 @@ def map_call(self, expr: prim.Call) -> ArithmeticExpression: @override def map_subscript(self, expr: prim.Subscript) -> ArithmeticExpression: - # Assume calculations inside subscripts are performed on non-floats - return 0 + # Assume index calculations are performed on non-floats + return self.rec(expr.aggregate) @override def map_sum(self, expr: prim.Sum) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] diff --git a/test/test_pytato.py b/test/test_pytato.py index 9bf9593ed..8c1764e6a 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -969,64 +969,68 @@ def test_scalar_flop_count(): import pymbolic.primitives as prim from pymbolic import Variable - x = Variable("x") - y = Variable("y") + x = 2*Variable("x") + y = 3 + Variable("y") - assert fc(Variable("f")(x)) == 32 + assert fc(x) == 1 + assert fc(y) == 1 - assert fc(x[0]) == 0 + assert fc(Variable("f")(x)) == 32 + 1 - assert fc(x + 2) == 1 - assert fc(2 + y) == 1 - assert fc(x + y) == 1 + assert fc(x[0]) == 0 + 1 - assert fc(prim.Sum((2, x, y))) == 2 + assert fc(x + 2) == 1 + 1 + assert fc(2 + y) == 1 + 1 + assert fc(x + y) == 1 + 2 - assert fc(x - 2) == 1 - assert fc(2 - y) == 2 - assert fc(x - y) == 2 + assert fc(prim.Sum((2, x, y))) == 2 + 2 - assert fc(x * 2) == 1 - assert fc(2 * y) == 1 - assert fc(x * y) == 1 + assert fc(x - 2) == 1 + 1 + assert fc(2 - y) == 2 + 1 + assert fc(x - y) == 2 + 2 - assert fc(prim.Product((2, x, y))) == 2 + assert fc(x * 2) == 1 + 1 + assert fc(2 * y) == 1 + 1 + assert fc(x * y) == 1 + 2 - assert fc(x.or_(y)) == 0 - assert fc(x.and_(y)) == 0 + assert fc(prim.Product((2, x, y))) == 2 + 2 - assert fc(x / 2) == 4 - assert fc(2 / y) == 4 - assert fc(x / y) == 4 + assert fc(x.or_(y)) == 0 + 2 + assert fc(x.and_(y)) == 0 + 2 - assert fc(x // 2) == 4 + assert fc(x / 2) == 4 + 1 + assert fc(2 / y) == 4 + 1 + assert fc(x / y) == 4 + 2 - assert fc(x % 2) == 4 + assert fc(x // 2) == 4 + 1 + + assert fc(x % 2) == 4 + 1 assert fc(x ** 0) == 0 - assert fc(x ** 1) == 0 + assert fc(x ** 1) == 0 + 1 # x * x - assert fc(x ** 2) == 1 + assert fc(x ** 2) == 1 + 1 # compute x^2, x^4, x^8, x^16, x^32; multiply x^32 * x^16 * x^8 * x^4 * x * 1 - assert fc(x ** 61) == 5 + 5 + assert fc(x ** 61) == 5 + 5 + 1 # divide; compute x^2, x^4, x^8, x^16; multiply x^16 * x^4 * x^2 * 1 - assert fc(x ** -22) == 4 + 4 + 3 - assert fc(x ** 0.3) == 8 - assert fc(x ** y) == 8 + assert fc(x ** -22) == 4 + 4 + 3 + 1 + assert fc(x ** 0.3) == 8 + 1 + assert fc(x ** y) == 8 + 2 - assert fc(x.lt(y)) == 1 + assert fc(x.lt(y)) == 1 + 2 - assert fc(prim.If(x, x, y)) == 0 + assert fc(prim.If(x, x, y)) == 0 + 3 - assert fc(prim.Min((2, x, y))) == 2 - assert fc(prim.Max((2, x, y))) == 2 + assert fc(prim.Min((2, x, y))) == 2 + 2 + assert fc(prim.Max((2, x, y))) == 2 + 2 from constantdict import constantdict from pytato.reductions import SumReductionOperation from pytato.scalar_expr import Reduce - assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) == 9 + assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) \ + == 9 + 10 def test_flop_count(): From 97bfcb5350e883b3e95b1d3b06f90b9765c996a1 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 30 Mar 2026 15:58:14 -0500 Subject: [PATCH 31/31] add note about flop counting for subscript indices --- pytato/analysis/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d96c05d09..6c23bcb4c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1439,6 +1439,11 @@ def get_num_flops( this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + .. note:: Does not support functions. Function calls must be inlined before calling. @@ -1476,6 +1481,11 @@ def get_materialized_node_flop_counts( this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + .. note:: Does not support functions. Function calls must be inlined before calling. @@ -1519,6 +1529,11 @@ def get_unmaterialized_node_flop_counts( this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + .. note:: + + Calculations for array subscripts are currently assumed to be integer-typed. + Any floating point operations contained within will be ignored. + .. note:: Does not support functions. Function calls must be inlined before calling.