-
-
Notifications
You must be signed in to change notification settings - Fork 6
[FEAT] Adding vecdot implementation
#86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -315,6 +315,154 @@ quad_matmul_strided_loop_unaligned(PyArrayMethod_Context *context, char *const d | |||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // vecdot: signature (n),(n)->() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| static int | ||||||||||||||||||||||||||||||||
| quad_vecdot_strided_loop_aligned(PyArrayMethod_Context *context, char *const data[], | ||||||||||||||||||||||||||||||||
| npy_intp const dimensions[], npy_intp const strides[], | ||||||||||||||||||||||||||||||||
| NpyAuxData *auxdata) | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| npy_intp N = dimensions[0]; // outer (broadcast) loop length | ||||||||||||||||||||||||||||||||
| npy_intp n = dimensions[1]; // core dim length | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| npy_intp x_outer_stride = strides[0]; | ||||||||||||||||||||||||||||||||
| npy_intp y_outer_stride = strides[1]; | ||||||||||||||||||||||||||||||||
| npy_intp out_outer_stride = strides[2]; | ||||||||||||||||||||||||||||||||
| npy_intp x_n_stride = strides[3]; | ||||||||||||||||||||||||||||||||
| npy_intp y_n_stride = strides[4]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0]; | ||||||||||||||||||||||||||||||||
| if (descr->backend != BACKEND_SLEEF) { | ||||||||||||||||||||||||||||||||
| PyErr_SetString(PyExc_NotImplementedError, | ||||||||||||||||||||||||||||||||
| "QBLAS-accelerated vecdot only supports SLEEF backend."); | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| char *x = data[0]; | ||||||||||||||||||||||||||||||||
| char *y = data[1]; | ||||||||||||||||||||||||||||||||
| char *out = data[2]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| size_t incx = x_n_stride / sizeof(Sleef_quad); | ||||||||||||||||||||||||||||||||
| size_t incy = y_n_stride / sizeof(Sleef_quad); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for (npy_intp i = 0; i < N; i++) { | ||||||||||||||||||||||||||||||||
| Sleef_quad *x_ptr = (Sleef_quad *)x; | ||||||||||||||||||||||||||||||||
| Sleef_quad *y_ptr = (Sleef_quad *)y; | ||||||||||||||||||||||||||||||||
| Sleef_quad *out_ptr = (Sleef_quad *)out; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if (n == 0) { | ||||||||||||||||||||||||||||||||
| *out_ptr = Sleef_cast_from_doubleq1(0.0); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| else { | ||||||||||||||||||||||||||||||||
| int result = qblas_dot(n, x_ptr, incx, y_ptr, incy, out_ptr); | ||||||||||||||||||||||||||||||||
| if (result != 0) { | ||||||||||||||||||||||||||||||||
| PyErr_SetString(PyExc_RuntimeError, "QBLAS vecdot operation failed"); | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| x += x_outer_stride; | ||||||||||||||||||||||||||||||||
| y += y_outer_stride; | ||||||||||||||||||||||||||||||||
| out += out_outer_stride; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| static int | ||||||||||||||||||||||||||||||||
| quad_vecdot_strided_loop_unaligned(PyArrayMethod_Context *context, char *const data[], | ||||||||||||||||||||||||||||||||
| npy_intp const dimensions[], npy_intp const strides[], | ||||||||||||||||||||||||||||||||
| NpyAuxData *auxdata) | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| npy_intp N = dimensions[0]; | ||||||||||||||||||||||||||||||||
| npy_intp n = dimensions[1]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| npy_intp x_outer_stride = strides[0]; | ||||||||||||||||||||||||||||||||
| npy_intp y_outer_stride = strides[1]; | ||||||||||||||||||||||||||||||||
| npy_intp out_outer_stride = strides[2]; | ||||||||||||||||||||||||||||||||
| npy_intp x_n_stride = strides[3]; | ||||||||||||||||||||||||||||||||
| npy_intp y_n_stride = strides[4]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0]; | ||||||||||||||||||||||||||||||||
| if (descr->backend != BACKEND_SLEEF) { | ||||||||||||||||||||||||||||||||
| PyErr_SetString(PyExc_NotImplementedError, | ||||||||||||||||||||||||||||||||
| "QBLAS-accelerated vecdot only supports SLEEF backend."); | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| char *x = data[0]; | ||||||||||||||||||||||||||||||||
| char *y = data[1]; | ||||||||||||||||||||||||||||||||
| char *out = data[2]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for (npy_intp i = 0; i < N; i++) { | ||||||||||||||||||||||||||||||||
| Sleef_quad sum = Sleef_cast_from_doubleq1(0.0); | ||||||||||||||||||||||||||||||||
| for (npy_intp k = 0; k < n; k++) { | ||||||||||||||||||||||||||||||||
| Sleef_quad a_val, b_val; | ||||||||||||||||||||||||||||||||
| memcpy(&a_val, x + k * x_n_stride, sizeof(Sleef_quad)); | ||||||||||||||||||||||||||||||||
| memcpy(&b_val, y + k * y_n_stride, sizeof(Sleef_quad)); | ||||||||||||||||||||||||||||||||
| sum = Sleef_fmaq1_u05(a_val, b_val, sum); | ||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unaligned matmul does a different thing and calls into numpy-quaddtype/src/csrc/umath/matmul.cpp Lines 234 to 248 in fc52921
Why the difference? Claude seems to think the |
||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| memcpy(out, &sum, sizeof(Sleef_quad)); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| x += x_outer_stride; | ||||||||||||||||||||||||||||||||
| y += y_outer_stride; | ||||||||||||||||||||||||||||||||
| out += out_outer_stride; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| static int | ||||||||||||||||||||||||||||||||
| naive_vecdot_strided_loop(PyArrayMethod_Context *context, char *const data[], | ||||||||||||||||||||||||||||||||
| npy_intp const dimensions[], npy_intp const strides[], | ||||||||||||||||||||||||||||||||
| NpyAuxData *auxdata) | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| npy_intp N = dimensions[0]; | ||||||||||||||||||||||||||||||||
| npy_intp n = dimensions[1]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| npy_intp x_outer_stride = strides[0]; | ||||||||||||||||||||||||||||||||
| npy_intp y_outer_stride = strides[1]; | ||||||||||||||||||||||||||||||||
| npy_intp out_outer_stride = strides[2]; | ||||||||||||||||||||||||||||||||
| npy_intp x_n_stride = strides[3]; | ||||||||||||||||||||||||||||||||
| npy_intp y_n_stride = strides[4]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| QuadPrecDTypeObject *descr = (QuadPrecDTypeObject *)context->descriptors[0]; | ||||||||||||||||||||||||||||||||
| QuadBackendType backend = descr->backend; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| char *x = data[0]; | ||||||||||||||||||||||||||||||||
| char *y = data[1]; | ||||||||||||||||||||||||||||||||
| char *out = data[2]; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for (npy_intp i = 0; i < N; i++) { | ||||||||||||||||||||||||||||||||
| if (backend == BACKEND_SLEEF) { | ||||||||||||||||||||||||||||||||
| Sleef_quad sum = Sleef_cast_from_doubleq1(0.0); | ||||||||||||||||||||||||||||||||
| for (npy_intp k = 0; k < n; k++) { | ||||||||||||||||||||||||||||||||
| Sleef_quad a_val, b_val; | ||||||||||||||||||||||||||||||||
| memcpy(&a_val, x + k * x_n_stride, sizeof(Sleef_quad)); | ||||||||||||||||||||||||||||||||
| memcpy(&b_val, y + k * y_n_stride, sizeof(Sleef_quad)); | ||||||||||||||||||||||||||||||||
| sum = Sleef_fmaq1_u05(a_val, b_val, sum); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| memcpy(out, &sum, sizeof(Sleef_quad)); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| else { | ||||||||||||||||||||||||||||||||
| long double sum = 0.0L; | ||||||||||||||||||||||||||||||||
| for (npy_intp k = 0; k < n; k++) { | ||||||||||||||||||||||||||||||||
| long double a_val, b_val; | ||||||||||||||||||||||||||||||||
| memcpy(&a_val, x + k * x_n_stride, sizeof(long double)); | ||||||||||||||||||||||||||||||||
| memcpy(&b_val, y + k * y_n_stride, sizeof(long double)); | ||||||||||||||||||||||||||||||||
| sum += a_val * b_val; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| memcpy(out, &sum, sizeof(long double)); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| x += x_outer_stride; | ||||||||||||||||||||||||||||||||
| y += y_outer_stride; | ||||||||||||||||||||||||||||||||
| out += out_outer_stride; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| static int | ||||||||||||||||||||||||||||||||
| naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[], | ||||||||||||||||||||||||||||||||
| npy_intp const dimensions[], npy_intp const strides[], | ||||||||||||||||||||||||||||||||
|
|
@@ -385,32 +533,26 @@ naive_matmul_strided_loop(PyArrayMethod_Context *context, char *const data[], | |||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| int | ||||||||||||||||||||||||||||||||
| init_matmul_ops(PyObject *numpy) | ||||||||||||||||||||||||||||||||
| static int | ||||||||||||||||||||||||||||||||
| register_matmul_like_ufunc(PyObject *numpy, const char *ufunc_name, const char *spec_name, | ||||||||||||||||||||||||||||||||
| PyArrayMethod_StridedLoop *aligned_loop, | ||||||||||||||||||||||||||||||||
| PyArrayMethod_StridedLoop *unaligned_loop) | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| PyObject *ufunc = PyObject_GetAttrString(numpy, "matmul"); | ||||||||||||||||||||||||||||||||
| PyObject *ufunc = PyObject_GetAttrString(numpy, ufunc_name); | ||||||||||||||||||||||||||||||||
| if (ufunc == NULL) { | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &QuadPrecDType}; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| #ifndef DISABLE_QUADBLAS | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| PyType_Slot slots[] = { | ||||||||||||||||||||||||||||||||
| {NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors}, | ||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Claude points out that numpy-quaddtype/src/csrc/umath/matmul.cpp Lines 38 to 39 in fc52921
Also because |
||||||||||||||||||||||||||||||||
| {NPY_METH_strided_loop, (void *)&quad_matmul_strided_loop_aligned}, | ||||||||||||||||||||||||||||||||
| {NPY_METH_unaligned_strided_loop, (void *)&quad_matmul_strided_loop_unaligned}, | ||||||||||||||||||||||||||||||||
| {NPY_METH_strided_loop, (void *)aligned_loop}, | ||||||||||||||||||||||||||||||||
| {NPY_METH_unaligned_strided_loop, (void *)unaligned_loop}, | ||||||||||||||||||||||||||||||||
| {0, NULL}}; | ||||||||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||||||||
| PyType_Slot slots[] = {{NPY_METH_resolve_descriptors, (void *)&quad_matmul_resolve_descriptors}, | ||||||||||||||||||||||||||||||||
| {NPY_METH_strided_loop, (void *)&naive_matmul_strided_loop}, | ||||||||||||||||||||||||||||||||
| {NPY_METH_unaligned_strided_loop, (void *)&naive_matmul_strided_loop}, | ||||||||||||||||||||||||||||||||
| {0, NULL}}; | ||||||||||||||||||||||||||||||||
| #endif // DISABLE_QUADBLAS | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| PyArrayMethod_Spec Spec = { | ||||||||||||||||||||||||||||||||
| .name = "quad_matmul_qblas", | ||||||||||||||||||||||||||||||||
| .name = spec_name, | ||||||||||||||||||||||||||||||||
| .nin = 2, | ||||||||||||||||||||||||||||||||
| .nout = 1, | ||||||||||||||||||||||||||||||||
| .casting = NPY_NO_CASTING, | ||||||||||||||||||||||||||||||||
|
|
@@ -460,5 +602,35 @@ init_matmul_ops(PyObject *numpy) | |||||||||||||||||||||||||||||||
| Py_DECREF(promoter_capsule); | ||||||||||||||||||||||||||||||||
| Py_DECREF(ufunc); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| int | ||||||||||||||||||||||||||||||||
| init_matmul_ops(PyObject *numpy) | ||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||
| #ifndef DISABLE_QUADBLAS | ||||||||||||||||||||||||||||||||
| if (register_matmul_like_ufunc(numpy, "matmul", "quad_matmul_qblas", | ||||||||||||||||||||||||||||||||
| &quad_matmul_strided_loop_aligned, | ||||||||||||||||||||||||||||||||
| &quad_matmul_strided_loop_unaligned) < 0) { | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| if (register_matmul_like_ufunc(numpy, "vecdot", "quad_vecdot_qblas", | ||||||||||||||||||||||||||||||||
| &quad_vecdot_strided_loop_aligned, | ||||||||||||||||||||||||||||||||
| &quad_vecdot_strided_loop_unaligned) < 0) { | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||||||||
| if (register_matmul_like_ufunc(numpy, "matmul", "quad_matmul_naive", | ||||||||||||||||||||||||||||||||
| &naive_matmul_strided_loop, | ||||||||||||||||||||||||||||||||
| &naive_matmul_strided_loop) < 0) { | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| if (register_matmul_like_ufunc(numpy, "vecdot", "quad_vecdot_naive", | ||||||||||||||||||||||||||||||||
| &naive_vecdot_strided_loop, | ||||||||||||||||||||||||||||||||
| &naive_vecdot_strided_loop) < 0) { | ||||||||||||||||||||||||||||||||
| return -1; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| #endif // DISABLE_QUADBLAS | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import pytest | ||
| import numpy as np | ||
| from utils import create_quad_array, assert_quad_equal, assert_quad_array_equal, arrays_equal_with_nan | ||
| from utils import create_quad_array, assert_quad_equal, assert_quad_array_equal, arrays_equal_with_nan, _q, _qarr | ||
| from numpy_quaddtype import QuadPrecision, QuadPrecDType | ||
|
|
||
|
|
||
|
|
@@ -689,4 +689,108 @@ def test_dimension_mismatch_matrices(self): | |
| B = create_quad_array([1, 2, 3, 4, 5, 6], shape=(3, 2)) # Wrong size | ||
|
|
||
| with pytest.raises(ValueError, match=r"matmul: Input operand 1 has a mismatch in its core dimension 0"): | ||
| np.matmul(A, B) | ||
| np.matmul(A, B) | ||
|
|
||
|
|
||
| class TestVecdot: | ||
| """Tests for np.vecdot on QuadPrecision arrays.""" | ||
|
|
||
| def test_simple(self): | ||
| x = create_quad_array([1, 2, 3]) | ||
| y = create_quad_array([4, 5, 6]) | ||
| result = np.vecdot(x, y) | ||
| assert isinstance(result, QuadPrecision) | ||
| assert_quad_equal(result, 32.0) | ||
|
|
||
| def test_orthogonal(self): | ||
| x = create_quad_array([1, 0, 0]) | ||
| y = create_quad_array([0, 1, 0]) | ||
| assert_quad_equal(np.vecdot(x, y), 0.0) | ||
|
|
||
| def test_self_dot(self): | ||
| x = create_quad_array([2, 3, 4]) | ||
| assert_quad_equal(np.vecdot(x, x), 29.0) | ||
|
|
||
| @pytest.mark.parametrize("size", [1, 2, 5, 10, 50, 100]) | ||
| def test_various_sizes(self, size): | ||
| x_vals = [i + 1 for i in range(size)] | ||
| y_vals = [2 * (i + 1) for i in range(size)] | ||
| x = create_quad_array(x_vals) | ||
| y = create_quad_array(y_vals) | ||
| result = np.vecdot(x, y) | ||
| expected = sum(x_vals[i] * y_vals[i] for i in range(size)) | ||
| assert_quad_equal(result, expected) | ||
|
|
||
| def test_negative_and_fractional(self): | ||
| x = create_quad_array([1.5, -2.5, 3.25]) | ||
| y = create_quad_array([-1.25, 2.75, -3.5]) | ||
| expected = 1.5 * -1.25 + -2.5 * 2.75 + 3.25 * -3.5 | ||
| assert_quad_equal(np.vecdot(x, y), expected) | ||
|
|
||
| def test_single_element(self): | ||
| x = create_quad_array([7.0]) | ||
| y = create_quad_array([6.0]) | ||
| result = np.vecdot(x, y) | ||
| assert isinstance(result, QuadPrecision) | ||
| assert_quad_equal(result, 42.0) | ||
|
|
||
| def test_batched_vectors(self): | ||
| """vecdot broadcasts over leading dimensions.""" | ||
| x = _qarr([[1, 2, 3], [4, 5, 6]]) | ||
| y = _qarr([[1, 1, 1], [2, 2, 2]]) | ||
| result = np.vecdot(x, y) | ||
| assert result.shape == (2,) | ||
| assert_quad_equal(result[0], 6.0) | ||
| assert_quad_equal(result[1], 30.0) | ||
|
|
||
| def test_batched_3d(self): | ||
| x = _qarr([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) | ||
| y = _qarr([[[1, 1], [1, 1]], [[2, 2], [2, 2]]]) | ||
| result = np.vecdot(x, y) | ||
| assert result.shape == (2, 2) | ||
| expected = [[3, 7], [22, 30]] | ||
| for i in range(2): | ||
| for j in range(2): | ||
| assert_quad_equal(result[i, j], expected[i][j]) | ||
|
|
||
| def test_broadcast_against_scalar_vector(self): | ||
| """Broadcast a single vector against a stack.""" | ||
| x = _qarr([[1, 2, 3], [4, 5, 6]]) | ||
| y = _qarr([1, 1, 1]) | ||
| result = np.vecdot(x, y) | ||
| assert result.shape == (2,) | ||
| assert_quad_equal(result[0], 6.0) | ||
| assert_quad_equal(result[1], 15.0) | ||
|
|
||
| @pytest.mark.parametrize("special_val", ["0.0", "-0.0", "inf", "-inf", "nan"]) | ||
| def test_special_values(self, special_val): | ||
| x = create_quad_array([1.0, float(special_val), 2.0]) | ||
| y = create_quad_array([3.0, 4.0, 5.0]) | ||
| result = np.vecdot(x, y) | ||
| expected = np.vecdot(np.array([1.0, float(special_val), 2.0], dtype=np.float64), | ||
| np.array([3.0, 4.0, 5.0], dtype=np.float64)) | ||
| if np.isnan(expected): | ||
| assert np.isnan(float(result)) | ||
| elif np.isinf(expected): | ||
| assert np.isinf(float(result)) | ||
| assert np.sign(float(result)) == np.sign(expected) | ||
| else: | ||
| assert_quad_equal(result, expected) | ||
|
|
||
| def test_matches_matmul_for_1d(self): | ||
| """np.matmul of two 1D arrays is equivalent to vecdot.""" | ||
| x = create_quad_array([1.5, 2.5, -3.0, 0.25]) | ||
| y = create_quad_array([4.0, -1.0, 2.0, 8.0]) | ||
| assert_quad_equal(np.vecdot(x, y), np.matmul(x, y)) | ||
|
|
||
| def test_precision_advantage(self): | ||
| """vecdot accumulates with quad precision, beating float64 cancellation.""" | ||
| x = create_quad_array([1e20, 1.0, -1e20]) | ||
| y = create_quad_array([0.0, 1.0, 0.0]) | ||
| assert_quad_equal(np.vecdot(x, y), 1.0, atol=1e-25) | ||
|
|
||
| def test_dimension_mismatch(self): | ||
| x = create_quad_array([1, 2]) | ||
| y = create_quad_array([1, 2, 3]) | ||
| with pytest.raises(ValueError): | ||
| np.vecdot(x, y) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. might be worth adding a test that does vecdot on an empty array, e.g. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check is unnecessary because
resolve_descriptorswould have already triggered the same error.