More robust diagonal constant check#2173
Conversation
|
I also upgraded the performance for checking diagonality. The cheapest thing me and mr. bot could come up with is to check if the number of nonzero entries in the matrix is equal to the number on the diagonal. If this is true, the matrix must be diagonal. This avoids allocation of a mask matrix, and seems to be a highly optimized numpy function (~2x speedup vs |
ricardoV94
left a comment
There was a problem hiding this comment.
The diagonal seems good, the permutation can be improved a tiny bit more?
| ) | ||
| if data.dtype.kind in "uib": | ||
| # Fast check, only valid for integer/bool | ||
| is_permutation = data.max(initial=1) <= 1 and data.min(initial=0) >= 0 |
There was a problem hiding this comment.
it's enough to check min(initial=0) >= 0 for discrete, and for bool/unsigned we don't even need anything at this point.
| # Otherwise a matrix is permutation iff there is exactly 1 nonzero entry per row | ||
| # and column. That non-zero value can only be 1 due to previous checks. | ||
| n = data.shape[-1] | ||
| is_permutation = np.count_nonzero(data) == (data.size // n if n else 0) |
There was a problem hiding this comment.
This seems slower than the data == 0 | data == 1, and I don't think it would have had issues with inf/nan
There was a problem hiding this comment.
did you bench it? yours is two loops over the data, mine is less than 1
There was a problem hiding this comment.
yes, count_nonzero is doing something dumb perhaps. I can check again
There was a problem hiding this comment.
i spent some time exploring options and count_nonzero was consistently the fastest in my testing
There was a problem hiding this comment.
let's run the same bench script, will send in ~30m
There was a problem hiding this comment.
You were right, it was conflating overall performance with conditional on the branch:
"""Micro-benchmark for PR #2173: the FLOAT worst-case of the constant
permutation/selection checks -- the only path that differs between impls.
For float data, after the row/column sum==1 checks pass we must still confirm
the entries are 0/1:
prior (jessegrabowski 89156a66f): np.count_nonzero(data) == size // n
HEAD (ricardoV94 1cd347d50): ((data == 0) | (data == 1)).all()
bool/int dtypes are not shown: there the new impl strictly drops redundant
reductions, so it is unambiguously faster. The float branch is the one in
question. Both impls are correct; this only measures speed.
"""
import timeit
import numpy as np
def final_old(data):
n = data.shape[-1]
return np.count_nonzero(data) == (data.size // n if n else 0)
def final_new(data):
return bool(((data == 0) | (data == 1)).all())
def bench(fn, data, number):
return min(timeit.repeat(lambda: fn(data), number=number, repeat=7)) / number
def row(label, t_old, t_new):
speedup = t_old / t_new
tag = "new faster" if speedup > 1 else "OLD faster"
print(
f" {label:<40} old={t_old * 1e6:8.2f}us new={t_new * 1e6:8.2f}us "
f"x{speedup:5.2f} ({tag})"
)
def permutation_matrix(n):
return np.asarray(np.eye(n)[np.random.permutation(n)], dtype="float64")
np.random.seed(0)
for n in (10, 100, 1000, 4000):
num = 5000 if n <= 100 else (200 if n <= 1000 else 20)
print(f"-- n={n} ({n * n} entries, float64) --")
# True permutation: the common "is this a permutation?" -> True full check.
perm = permutation_matrix(n)
assert final_old(perm) and final_new(perm)
row("true permutation (->True)", bench(final_old, perm, num), bench(final_new, perm, num))
# Doubly-stochastic: passes sum checks, every entry nonzero & != 0/1 -> False.
# Worst case for count_nonzero (must scan all) and best case for the OR-scan
# short-circuit possibilities (.all() can bail at the first non-binary entry).
ds = np.full((n, n), 1.0 / n, dtype="float64")
assert not final_old(ds) and not final_new(ds)
row("doubly-stochastic (->False)", bench(final_old, ds, num), bench(final_new, ds, num))So on the float branch we should go back to nonzero
There was a problem hiding this comment.
nonzero is back, separate commit
|
I pushed an opt for permutation/selection and also suppress nan/inf during the sum |
1140c55 to
1cd347d
Compare
ricardoV94
left a comment
There was a problem hiding this comment.
Pushed my changes as separate commit. Review / squash before merging
pymc-extras surfaced a small bug in _diagonal_from_constant, which used multiplication by a masking matrix to check for non-zero off-diagonal values. This raises a warning when there are non-finite values (inf/nan) in the matrix being checked. Use indexing instead to avoid the warnings.