diff --git a/src/csrc/scalar_ops.cpp b/src/csrc/scalar_ops.cpp index 8e81910..1e62c40 100644 --- a/src/csrc/scalar_ops.cpp +++ b/src/csrc/scalar_ops.cpp @@ -144,7 +144,7 @@ quad_richcompare(QuadPrecisionObject *self, PyObject *other, int cmp_op) return NULL; } } - else if (PyLong_CheckExact(other) || PyFloat_CheckExact(other)) { + else if (PyLong_Check(other) || PyFloat_Check(other)) { other_quad = QuadPrecision_from_object(other, backend); if (other_quad == NULL) { return NULL; diff --git a/tests/test_quaddtype.py b/tests/test_quaddtype.py index 58a8b7d..2a244ea 100644 --- a/tests/test_quaddtype.py +++ b/tests/test_quaddtype.py @@ -1350,6 +1350,33 @@ def test_comparisons(op, a, b): assert op_func(quad_a, quad_b) == op_func(float_a, float_b) +@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"]) +@pytest.mark.parametrize( + "quad_val, other", + [ + # bool is a subclass of int — exercises PyLong_Check (regression: gh-100) + (1, True), + (0, False), + (1, False), + (0, True), + (2, True), + # np.float64 is a subclass of float — exercises PyFloat_Check + (1, np.float64(1.0)), + (1, np.float64(2.0)), + (0, np.float64(-0.0)), + ], +) +def test_comparisons_with_python_subclasses(op, quad_val, other): + op_func = getattr(operator, op) + quad_a = QuadPrecision(quad_val) + expected = op_func(float(quad_val), float(other)) + + # Forward: QuadPrecision OP subclass-instance + assert op_func(quad_a, other) == expected, f"Failed {op} between QuadPrecision({quad_val}) and {other} (type {type(other)})" + # Reverse: subclass-instance OP QuadPrecision + assert op_func(other, quad_a) == op_func(float(other), float(quad_val)), f"Failed {op} between {other} (type {type(other)}) and QuadPrecision({quad_val})" + + @pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"]) @pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"]) @pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])