From b816f6754c6881e71c333b6684a389cf86570b93 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 24 Oct 2025 15:51:27 -0500 Subject: [PATCH 01/14] add flop counting functions --- pytato/analysis/__init__.py | 359 +++++++++++++++++++++++++++++++++++- pytato/reductions.py | 36 ++++ pytato/scalar_expr.py | 136 ++++++++++++++ pytato/transform/calls.py | 94 +++++++++- pytato/utils.py | 21 ++- test/test_pytato.py | 281 ++++++++++++++++++++++++++++ 6 files changed, 923 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 00e36ce32..97cd18a35 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -27,16 +27,19 @@ """ from collections import defaultdict -from typing import TYPE_CHECKING, Any, overload +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast, overload from orderedsets import FrozenOrderedSet from typing_extensions import Never, Self, override from loopy.tools import LoopyKeyBuilder from pymbolic.mapper.optimize import optimize_mapper +from pytools import product from pytato.array import ( Array, + ArrayOrScalar, Concatenate, DictOfNamedArrays, Einsum, @@ -49,13 +52,21 @@ Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.scalar_expr import ( + FlopCounter as ScalarFlopCounter, +) +from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, + ArrayOrNamesTc, CachedWalkMapper, CombineMapper, Mapper, VisitKeyT, + map_and_copy, ) +from pytato.transform.lower_to_index_lambda import to_index_lambda +from pytato.utils import is_materializable if TYPE_CHECKING: @@ -87,6 +98,13 @@ .. autoclass:: TagCountMapper .. autofunction:: get_num_tags_of_type + +.. autoclass:: UndefinedOpFlopCountError +.. autofunction:: get_default_op_name_to_num_flops +.. autofunction:: get_num_flops +.. autofunction:: get_materialized_node_flop_counts +.. autoclass:: UnmaterializedNodeFlopCounts +.. autofunction:: get_unmaterialized_node_flop_counts """ @@ -755,4 +773,343 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # }}} + +# {{{ flop counting + +def _is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + return ( + is_materializable(expr) + and bool(expr.tags_of_type(ImplStored))) + + +def _is_unmaterialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + return ( + is_materializable(expr) + and not bool(expr.tags_of_type(ImplStored))) + + +@dataclass +class UndefinedOpFlopCountError(ValueError): + op_name: str + + +class _PerEntryFlopCounter(CombineMapper[int, Never]): + def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() + self.scalar_flop_counter: ScalarFlopCounter = ScalarFlopCounter( + op_name_to_num_flops) + self.node_to_nflops: dict[Array, int] = {} + + @override + def combine(self, *args: int) -> int: + return sum(args) + + @override + def rec(self, expr: ArrayOrNames) -> int: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: int + if _is_unmaterialized(expr): + assert isinstance(expr, Array) + self_nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + if not isinstance(self_nflops, int): + from pytato.scalar_expr import InputGatherer as ScalarInputGatherer + var_names: set[str] = set(ScalarInputGatherer()(self_nflops)) + var_names.discard("nflops") + if var_names: + raise UndefinedOpFlopCountError(next(iter(var_names))) from None + else: + raise AssertionError from None + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = self_nflops + cast("int", Mapper.rec(self, expr)) + else: + result = 0 + if isinstance(expr, Array): + self.node_to_nflops[expr] = result + return self._cache_add(inputs, result) + + +class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the number of floating point operations of each materialized + expression in a DAG. + + .. note:: + + Flops from nodes inside function calls are accumulated onto the corresponding + call node. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + _visited_functions: set[VisitKeyT] | None = None, + _function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] | None = None + ) -> None: + super().__init__(_visited_functions=_visited_functions) + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} + self.call_to_nflops: dict[Call, ArrayOrScalar] = {} + self._function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] = \ + _function_to_nflops if _function_to_nflops is not None else {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + return type(self)( + op_name_to_num_flops=self.op_name_to_num_flops, + _visited_functions=self._visited_functions, + _function_to_nflops=self._function_to_nflops) + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee(expr) + for subexpr in expr.returns.values(): + # Assume that any calls that haven't been inlined have their functions' + # outputs materialized + assert not _is_unmaterialized(subexpr) + new_mapper(subexpr) + + self._function_to_nflops[expr] = ( + sum(new_mapper.materialized_node_to_nflops.values()) + + sum(new_mapper.call_to_nflops.values())) + + self.post_visit(expr) + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + self.rec_function_definition(expr.function) + for bnd in expr.bindings.values(): + # Assume that any calls that haven't been inlined have their inputs + # materialized + assert not _is_unmaterialized(bnd) + self.rec(bnd) + + self.call_to_nflops[expr] = self._function_to_nflops[expr.function] + + self.post_visit(expr) + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not _is_materialized(expr): + return + assert isinstance(expr, Array) + unmaterialized_expr = expr.without_tags(ImplStored()) + self._per_entry_flop_counter(unmaterialized_expr) + self.materialized_node_to_nflops[expr] = ( + product(expr.shape) + * self._per_entry_flop_counter.node_to_nflops[unmaterialized_expr]) + + +class _UnmaterializedSubexpressionUseCounter(CombineMapper[dict[Array, int], Never]): + @override + def combine(self, *args: dict[Array, int]) -> dict[Array, int]: + result: dict[Array, int] = defaultdict(int) + for arg in args: + for ary, nuses in arg.items(): + result[ary] += nuses + return result + + @override + def rec(self, expr: ArrayOrNames) -> dict[Array, int]: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + result: dict[Array, int] + if _is_unmaterialized(expr): + assert isinstance(expr, Array) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = self.combine( + {expr: 1}, cast("dict[Array, int]", Mapper.rec(self, expr))) + else: + result = {} + return self._cache_add(inputs, result) + + +@dataclass +class UnmaterializedNodeFlopCounts: + materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] + nflops_if_materialized: ArrayOrScalar + + +class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): + """ + Mapper that counts the accumulated number of floating point operations that each + unmaterialized expression contributes to materialized expressions in the DAG. + + .. note:: + + This mapper does not descend into functions. + """ + def __init__( + self, + op_name_to_num_flops: Mapping[str, int], + _visited_functions: set[VisitKeyT] | None = None) -> None: + super().__init__(_visited_functions=_visited_functions) + self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops + self.unmaterialized_node_to_flop_counts: \ + dict[Array, UnmaterializedNodeFlopCounts] = {} + self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( + self.op_name_to_num_flops) + + @override + def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: + return expr + + @override + def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: + return expr + + @override + def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + if not _is_materialized(expr): + return + assert isinstance(expr, Array) + unmaterialized_expr = expr.without_tags(ImplStored()) + subexpr_to_nuses = _UnmaterializedSubexpressionUseCounter()( + unmaterialized_expr) + del subexpr_to_nuses[unmaterialized_expr] + self._per_entry_flop_counter(unmaterialized_expr) + for subexpr, nuses in subexpr_to_nuses.items(): + per_entry_nflops = self._per_entry_flop_counter.node_to_nflops[subexpr] + if subexpr not in self.unmaterialized_node_to_flop_counts: + nflops_if_materialized = product(subexpr.shape) * per_entry_nflops + flop_counts = UnmaterializedNodeFlopCounts({}, nflops_if_materialized) + self.unmaterialized_node_to_flop_counts[subexpr] = flop_counts + else: + flop_counts = self.unmaterialized_node_to_flop_counts[subexpr] + assert expr not in flop_counts.materialized_successor_to_contrib_nflops + flop_counts.materialized_successor_to_contrib_nflops[expr] = ( + nuses * product(expr.shape) * per_entry_nflops) + + +# FIXME: Should this be added to normalize_outputs? +def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + # Make sure call bindings and function results are materialized + from pytato.transform.calls import normalize_calls + expr = normalize_calls(expr) + + # Make sure outputs are materialized + if isinstance(expr, DictOfNamedArrays): + output_to_materialized_output: dict[Array, Array] = { + ary: ary.tagged(ImplStored()) if is_materializable(ary) else ary + for ary in expr._data.values()} + + def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: + if not isinstance(ary, Array): + return ary + try: + return output_to_materialized_output[ary] + except KeyError: + return ary + + expr = map_and_copy(expr, replace_with_materialized) + + return expr + + +def get_default_op_name_to_num_flops() -> dict[str, int]: + """ + Returns a mapping from operator name to floating point operation count for + operators that are almost always a single flop. + """ + return { + "+": 1, + "*": 1, + "==": 1, + "!=": 1, + "<": 1, + ">": 1, + "<=": 1, + ">=": 1, + "min": 1, + "max": 1} + + +def get_num_flops( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> ArrayOrScalar: + """Count the total number of floating point operations in the DAG *expr*.""" + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return ( + sum(fc.materialized_node_to_nflops.values()) + + sum(fc.call_to_nflops.values())) + + +def get_materialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, ArrayOrScalar]: + """ + Returns a dictionary mapping materialized nodes in DAG *expr* to their floating + point operation count. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = MaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.materialized_node_to_nflops + + +def get_unmaterialized_node_flop_counts( + expr: ArrayOrNames, + op_name_to_num_flops: Mapping[str, int] | None = None, + ) -> dict[Array, UnmaterializedNodeFlopCounts]: + """ + Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a + :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count + information. + """ + from pytato.codegen import normalize_outputs + expr = normalize_outputs(expr) + expr = _normalize_materialization(expr) + + if op_name_to_num_flops is None: + op_name_to_num_flops = get_default_op_name_to_num_flops() + + fc = UnmaterializedNodeFlopCounter(op_name_to_num_flops) + fc(expr) + + return fc.unmaterialized_node_to_flop_counts + +# }}} + + # vim: fdm=marker diff --git a/pytato/reductions.py b/pytato/reductions.py index 6efa45ac2..a35c1d98a 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -34,6 +34,7 @@ import numpy as np from constantdict import constantdict +from typing_extensions import override import pymbolic.primitives as prim from pymbolic import ArithmeticExpression @@ -80,10 +81,15 @@ class _NoValue: class ReductionOperation(ABC): """ + .. automethod:: scalar_op_name .. automethod:: neutral_element .. automethod:: __hash__ .. automethod:: __eq__ """ + @classmethod + @abstractmethod + def scalar_op_name(cls) -> str: + pass @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: @@ -110,16 +116,31 @@ def __eq__(self, other: object) -> bool: class SumReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "+" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 0 class ProductReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "*" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return 1 class MaxReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "max" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("-inf")) @@ -130,6 +151,11 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class MinReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "min" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: if dtype.kind == "f": return dtype.type(float("inf")) @@ -140,11 +166,21 @@ def neutral_element(self, dtype: np.dtype[Any]) -> Any: class AllReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "or" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(True) class AnyReductionOperation(_StatelessReductionOperation): + @override + @classmethod + def scalar_op_name(cls): + return "and" + def neutral_element(self, dtype: np.dtype[Any]) -> Any: return np.bool_(False) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index e59cbdce8..72c4adeaf 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -57,6 +57,7 @@ import pymbolic.primitives as prim from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass from pymbolic.mapper import ( + Collector, CombineMapper as CombineMapperBase, IdentityMapper as IdentityMapperBase, P, @@ -70,6 +71,7 @@ ) from pymbolic.mapper.distributor import DistributeMapper as DistributeMapperBase from pymbolic.mapper.evaluator import EvaluationMapper as EvaluationMapperBase +from pymbolic.mapper.flop_counter import FlopCounterBase from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import SubstitutionMapper as SubstitutionMapperBase from pymbolic.typing import Integer @@ -241,6 +243,140 @@ def map_type_cast(self, inner_str = self.rec(expr.inner_expr, PREC_NONE, *args, **kwargs) return f"cast({expr.dtype}, {inner_str})" + +class InputGatherer(Collector[str, []]): + @override + def map_variable(self, expr: prim.Variable) -> set[str]: + return {expr.name} + + +class FlopCounter(FlopCounterBase): + def __init__( + self, + op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): + super().__init__() + self.op_name_to_num_flops: dict[str, ArithmeticExpression] + if op_name_to_num_flops: + self.op_name_to_num_flops = dict(op_name_to_num_flops) + else: + self.op_name_to_num_flops = {} + + def _get_op_nflops(self, name: str) -> ArithmeticExpression: + try: + return self.op_name_to_num_flops[name] + except KeyError: + from pymbolic import var + result = var("nflops")(var(name)) + self.op_name_to_num_flops[name] = result + return result + + @override + def map_call(self, expr: prim.Call) -> ArithmeticExpression: + assert isinstance(expr.function, prim.Variable) + return ( + self._get_op_nflops(expr.function.name) + + sum(self.rec(child) for child in expr.parameters)) + + @override + def map_subscript(self, expr: prim.Subscript) -> ArithmeticExpression: + # Assume calculations inside subscripts are performed on non-floats + return 0 + + @override + def map_sum(self, expr: prim.Sum) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + if expr.children: + return ( + self._get_op_nflops("+") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_product(self, expr: prim.Product) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("*") * (len(expr.children) - 1) + + sum(self.rec(ch) for ch in expr.children)) + else: + return 0 + + @override + def map_quotient(self, expr: prim.Quotient) -> ArithmeticExpression: # pyright: ignore[reportIncompatibleMethodOverride] + return ( + self._get_op_nflops("/") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_floor_div(self, expr: prim.FloorDiv) -> ArithmeticExpression: + return ( + self._get_op_nflops("//") + + self.rec(expr.numerator) + + self.rec(expr.denominator)) + + @override + def map_power(self, expr: prim.Power) -> ArithmeticExpression: + if isinstance(expr.exponent, int): + if expr.exponent >= 0: + return ( + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("/") + + expr.exponent * self._get_op_nflops("*") + + self.rec(expr.base)) + else: + return ( + self._get_op_nflops("**") + + self.rec(expr.base) + + self.rec(expr.exponent)) + + @override + def map_comparison(self, expr: prim.Comparison) -> ArithmeticExpression: + return ( + self._get_op_nflops(expr.operator) + + self.rec(expr.left) + + self.rec(expr.right)) + + @override + def map_if(self, expr: prim.If) -> ArithmeticExpression: + return ( + self.rec(expr.condition) + + self.rec(expr.then) + + self.rec(expr.else_)) + + @override + def map_max(self, expr: prim.Max) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("max") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_min(self, expr: prim.Min) -> ArithmeticExpression: + if expr.children: + return ( + self._get_op_nflops("min") * (len(expr.children) - 1) + + sum(self.rec(child) for child in expr.children)) + else: + return 0 + + @override + def map_nan(self, expr: prim.NaN) -> ArithmeticExpression: + return 0 + + def map_reduce(self, expr: Reduce) -> ArithmeticExpression: + result = self.rec(expr.inner_expr) + nflops_op = self._get_op_nflops(expr.op.scalar_op_name()) + for lower_bd, upper_bd in expr.bounds.values(): + nops = upper_bd - lower_bd + result = result * nops + nflops_op * (nops-1) + + return result + # }}} diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index ffad1e5f7..16c4413b2 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,6 +1,7 @@ """ .. currentmodule:: pytato.transform.calls +.. autofunction:: normalize_calls .. autofunction:: inline_calls .. autofunction:: tag_all_calls_to_be_inlined """ @@ -32,20 +33,26 @@ from typing import TYPE_CHECKING, cast -from typing_extensions import Self +from immutabledict import immutabledict +from typing_extensions import Never, Self, override from pytato.array import ( AbstractResultWithNamedArrays, Array, + DataWrapper, DictOfNamedArrays, Placeholder, + SizeParam, + make_dict_of_named_arrays, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.tags import InlineCallTag +from pytato.tags import ImplStored, InlineCallTag from pytato.transform import ( ArrayOrNames, ArrayOrNamesTc, + CombineMapper, CopyMapper, + Mapper, TransformMapperCache, _verify_is_array, deduplicate, @@ -56,6 +63,89 @@ from collections.abc import Mapping +# {{{ normalizing + +class _LocalStackCallBindingCollector(CombineMapper[frozenset[Array], Never]): + """Mapper to collect bindings of calls on the current call stack.""" + @override + def combine(self, *args: frozenset[Array]) -> frozenset[Array]: + from functools import reduce + return reduce(lambda a, b: a | b, args, cast("frozenset[Array]", frozenset())) + + @override + def map_call(self, expr: Call) -> frozenset[Array]: + return frozenset(expr.bindings.values()) + + +class _CallMaterializer(CopyMapper): + """Mapper to add materialization tags for call bindings and function results.""" + def __init__( + self, + local_call_bindings: frozenset[Array], + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__(_cache=_cache, _function_cache=_function_cache) + self.local_call_bindings: frozenset[Array] = local_call_bindings + + @override + def clone_for_callee(self, function: FunctionDefinition) -> Self: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + local_call_bindings = _LocalStackCallBindingCollector()( + make_dict_of_named_arrays(function.returns)) + return type(self)( + local_call_bindings, + _function_cache=cast( + "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) + + def _materialize_if_possible(self, expr: ArrayOrNames) -> ArrayOrNames: + if ( + isinstance(expr, Array) + and not isinstance(expr, + (DataWrapper, Placeholder, SizeParam, NamedCallResult))): + return expr.tagged(ImplStored()) + else: + return expr + + @override + def map_function_definition(self, + expr: FunctionDefinition) -> FunctionDefinition: + new_mapper = self.clone_for_callee(expr) + new_returns: Mapping[str, Array] = immutabledict({ + name: self._materialize_if_possible(_verify_is_array(new_mapper(ret))) + for name, ret in expr.returns.items()}) + return expr.replace_if_different(returns=new_returns) + + @override + def rec(self, expr: ArrayOrNames) -> ArrayOrNames: + inputs = self._make_cache_inputs(expr) + try: + return self._cache_retrieve(inputs) + except KeyError: + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + result = cast("ArrayOrNames", Mapper.rec(self, expr)) + if expr in self.local_call_bindings: + result = self._materialize_if_possible(result) + return self._cache_add(inputs, result) + + +def normalize_calls(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: + """ + Ensure that calls/functions are defined uniformly. + + Adds any missing materialization tags for call bindings and function results. + """ + local_call_bindings = _LocalStackCallBindingCollector()(expr) + return _CallMaterializer(local_call_bindings)(expr) + +# }}} + + # {{{ inlining class PlaceholderSubstitutor(CopyMapper): diff --git a/pytato/utils.py b/pytato/utils.py index 19dc08ef7..92c1a1cbb 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -61,7 +61,7 @@ ScalarExpression, TypeCast, ) -from pytato.transform import CachedMapper +from pytato.transform import ArrayOrNames, CachedMapper if TYPE_CHECKING: @@ -69,6 +69,8 @@ from pytools.tag import Tag + from pytato.function import FunctionDefinition + __doc__ = """ Helper routines @@ -80,6 +82,7 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str +.. autofunction:: is_materializable References ^^^^^^^^^^ @@ -735,4 +738,20 @@ def get_einsum_specification(expr: Einsum) -> str: for i in range(expr.ndim)) return f"{','.join(input_specs)}->{output_spec}" + + +def is_materializable(expr: ArrayOrNames | FunctionDefinition) -> bool: + """ + Returns *True* if *expr* is an instance of an array type that can be materialized. + """ + from pytato.array import InputArgumentBase, NamedArray + from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder + return ( + isinstance(expr, Array) + and not isinstance(expr, ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, NamedArray, DistributedRecv, + DistributedSendRefHolder))) + + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 8746810ae..1c5ca303a 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -770,6 +770,287 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_scalar_flop_count(): + from pytato.scalar_expr import FlopCounter + fc = FlopCounter({ + "+": 1, + "*": 1, + "/": 4, + "//": 4, + "%": 4, + "**": 8, + "<": 1, + "min": 1, + "max": 1, + "f": 32}) + + import pymbolic.primitives as prim + from pymbolic import Variable + + x = Variable("x") + y = Variable("y") + + assert fc(Variable("f")(x)) == 32 + + assert fc(x[0]) == 0 + + assert fc(x + 2) == 1 + assert fc(2 + y) == 1 + assert fc(x + y) == 1 + + assert fc(prim.Sum((2, x, y))) == 2 + + assert fc(x - 2) == 1 + assert fc(2 - y) == 2 + assert fc(x - y) == 2 + + assert fc(x * 2) == 1 + assert fc(2 * y) == 1 + assert fc(x * y) == 1 + + assert fc(prim.Product((2, x, y))) == 2 + + assert fc(x.or_(y)) == 0 + assert fc(x.and_(y)) == 0 + + assert fc(x / 2) == 4 + assert fc(2 / y) == 4 + assert fc(x / y) == 4 + + assert fc(x // 2) == 4 + + assert fc(x % 2) == 0 + + assert fc(x ** 3) == 3 + assert fc(x ** 0.3) == 8 + + assert fc(x.lt(y)) == 1 + + assert fc(prim.If(x, x, y)) == 0 + + assert fc(prim.Min((2, x, y))) == 2 + assert fc(prim.Max((2, x, y))) == 2 + + from constantdict import constantdict + + from pytato.reductions import SumReductionOperation + from pytato.scalar_expr import Reduce + + assert fc(Reduce(x, SumReductionOperation(), constantdict({"_0": (0, 10)}))) == 9 + + +def test_flop_count(): + from pytato.analysis import ( + UndefinedOpFlopCountError, + get_default_op_name_to_num_flops, + get_num_flops, + ) + from pytato.tags import ImplStored + + # {{{ basic expression + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = u - v + + # expr[i, j] = 2*(x[i, j] + y[i, j]) + (-1)*3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*6 + + # }}} + + # {{{ expression with operators that don't have default flop counts + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + expr = pt.cmath.exp(x / y) + + with pytest.raises(UndefinedOpFlopCountError): + get_num_flops(expr) + + op_name_to_num_flops = get_default_op_name_to_num_flops() + op_name_to_num_flops.update({ + "/": 4, + "pytato.c99.exp": 8}) + + assert get_num_flops(expr, op_name_to_num_flops) == 40*12 + + # }}} + + # {{{ multiple expressions + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = pt.make_dict_of_named_arrays({"u": u, "v": v}) + + # expr["u"][i, j] = 2*(x[i, j] + y[i, j]) + # expr["v"][i, j] = 3*(x[i, j] + y[i, j]) + assert get_num_flops(expr) == 40*4 + + # }}} + + # {{{ subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = x + y + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # expr[i, j] = 2*(x[2*i, j] + y[2*i, j]) + (-1)*3*(x[2*i, j] + y[2*i, j]) + assert get_num_flops(expr) == 20*6 + + # }}} + + # {{{ materialized array + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert get_num_flops(expr) == 40 + 40*4 + + # }}} + + # {{{ materialized array and subscripting + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = (u - v)[::2, :] + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[2*i, j] + (-1)*3*z[2*i, j] + assert get_num_flops(expr) == 40 + 20*4 + + # }}} + + # {{{ function call + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + def f(x, y): + z = x + y + return 2*z, 3*z + + u, v = pt.trace_call(f, x, y) + expr = u - v + + # u[i, j] = 2*(x[i, j] + y[i, j]) + # v[i, j] = 3*(x[i, j] + y[i, j]) + # expr[i, j] = u[i, j] + (-1)*v[i, j] + assert get_num_flops(expr) == 40*2 + 40*2 + 40*2 + + # }}} + + # {{{ einsum + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->ijk", x, y) + + # expr[i, j, k] = x[i, j, k] * y[j, k] + assert get_num_flops(expr) == 24 + + x = pt.make_placeholder("x", (2, 3, 4)) + y = pt.make_placeholder("y", (3, 4)) + expr = pt.einsum("ijk,jk->i", x, y) + + # expr[i] = sum(sum(x[i, j, k] * y[j, k], j), k) + assert get_num_flops(expr) == 2*(4 * (3*1 + 2) + 3) + + # }}} + + +def test_materialized_node_flop_counts(): + from pytato.analysis import get_materialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + z = (x + y).tagged(ImplStored()) + u = 2*z + v = 3*z + expr = u - v + + materialized_node_to_flop_count = get_materialized_node_flop_counts(expr) + + # z[i, j] = x[i, j] + y[i, j] + # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] + assert len(materialized_node_to_flop_count) == 2 + assert z in materialized_node_to_flop_count + assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert materialized_node_to_flop_count[z] == 40 + assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 + + +def test_unmaterialized_node_flop_counts(): + from pytato.analysis import get_unmaterialized_node_flop_counts + from pytato.tags import ImplStored + + x = pt.make_placeholder("x", (10, 4)) + y = pt.make_placeholder("y", (10, 4)) + + # Make a reduction over a bunch of expressions that reference z + z = x + y + w = [i*z for i in range(1, 11)] + s = [w[0]] + for w_i in w[1:-1]: + s.append(s[-1] + w_i) + expr = s[-1] + w[-1] + + unmaterialized_node_to_flop_counts = get_unmaterialized_node_flop_counts(expr) + + materialized_expr = expr.tagged(ImplStored()) + + # Everything except expr stays unmaterialized + assert len(unmaterialized_node_to_flop_counts) == 1 + 10 + 8 + assert z in unmaterialized_node_to_flop_counts + assert all(w_i in unmaterialized_node_to_flop_counts for w_i in w) + assert all(s_i in unmaterialized_node_to_flop_counts for s_i in s) + flop_counts = unmaterialized_node_to_flop_counts[z] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*10 + assert flop_counts.nflops_if_materialized == 40 + for w_i in w: + flop_counts = unmaterialized_node_to_flop_counts[w_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2 + assert flop_counts.nflops_if_materialized == 40*2 + for i, s_i in enumerate(s): + flop_counts = unmaterialized_node_to_flop_counts[s_i] + assert len(flop_counts.materialized_successor_to_contrib_nflops) == 1 + assert materialized_expr in flop_counts.materialized_successor_to_contrib_nflops + assert flop_counts.materialized_successor_to_contrib_nflops[materialized_expr] \ + == 40*2*(i+1) + 40*i + assert flop_counts.nflops_if_materialized == 40*2*(i+1) + 40*i + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) From f86b92cccfefcd4149e600ffbab68e7ce1218b67 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 27 Oct 2025 16:17:59 -0500 Subject: [PATCH 02/14] Update baseline --- .basedpyright/baseline.json | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index dd72c4312..3aa81747c 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2016,6 +2016,22 @@ "endColumn": 25, "lineCount": 1 } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 51, + "endColumn": 61, + "lineCount": 1 + } + }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 56, + "endColumn": 66, + "lineCount": 1 + } } ], "./pytato/array.py": [ @@ -7963,6 +7979,14 @@ "lineCount": 1 } }, + { + "code": "reportUnknownMemberType", + "range": { + "startColumn": 42, + "endColumn": 52, + "lineCount": 1 + } + }, { "code": "reportUnannotatedClassAttribute", "range": { From a71771ef44b1882d1a9e72634cab8b5489571380 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 10 Nov 2025 16:49:36 -0600 Subject: [PATCH 03/14] add note to docs about assumptions when handling conditional expressions --- pytato/analysis/__init__.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 97cd18a35..fdbc7100d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1051,7 +1051,15 @@ def get_num_flops( expr: ArrayOrNames, op_name_to_num_flops: Mapping[str, int] | None = None, ) -> ArrayOrScalar: - """Count the total number of floating point operations in the DAG *expr*.""" + """ + Count the total number of floating point operations in the DAG *expr*. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. + """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) expr = _normalize_materialization(expr) @@ -1074,6 +1082,12 @@ def get_materialized_node_flop_counts( """ Returns a dictionary mapping materialized nodes in DAG *expr* to their floating point operation count. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1096,6 +1110,12 @@ def get_unmaterialized_node_flop_counts( Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count information. + + .. note:: + + For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, + this function assumes a SIMT-like model of computation in which the per-entry + cost is the sum of the costs of the two branches. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) From 2ac25c45c5d503767b5840446b580375d9e3d499 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 10:46:02 -0600 Subject: [PATCH 04/14] change 'pass' -> '...' in ReductionOperation abstract methods --- pytato/reductions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index a35c1d98a..a229328c8 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -89,19 +89,19 @@ class ReductionOperation(ABC): @classmethod @abstractmethod def scalar_op_name(cls) -> str: - pass + ... @abstractmethod def neutral_element(self, dtype: np.dtype[Any]) -> Any: - pass + ... @abstractmethod def __hash__(self) -> int: - pass + ... @abstractmethod def __eq__(self, other: object) -> bool: - pass + ... class _StatelessReductionOperation(ReductionOperation): From c3921038fc572aceae4b624ae28cfcf8e25eff16 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 10:52:14 -0600 Subject: [PATCH 05/14] move op_name_to_num_flops type declaration to FlopCounter class body --- pytato/scalar_expr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 72c4adeaf..e78d1ce43 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -251,11 +251,12 @@ def map_variable(self, expr: prim.Variable) -> set[str]: class FlopCounter(FlopCounterBase): + op_name_to_num_flops: dict[str, ArithmeticExpression] + def __init__( self, op_name_to_num_flops: Mapping[str, ArithmeticExpression] | None = None): super().__init__() - self.op_name_to_num_flops: dict[str, ArithmeticExpression] if op_name_to_num_flops: self.op_name_to_num_flops = dict(op_name_to_num_flops) else: From 477e9b5cbd39d841d04e51c7a7921cdcd60111ce Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 15:31:55 -0600 Subject: [PATCH 06/14] add some details about how flop counts are computed --- pytato/analysis/__init__.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fdbc7100d..964ab2502 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -949,6 +949,10 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]: @dataclass class UnmaterializedNodeFlopCounts: + """ + Floating point operation counts for an unmaterialized node. See + :func:`get_unmaterialized_node_flop_counts` for details. + """ materialized_successor_to_contrib_nflops: dict[Array, ArrayOrScalar] nflops_if_materialized: ArrayOrScalar @@ -1054,6 +1058,11 @@ def get_num_flops( """ Count the total number of floating point operations in the DAG *expr*. + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. The total flop count is the sum + over all materialized nodes. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, @@ -1083,6 +1092,10 @@ def get_materialized_node_flop_counts( Returns a dictionary mapping materialized nodes in DAG *expr* to their floating point operation count. + Counts flops as if emitting a statement at each materialized node (i.e., a node + tagged with :class:`pytato.tags.ImplStored`) that computes everything up to + (not including) its materialized predecessors. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, @@ -1111,6 +1124,15 @@ def get_unmaterialized_node_flop_counts( :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count information. + The :class:`UnmaterializedNodeFlopCounts` instance for each unmaterialized node + (i.e., a node that can be tagged with :class:`pytato.tags.ImplStored` but isn't) + contains `materialized_successor_to_contrib_nflops` and `nflops_if_materialized` + attributes. The former is a mapping from each materialized successor of the + unmaterialized node to the number of flops the node contributes to evaluating + that successor (this includes flops from the predecessors of the unmaterialized + node). The latter is the number of flops that would be required to evaluate the + unmaterialized node if it was materialized instead. + .. note:: For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, From 2571a89494b9efa4d479052e15acd5236edc6094 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 5 Dec 2025 15:33:20 -0600 Subject: [PATCH 07/14] clarify how flop counting functions behave w.r.t. DAG functions --- pytato/analysis/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 964ab2502..a4d1127de 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -1068,6 +1068,10 @@ def get_num_flops( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1101,6 +1105,10 @@ def get_materialized_node_flop_counts( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does not* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1138,6 +1146,10 @@ def get_unmaterialized_node_flop_counts( For arrays whose index lambda form contains :class:`pymbolic.primitives.If`, this function assumes a SIMT-like model of computation in which the per-entry cost is the sum of the costs of the two branches. + + .. note:: + + This *does not* descend into functions. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) From f6cbcd88003af0d4d41be337ad767beeed627b9f Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 8 Dec 2025 16:57:43 -0600 Subject: [PATCH 08/14] don't try to count flops for function calls --- .basedpyright/baseline.json | 8 ---- pytato/analysis/__init__.py | 81 ++++++++++++-------------------- pytato/transform/calls.py | 94 +------------------------------------ test/test_pytato.py | 19 -------- 4 files changed, 32 insertions(+), 170 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 3aa81747c..c4960237a 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -7979,14 +7979,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 42, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a4d1127de..fac0754c8 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -846,15 +846,10 @@ class MaterializedNodeFlopCounter(CachedWalkMapper[[]]): def __init__( self, op_name_to_num_flops: Mapping[str, int], - _visited_functions: set[VisitKeyT] | None = None, - _function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] | None = None ) -> None: - super().__init__(_visited_functions=_visited_functions) + super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops self.materialized_node_to_nflops: dict[Array, ArrayOrScalar] = {} - self.call_to_nflops: dict[Call, ArrayOrScalar] = {} - self._function_to_nflops: dict[FunctionDefinition, ArrayOrScalar] = \ - _function_to_nflops if _function_to_nflops is not None else {} self._per_entry_flop_counter: _PerEntryFlopCounter = _PerEntryFlopCounter( self.op_name_to_num_flops) @@ -862,50 +857,25 @@ def __init__( def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: return expr - @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: - return expr - @override def clone_for_callee(self, function: FunctionDefinition) -> Self: - return type(self)( - op_name_to_num_flops=self.op_name_to_num_flops, - _visited_functions=self._visited_functions, - _function_to_nflops=self._function_to_nflops) + raise AssertionError("Control shouldn't reach this point.") @override def map_function_definition(self, expr: FunctionDefinition) -> None: if not self.visit(expr): return - new_mapper = self.clone_for_callee(expr) - for subexpr in expr.returns.values(): - # Assume that any calls that haven't been inlined have their functions' - # outputs materialized - assert not _is_unmaterialized(subexpr) - new_mapper(subexpr) - - self._function_to_nflops[expr] = ( - sum(new_mapper.materialized_node_to_nflops.values()) - + sum(new_mapper.call_to_nflops.values())) - - self.post_visit(expr) + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def map_call(self, expr: Call) -> None: if not self.visit(expr): return - self.rec_function_definition(expr.function) - for bnd in expr.bindings.values(): - # Assume that any calls that haven't been inlined have their inputs - # materialized - assert not _is_unmaterialized(bnd) - self.rec(bnd) - - self.call_to_nflops[expr] = self._function_to_nflops[expr.function] - - self.post_visit(expr) + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: @@ -968,9 +938,8 @@ class UnmaterializedNodeFlopCounter(CachedWalkMapper[[]]): """ def __init__( self, - op_name_to_num_flops: Mapping[str, int], - _visited_functions: set[VisitKeyT] | None = None) -> None: - super().__init__(_visited_functions=_visited_functions) + op_name_to_num_flops: Mapping[str, int]) -> None: + super().__init__() self.op_name_to_num_flops: Mapping[str, int] = op_name_to_num_flops self.unmaterialized_node_to_flop_counts: \ dict[Array, UnmaterializedNodeFlopCounts] = {} @@ -982,8 +951,24 @@ def get_cache_key(self, expr: ArrayOrNames) -> VisitKeyT: return expr @override - def get_function_definition_cache_key(self, expr: FunctionDefinition) -> VisitKeyT: - return expr + def clone_for_callee(self, function: FunctionDefinition) -> Self: + raise AssertionError("Control shouldn't reach this point.") + + @override + def map_function_definition(self, expr: FunctionDefinition) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") + + @override + def map_call(self, expr: Call) -> None: + if not self.visit(expr): + return + + raise NotImplementedError( + f"{type(self).__name__} does not support functions.") @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: @@ -1010,10 +995,6 @@ def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: # FIXME: Should this be added to normalize_outputs? def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - # Make sure call bindings and function results are materialized - from pytato.transform.calls import normalize_calls - expr = normalize_calls(expr) - # Make sure outputs are materialized if isinstance(expr, DictOfNamedArrays): output_to_materialized_output: dict[Array, Array] = { @@ -1071,7 +1052,7 @@ def get_num_flops( .. note:: - This *does* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1083,9 +1064,7 @@ def get_num_flops( fc = MaterializedNodeFlopCounter(op_name_to_num_flops) fc(expr) - return ( - sum(fc.materialized_node_to_nflops.values()) - + sum(fc.call_to_nflops.values())) + return sum(fc.materialized_node_to_nflops.values()) def get_materialized_node_flop_counts( @@ -1108,7 +1087,7 @@ def get_materialized_node_flop_counts( .. note:: - This *does not* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) @@ -1149,7 +1128,7 @@ def get_unmaterialized_node_flop_counts( .. note:: - This *does not* descend into functions. + Does not support functions. Function calls must be inlined before calling. """ from pytato.codegen import normalize_outputs expr = normalize_outputs(expr) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index 16c4413b2..ffad1e5f7 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -1,7 +1,6 @@ """ .. currentmodule:: pytato.transform.calls -.. autofunction:: normalize_calls .. autofunction:: inline_calls .. autofunction:: tag_all_calls_to_be_inlined """ @@ -33,26 +32,20 @@ from typing import TYPE_CHECKING, cast -from immutabledict import immutabledict -from typing_extensions import Never, Self, override +from typing_extensions import Self from pytato.array import ( AbstractResultWithNamedArrays, Array, - DataWrapper, DictOfNamedArrays, Placeholder, - SizeParam, - make_dict_of_named_arrays, ) from pytato.function import Call, FunctionDefinition, NamedCallResult -from pytato.tags import ImplStored, InlineCallTag +from pytato.tags import InlineCallTag from pytato.transform import ( ArrayOrNames, ArrayOrNamesTc, - CombineMapper, CopyMapper, - Mapper, TransformMapperCache, _verify_is_array, deduplicate, @@ -63,89 +56,6 @@ from collections.abc import Mapping -# {{{ normalizing - -class _LocalStackCallBindingCollector(CombineMapper[frozenset[Array], Never]): - """Mapper to collect bindings of calls on the current call stack.""" - @override - def combine(self, *args: frozenset[Array]) -> frozenset[Array]: - from functools import reduce - return reduce(lambda a, b: a | b, args, cast("frozenset[Array]", frozenset())) - - @override - def map_call(self, expr: Call) -> frozenset[Array]: - return frozenset(expr.bindings.values()) - - -class _CallMaterializer(CopyMapper): - """Mapper to add materialization tags for call bindings and function results.""" - def __init__( - self, - local_call_bindings: frozenset[Array], - _cache: TransformMapperCache[ArrayOrNames, []] | None = None, - _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None - ) -> None: - super().__init__(_cache=_cache, _function_cache=_function_cache) - self.local_call_bindings: frozenset[Array] = local_call_bindings - - @override - def clone_for_callee(self, function: FunctionDefinition) -> Self: - """ - Called to clone *self* before starting traversal of a - :class:`pytato.function.FunctionDefinition`. - """ - local_call_bindings = _LocalStackCallBindingCollector()( - make_dict_of_named_arrays(function.returns)) - return type(self)( - local_call_bindings, - _function_cache=cast( - "TransformMapperCache[FunctionDefinition, []]", self._function_cache)) - - def _materialize_if_possible(self, expr: ArrayOrNames) -> ArrayOrNames: - if ( - isinstance(expr, Array) - and not isinstance(expr, - (DataWrapper, Placeholder, SizeParam, NamedCallResult))): - return expr.tagged(ImplStored()) - else: - return expr - - @override - def map_function_definition(self, - expr: FunctionDefinition) -> FunctionDefinition: - new_mapper = self.clone_for_callee(expr) - new_returns: Mapping[str, Array] = immutabledict({ - name: self._materialize_if_possible(_verify_is_array(new_mapper(ret))) - for name, ret in expr.returns.items()}) - return expr.replace_if_different(returns=new_returns) - - @override - def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - inputs = self._make_cache_inputs(expr) - try: - return self._cache_retrieve(inputs) - except KeyError: - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = cast("ArrayOrNames", Mapper.rec(self, expr)) - if expr in self.local_call_bindings: - result = self._materialize_if_possible(result) - return self._cache_add(inputs, result) - - -def normalize_calls(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: - """ - Ensure that calls/functions are defined uniformly. - - Adds any missing materialization tags for call bindings and function results. - """ - local_call_bindings = _LocalStackCallBindingCollector()(expr) - return _CallMaterializer(local_call_bindings)(expr) - -# }}} - - # {{{ inlining class PlaceholderSubstitutor(CopyMapper): diff --git a/test/test_pytato.py b/test/test_pytato.py index 1c5ca303a..636932d75 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -944,25 +944,6 @@ def test_flop_count(): # }}} - # {{{ function call - - x = pt.make_placeholder("x", (10, 4)) - y = pt.make_placeholder("y", (10, 4)) - - def f(x, y): - z = x + y - return 2*z, 3*z - - u, v = pt.trace_call(f, x, y) - expr = u - v - - # u[i, j] = 2*(x[i, j] + y[i, j]) - # v[i, j] = 3*(x[i, j] + y[i, j]) - # expr[i, j] = u[i, j] + (-1)*v[i, j] - assert get_num_flops(expr) == 40*2 + 40*2 + 40*2 - - # }}} - # {{{ einsum x = pt.make_placeholder("x", (2, 3, 4)) From 8aeaa6ff6c8707b462aeb0c86b3acdbdbb988cb0 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 9 Dec 2025 13:33:13 -0600 Subject: [PATCH 09/14] move own flop counting out of rec and into its own method in _PerEntryFlopCounter --- .basedpyright/baseline.json | 4 ++-- pytato/analysis/__init__.py | 31 ++++++++++++++++++------------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index c4960237a..bda23a2bf 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2020,8 +2020,8 @@ { "code": "reportUnknownMemberType", "range": { - "startColumn": 51, - "endColumn": 61, + "startColumn": 34, + "endColumn": 44, "lineCount": 1 } }, diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fac0754c8..00fcb0b90 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -804,6 +804,18 @@ def __init__(self, op_name_to_num_flops: Mapping[str, int]) -> None: def combine(self, *args: int) -> int: return sum(args) + def _get_own_flop_count(self, expr: Array) -> int: + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + if not isinstance(nflops, int): + from pytato.scalar_expr import InputGatherer as ScalarInputGatherer + var_names: set[str] = set(ScalarInputGatherer()(nflops)) + var_names.discard("nflops") + if var_names: + raise UndefinedOpFlopCountError(next(iter(var_names))) from None + else: + raise AssertionError from None + return nflops + @override def rec(self, expr: ArrayOrNames) -> int: inputs = self._make_cache_inputs(expr) @@ -813,19 +825,12 @@ def rec(self, expr: ArrayOrNames) -> int: result: int if _is_unmaterialized(expr): assert isinstance(expr, Array) - self_nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - if not isinstance(self_nflops, int): - from pytato.scalar_expr import InputGatherer as ScalarInputGatherer - var_names: set[str] = set(ScalarInputGatherer()(self_nflops)) - var_names.discard("nflops") - if var_names: - raise UndefinedOpFlopCountError(next(iter(var_names))) from None - else: - raise AssertionError from None - # Intentionally going to Mapper instead of super() to avoid - # double caching when subclasses of CachedMapper override rec, - # see https://github.com/inducer/pytato/pull/585 - result = self_nflops + cast("int", Mapper.rec(self, expr)) + result = ( + self._get_own_flop_count(expr) + # Intentionally going to Mapper instead of super() to avoid + # double caching when subclasses of CachedMapper override rec, + # see https://github.com/inducer/pytato/pull/585 + + cast("int", Mapper.rec(self, expr))) else: result = 0 if isinstance(expr, Array): From f53c7a3fa509e410c63338369fb9a3591707bab6 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Tue, 9 Dec 2025 13:27:46 -0600 Subject: [PATCH 10/14] change is_materializable -> is_materialized / has_taggable_materialization --- pytato/analysis/__init__.py | 47 +++++++++++++++++-------------------- pytato/utils.py | 37 +++++++++++++++++++++++------ test/test_pytato.py | 6 ++++- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 00fcb0b90..9ca12937a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -51,6 +51,7 @@ ShapeType, Stack, ) +from pytato.diagnostic import CannotBeLoweredToIndexLambda from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.scalar_expr import ( FlopCounter as ScalarFlopCounter, @@ -66,7 +67,7 @@ map_and_copy, ) from pytato.transform.lower_to_index_lambda import to_index_lambda -from pytato.utils import is_materializable +from pytato.utils import has_taggable_materialization, is_materialized if TYPE_CHECKING: @@ -776,18 +777,6 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # {{{ flop counting -def _is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - return ( - is_materializable(expr) - and bool(expr.tags_of_type(ImplStored))) - - -def _is_unmaterialized(expr: ArrayOrNames | FunctionDefinition) -> bool: - return ( - is_materializable(expr) - and not bool(expr.tags_of_type(ImplStored))) - - @dataclass class UndefinedOpFlopCountError(ValueError): op_name: str @@ -805,7 +794,10 @@ def combine(self, *args: int) -> int: return sum(args) def _get_own_flop_count(self, expr: Array) -> int: - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + try: + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) + except CannotBeLoweredToIndexLambda: + nflops = 0 if not isinstance(nflops, int): from pytato.scalar_expr import InputGatherer as ScalarInputGatherer var_names: set[str] = set(ScalarInputGatherer()(nflops)) @@ -823,8 +815,7 @@ def rec(self, expr: ArrayOrNames) -> int: return self._cache_retrieve(inputs) except KeyError: result: int - if _is_unmaterialized(expr): - assert isinstance(expr, Array) + if isinstance(expr, Array) and not is_materialized(expr): result = ( self._get_own_flop_count(expr) # Intentionally going to Mapper instead of super() to avoid @@ -884,14 +875,16 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not _is_materialized(expr): + if not is_materialized(expr): return assert isinstance(expr, Array) - unmaterialized_expr = expr.without_tags(ImplStored()) - self._per_entry_flop_counter(unmaterialized_expr) - self.materialized_node_to_nflops[expr] = ( - product(expr.shape) - * self._per_entry_flop_counter.node_to_nflops[unmaterialized_expr]) + if has_taggable_materialization(expr): + unmaterialized_expr = expr.without_tags(ImplStored()) + self.materialized_node_to_nflops[expr] = ( + product(expr.shape) + * self._per_entry_flop_counter(unmaterialized_expr)) + else: + self.materialized_node_to_nflops[expr] = 0 class _UnmaterializedSubexpressionUseCounter(CombineMapper[dict[Array, int], Never]): @@ -910,8 +903,7 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]: return self._cache_retrieve(inputs) except KeyError: result: dict[Array, int] - if _is_unmaterialized(expr): - assert isinstance(expr, Array) + if isinstance(expr, Array) and not is_materialized(expr): # Intentionally going to Mapper instead of super() to avoid # double caching when subclasses of CachedMapper override rec, # see https://github.com/inducer/pytato/pull/585 @@ -977,7 +969,7 @@ def map_call(self, expr: Call) -> None: @override def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: - if not _is_materialized(expr): + if not is_materialized(expr) or not has_taggable_materialization(expr): return assert isinstance(expr, Array) unmaterialized_expr = expr.without_tags(ImplStored()) @@ -1003,7 +995,10 @@ def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc: # Make sure outputs are materialized if isinstance(expr, DictOfNamedArrays): output_to_materialized_output: dict[Array, Array] = { - ary: ary.tagged(ImplStored()) if is_materializable(ary) else ary + ary: ( + ary.tagged(ImplStored()) + if has_taggable_materialization(ary) + else ary) for ary in expr._data.values()} def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames: diff --git a/pytato/utils.py b/pytato/utils.py index 92c1a1cbb..d5d377966 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -82,7 +82,8 @@ .. autofunction:: dim_to_index_lambda_components .. autofunction:: get_common_dtype_of_ary_or_scalars .. autofunction:: get_einsum_subscript_str -.. autofunction:: is_materializable +.. autofunction:: is_materialized +.. autofunction:: has_taggable_materialization References ^^^^^^^^^^ @@ -740,18 +741,40 @@ def get_einsum_specification(expr: Einsum) -> str: return f"{','.join(input_specs)}->{output_spec}" -def is_materializable(expr: ArrayOrNames | FunctionDefinition) -> bool: +def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: + """Returns *True* if *expr* is materialized.""" + from pytato.array import InputArgumentBase + from pytato.distributed.nodes import DistributedRecv + from pytato.tags import ImplStored + return ( + ( + isinstance(expr, Array) + and bool(expr.tags_of_type(ImplStored))) + or isinstance( + expr, + ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, + DistributedRecv))) + + +def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool: """ - Returns *True* if *expr* is an instance of an array type that can be materialized. + Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to + determine whether or not it is materialized. """ from pytato.array import InputArgumentBase, NamedArray from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder return ( isinstance(expr, Array) - and not isinstance(expr, ( - # FIXME: Is there a nice way to generalize this? - InputArgumentBase, NamedArray, DistributedRecv, - DistributedSendRefHolder))) + and not isinstance( + expr, + ( + # FIXME: Is there a nice way to generalize this? + InputArgumentBase, + DistributedRecv, + NamedArray, + DistributedSendRefHolder))) # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 636932d75..56716132a 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -979,9 +979,13 @@ def test_materialized_node_flop_counts(): # z[i, j] = x[i, j] + y[i, j] # expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j] - assert len(materialized_node_to_flop_count) == 2 + assert len(materialized_node_to_flop_count) == 4 + assert x in materialized_node_to_flop_count + assert y in materialized_node_to_flop_count assert z in materialized_node_to_flop_count assert expr.tagged(ImplStored()) in materialized_node_to_flop_count + assert materialized_node_to_flop_count[x] == 0 + assert materialized_node_to_flop_count[y] == 0 assert materialized_node_to_flop_count[z] == 40 assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4 From 6f072458fdda1cd5aca194caf4c6286017a94fcc Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 15 Dec 2025 10:54:02 -0600 Subject: [PATCH 11/14] use explicit isinstance() check instead of try/except around to_index_lambda() --- pytato/analysis/__init__.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 9ca12937a..8aec9c5c3 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -41,6 +41,7 @@ Array, ArrayOrScalar, Concatenate, + DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -48,10 +49,11 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Placeholder, ShapeType, Stack, ) -from pytato.diagnostic import CannotBeLoweredToIndexLambda +from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.scalar_expr import ( FlopCounter as ScalarFlopCounter, @@ -75,7 +77,6 @@ import pytools.tag - from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder from pytato.loopy import LoopyCall __doc__ = """ @@ -794,10 +795,16 @@ def combine(self, *args: int) -> int: return sum(args) def _get_own_flop_count(self, expr: Array) -> int: - try: - nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) - except CannotBeLoweredToIndexLambda: - nflops = 0 + if isinstance( + expr, + ( + DataWrapper, + Placeholder, + NamedArray, + DistributedRecv, + DistributedSendRefHolder)): + return 0 + nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) if not isinstance(nflops, int): from pytato.scalar_expr import InputGatherer as ScalarInputGatherer var_names: set[str] = set(ScalarInputGatherer()(nflops)) From e209d6a4fe4b08fa00739df48d5280f8e4adbeea Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 15 Dec 2025 11:45:20 -0600 Subject: [PATCH 12/14] remove a couple of FIXMEs --- pytato/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytato/utils.py b/pytato/utils.py index d5d377966..f2b50791c 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -753,7 +753,6 @@ def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: or isinstance( expr, ( - # FIXME: Is there a nice way to generalize this? InputArgumentBase, DistributedRecv))) @@ -770,7 +769,6 @@ def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> boo and not isinstance( expr, ( - # FIXME: Is there a nice way to generalize this? InputArgumentBase, DistributedRecv, NamedArray, From 666ca2a1ea7099e9e3eb315b40aa764183038959 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 26 Jan 2026 12:27:00 -0600 Subject: [PATCH 13/14] add placeholder class for operator flop counts that aren't specified --- pytato/analysis/__init__.py | 15 +++++++++------ pytato/scalar_expr.py | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 8aec9c5c3..056e16d99 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -806,13 +806,16 @@ def _get_own_flop_count(self, expr: Array) -> int: return 0 nflops = self.scalar_flop_counter(to_index_lambda(expr).expr) if not isinstance(nflops, int): - from pytato.scalar_expr import InputGatherer as ScalarInputGatherer - var_names: set[str] = set(ScalarInputGatherer()(nflops)) - var_names.discard("nflops") - if var_names: - raise UndefinedOpFlopCountError(next(iter(var_names))) from None + # Restricting to numerical result here because the flop counters that use + # this mapper subsequently multiply the result by things that are + # potentially arrays (e.g., shape components), and arrays and scalar + # expressions are not interoperable + from pytato.scalar_expr import OpFlops, OpFlopsCollector + op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops) + if op_flops: + raise UndefinedOpFlopCountError(next(iter(op_flops)).op) else: - raise AssertionError from None + raise AssertionError return nflops @override diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index e78d1ce43..e7a796faa 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -44,6 +44,7 @@ import re from collections.abc import Iterable, Mapping, Set as AbstractSet +from functools import reduce from typing import ( TYPE_CHECKING, Any, @@ -266,8 +267,7 @@ def _get_op_nflops(self, name: str) -> ArithmeticExpression: try: return self.op_name_to_num_flops[name] except KeyError: - from pymbolic import var - result = var("nflops")(var(name)) + result = OpFlops(name) self.op_name_to_num_flops[name] = result return result @@ -480,9 +480,42 @@ class TypeCast(ExpressionBase): dtype: np.dtype[Any] inner_expr: ScalarExpression + +@expr_dataclass() +class OpFlops(prim.AlgebraicLeaf): + """ + Placeholder flop count for an operator. + + .. autoattribute:: op + """ + op: str + # }}} +class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]): + """ + Constructs a :class:`frozenset` containing all instances of + :class:`pytato.scalar_expr.OpFlops` found in a scalar expression. + """ + @override + def combine( + self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]: + return reduce( + lambda x, y: x.union(y), + values, + cast("frozenset[OpFlops]", frozenset())) + + @override + def map_algebraic_leaf( + self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]: + return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset() + + @override + def map_constant(self, expr: object) -> frozenset[OpFlops]: + return frozenset() + + class InductionVariableCollector(CombineMapper[AbstractSet[str], []]): def combine(self, values: Iterable[AbstractSet[str]]) -> frozenset[str]: from functools import reduce From 0655a0ce87ffbd4e2b718368099b949b3234acbf Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Mon, 15 Dec 2025 12:00:34 -0600 Subject: [PATCH 14/14] add ArrayOrNamesOrFunctionDef --- pytato/analysis/__init__.py | 27 ++++++++++++++------------- pytato/distributed/verify.py | 9 ++++++--- pytato/equality.py | 6 ++++-- pytato/transform/__init__.py | 11 +++++++---- pytato/utils.py | 8 +++----- pytato/visualization/dot.py | 15 ++++++++++----- 6 files changed, 44 insertions(+), 32 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 056e16d99..13a9bfd81 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -61,6 +61,7 @@ from pytato.tags import ImplStored from pytato.transform import ( ArrayOrNames, + ArrayOrNamesOrFunctionDef, ArrayOrNamesTc, CachedWalkMapper, CombineMapper, @@ -362,7 +363,7 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: class ListOfDirectPredecessorsGetter( Mapper[ - list[ArrayOrNames | FunctionDefinition], + list[ArrayOrNamesOrFunctionDef], list[ArrayOrNames], []]): """ @@ -445,8 +446,8 @@ def map_distributed_send_ref_holder(self, return [expr.send.data, expr.passthrough_data] def map_call( - self, expr: Call) -> list[ArrayOrNames | FunctionDefinition]: - result: list[ArrayOrNames | FunctionDefinition] = [] + self, expr: Call) -> list[ArrayOrNamesOrFunctionDef]: + result: list[ArrayOrNamesOrFunctionDef] = [] if self.include_functions: result.append(expr.function) result += list(expr.bindings.values()) @@ -483,7 +484,7 @@ def __init__(self, *, include_functions: bool = False) -> None: @overload def __call__( self, expr: ArrayOrNames - ) -> FrozenOrderedSet[ArrayOrNames | FunctionDefinition]: + ) -> FrozenOrderedSet[ArrayOrNamesOrFunctionDef]: ... @overload @@ -492,9 +493,9 @@ def __call__(self, expr: FunctionDefinition) -> FrozenOrderedSet[ArrayOrNames]: def __call__( self, - expr: ArrayOrNames | FunctionDefinition, + expr: ArrayOrNamesOrFunctionDef, ) -> ( - FrozenOrderedSet[ArrayOrNames | FunctionDefinition] + FrozenOrderedSet[ArrayOrNamesOrFunctionDef] | FrozenOrderedSet[ArrayOrNames]): """Get the direct predecessors of *expr*.""" return FrozenOrderedSet(self._pred_getter(expr)) @@ -543,7 +544,7 @@ def clone_for_callee(self, function: FunctionDefinition) -> Self: _visited_functions=self._visited_functions) @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 @@ -606,7 +607,7 @@ def __init__(self, _visited_functions: set[Any] | None = None) -> None: super().__init__(_visited_functions=_visited_functions) self.expr_multiplicity_counts: \ - dict[ArrayOrNames | FunctionDefinition, int] = defaultdict(int) + dict[ArrayOrNamesOrFunctionDef, int] = defaultdict(int) @override def get_cache_key(self, expr: ArrayOrNames) -> int: @@ -619,13 +620,13 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_multiplicity_counts[expr] += 1 def get_node_multiplicities( - outputs: ArrayOrNames) -> dict[ArrayOrNames | FunctionDefinition, int]: + outputs: ArrayOrNames) -> dict[ArrayOrNamesOrFunctionDef, int]: """ Returns the multiplicity per `expr`. """ @@ -662,7 +663,7 @@ def get_function_definition_cache_key(self, expr: FunctionDefinition) -> int: return id(expr) @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if isinstance(expr, Call): self.count += 1 @@ -884,7 +885,7 @@ def map_call(self, expr: Call) -> None: f"{type(self).__name__} does not support functions.") @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if not is_materialized(expr): return assert isinstance(expr, Array) @@ -978,7 +979,7 @@ def map_call(self, expr: Call) -> None: f"{type(self).__name__} does not support functions.") @override - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if not is_materialized(expr) or not has_taggable_materialization(expr): return assert isinstance(expr, Array) diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index c7defba6f..69bb9d6d5 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -55,7 +55,11 @@ DistributedGraphPartition, PartId, ) -from pytato.transform import ArrayOrNames, CachedWalkMapper +from pytato.transform import ( + ArrayOrNames, + ArrayOrNamesOrFunctionDef, + CachedWalkMapper, +) logger = logging.getLogger(__name__) @@ -68,7 +72,6 @@ import numpy as np from pytato.distributed.nodes import CommTagType, DistributedRecv - from pytato.function import FunctionDefinition # {{{ data structures @@ -156,7 +159,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @override - def visit(self, expr: ArrayOrNames | FunctionDefinition) -> bool: + def visit(self, expr: ArrayOrNamesOrFunctionDef) -> bool: super().visit(expr) if isinstance(expr, ArrayOrNames): self.seen_nodes.add(expr) diff --git a/pytato/equality.py b/pytato/equality.py index 837530f02..1976be49d 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -63,6 +63,8 @@ ArrayOrNames = Array | AbstractResultWithNamedArrays +ArrayOrNamesOrFunctionDef = \ + Array | AbstractResultWithNamedArrays | FunctionDefinition # {{{ EqualityComparer @@ -87,7 +89,7 @@ def __init__(self) -> None: # Uses the same cache for both arrays and functions self._cache: dict[tuple[int, int], bool] = {} - def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool: + def rec(self, expr1: ArrayOrNamesOrFunctionDef, expr2: object) -> bool: # These cases are simple enough that they don't need to be cached if expr1 is expr2: return True @@ -119,7 +121,7 @@ def rec(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool: self._cache[cache_key] = result return result - def __call__(self, expr1: ArrayOrNames | FunctionDefinition, expr2: object) -> bool: + def __call__(self, expr1: ArrayOrNamesOrFunctionDef, expr2: object) -> bool: return self.rec(expr1, expr2) def handle_unsupported_array(self, expr1: Array, diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 0ed13a63f..27823a746 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -89,6 +89,8 @@ ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays +ArrayOrNamesOrFunctionDef: TypeAlias = \ + Array | AbstractResultWithNamedArrays | FunctionDefinition ArrayOrNamesTc = TypeVar("ArrayOrNamesTc", Array, AbstractResultWithNamedArrays, DictOfNamedArrays) ArrayOrNamesOrFunctionDefTc = TypeVar("ArrayOrNamesOrFunctionDefTc", @@ -150,6 +152,7 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. class:: ArrayOrNames +.. class:: ArrayOrNamesOrFunctionDef .. class:: ArrayOrNamesTc @@ -307,7 +310,7 @@ def __call__( def __call__( self, - expr: ArrayOrNames | FunctionDefinition, + expr: ArrayOrNamesOrFunctionDef, *args: P.args, **kwargs: P.kwargs) -> ResultT | FunctionResultT: """Handle the mapping of *expr*.""" @@ -1569,7 +1572,7 @@ def clone_for_callee( return type(self)() def visit( - self, expr: ArrayOrNames | FunctionDefinition, + self, expr: ArrayOrNamesOrFunctionDef, *args: P.args, **kwargs: P.kwargs) -> bool: """ If this method returns *True*, *expr* is traversed during the walk. @@ -1579,7 +1582,7 @@ def visit( return True def post_visit( - self, expr: ArrayOrNames | FunctionDefinition, + self, expr: ArrayOrNamesOrFunctionDef, *args: P.args, **kwargs: P.kwargs) -> None: """ Callback after *expr* has been traversed. @@ -1841,7 +1844,7 @@ def __init__( def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None: + def post_visit(self, expr: ArrayOrNamesOrFunctionDef) -> None: if isinstance(expr, Array): self.topological_order.append(expr) diff --git a/pytato/utils.py b/pytato/utils.py index f2b50791c..333929d8b 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -61,7 +61,7 @@ ScalarExpression, TypeCast, ) -from pytato.transform import ArrayOrNames, CachedMapper +from pytato.transform import ArrayOrNamesOrFunctionDef, CachedMapper if TYPE_CHECKING: @@ -69,8 +69,6 @@ from pytools.tag import Tag - from pytato.function import FunctionDefinition - __doc__ = """ Helper routines @@ -741,7 +739,7 @@ def get_einsum_specification(expr: Einsum) -> str: return f"{','.join(input_specs)}->{output_spec}" -def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: +def is_materialized(expr: ArrayOrNamesOrFunctionDef) -> bool: """Returns *True* if *expr* is materialized.""" from pytato.array import InputArgumentBase from pytato.distributed.nodes import DistributedRecv @@ -757,7 +755,7 @@ def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool: DistributedRecv))) -def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool: +def has_taggable_materialization(expr: ArrayOrNamesOrFunctionDef) -> bool: """ Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to determine whether or not it is materialized. diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 5b39f83bf..6dcab1d1b 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -59,7 +59,12 @@ ) from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.tags import FunctionIdentifier -from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer +from pytato.transform import ( + ArrayOrNames, + ArrayOrNamesOrFunctionDef, + CachedMapper, + InputGatherer, +) if TYPE_CHECKING: @@ -160,7 +165,7 @@ def emit_subgraph(sg: _SubgraphTree) -> None: class _DotNodeInfo: title: str fields: dict[str, Any] - edges: dict[str, ArrayOrNames | FunctionDefinition] + edges: dict[str, ArrayOrNamesOrFunctionDef] def stringify_tags(tags: frozenset[Tag | None]) -> str: @@ -193,7 +198,7 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: "non_equality_tags": expr.non_equality_tags, } - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNamesOrFunctionDef] = {} return _DotNodeInfo(title, fields, edges) # type-ignore-reason: incompatible with supertype @@ -297,7 +302,7 @@ def map_einsum(self, expr: Einsum) -> None: self.node_to_dot[expr] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNamesOrFunctionDef] = {} for name, val in expr._data.items(): edges[name] = val self.rec(val) @@ -308,7 +313,7 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: dict[str, ArrayOrNames | FunctionDefinition] = {} + edges: dict[str, ArrayOrNamesOrFunctionDef] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): edges[name] = arg