Skip to content

More robust diagonal constant check#2173

Open
jessegrabowski wants to merge 6 commits into
pymc-devs:mainfrom
jessegrabowski:symmetric-check-fix
Open

More robust diagonal constant check#2173
jessegrabowski wants to merge 6 commits into
pymc-devs:mainfrom
jessegrabowski:symmetric-check-fix

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

pymc-extras surfaced a small bug in _diagonal_from_constant, which used multiplication by a masking matrix to check for non-zero off-diagonal values. This raises a warning when there are non-finite values (inf/nan) in the matrix being checked. Use indexing instead to avoid the warnings.

@jessegrabowski jessegrabowski requested a review from ricardoV94 May 26, 2026 17:48
@jessegrabowski
Copy link
Copy Markdown
Member Author

I also upgraded the performance for checking diagonality. The cheapest thing me and mr. bot could come up with is to check if the number of nonzero entries in the matrix is equal to the number on the diagonal. If this is true, the matrix must be diagonal. This avoids allocation of a mask matrix, and seems to be a highly optimized numpy function (~2x speedup vs np.any(x, where=mask) even excluding the cost of building the mask)

@ricardoV94 ricardoV94 changed the title More robust digonal constant check More robust diagonal constant check May 28, 2026
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diagonal seems good, the permutation can be improved a tiny bit more?

Comment thread pytensor/assumptions/permutation.py Outdated
)
if data.dtype.kind in "uib":
# Fast check, only valid for integer/bool
is_permutation = data.max(initial=1) <= 1 and data.min(initial=0) >= 0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's enough to check min(initial=0) >= 0 for discrete, and for bool/unsigned we don't even need anything at this point.

Comment thread pytensor/assumptions/permutation.py Outdated
# Otherwise a matrix is permutation iff there is exactly 1 nonzero entry per row
# and column. That non-zero value can only be 1 due to previous checks.
n = data.shape[-1]
is_permutation = np.count_nonzero(data) == (data.size // n if n else 0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems slower than the data == 0 | data == 1, and I don't think it would have had issues with inf/nan

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you bench it? yours is two loops over the data, mine is less than 1

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, count_nonzero is doing something dumb perhaps. I can check again

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i spent some time exploring options and count_nonzero was consistently the fastest in my testing

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's run the same bench script, will send in ~30m

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were right, it was conflating overall performance with conditional on the branch:

"""Micro-benchmark for PR #2173: the FLOAT worst-case of the constant
permutation/selection checks -- the only path that differs between impls.

For float data, after the row/column sum==1 checks pass we must still confirm
the entries are 0/1:

  prior (jessegrabowski 89156a66f): np.count_nonzero(data) == size // n
  HEAD  (ricardoV94 1cd347d50):     ((data == 0) | (data == 1)).all()

bool/int dtypes are not shown: there the new impl strictly drops redundant
reductions, so it is unambiguously faster. The float branch is the one in
question. Both impls are correct; this only measures speed.
"""

import timeit

import numpy as np


def final_old(data):
    n = data.shape[-1]
    return np.count_nonzero(data) == (data.size // n if n else 0)


def final_new(data):
    return bool(((data == 0) | (data == 1)).all())


def bench(fn, data, number):
    return min(timeit.repeat(lambda: fn(data), number=number, repeat=7)) / number


def row(label, t_old, t_new):
    speedup = t_old / t_new
    tag = "new faster" if speedup > 1 else "OLD faster"
    print(
        f"  {label:<40} old={t_old * 1e6:8.2f}us  new={t_new * 1e6:8.2f}us  "
        f"x{speedup:5.2f}  ({tag})"
    )


def permutation_matrix(n):
    return np.asarray(np.eye(n)[np.random.permutation(n)], dtype="float64")


np.random.seed(0)

for n in (10, 100, 1000, 4000):
    num = 5000 if n <= 100 else (200 if n <= 1000 else 20)
    print(f"-- n={n} ({n * n} entries, float64) --")

    # True permutation: the common "is this a permutation?" -> True full check.
    perm = permutation_matrix(n)
    assert final_old(perm) and final_new(perm)
    row("true permutation (->True)", bench(final_old, perm, num), bench(final_new, perm, num))

    # Doubly-stochastic: passes sum checks, every entry nonzero & != 0/1 -> False.
    # Worst case for count_nonzero (must scan all) and best case for the OR-scan
    # short-circuit possibilities (.all() can bail at the first non-binary entry).
    ds = np.full((n, n), 1.0 / n, dtype="float64")
    assert not final_old(ds) and not final_new(ds)
    row("doubly-stochastic (->False)", bench(final_old, ds, num), bench(final_new, ds, num))

So on the float branch we should go back to nonzero

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nonzero is back, separate commit

@ricardoV94
Copy link
Copy Markdown
Member

I pushed an opt for permutation/selection and also suppress nan/inf during the sum

@ricardoV94 ricardoV94 force-pushed the symmetric-check-fix branch from 1140c55 to 1cd347d Compare May 28, 2026 09:11
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed my changes as separate commit. Review / squash before merging

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants