diff --git a/src/interpcore/kernels.py b/src/interpcore/kernels.py index 243ec77..2809428 100644 --- a/src/interpcore/kernels.py +++ b/src/interpcore/kernels.py @@ -151,8 +151,9 @@ def interpolate_block( # Initialize output for destination points interpolated = np.zeros([neighbours_coords.shape[0], config.num_components]) else: - # Initialize output for destination points (same size as chunked_coords) - interpolated = np.zeros([chunked_coords.shape[0], config.num_components]) + # Initialize output only for the current destination chunk + block_size = max(chunk_idx.stop - chunk_idx.start, 0) + interpolated = np.zeros([block_size, config.num_components]) unmapped = np.zeros([1, config.num_components]) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index c671ea4..6a978df 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -14,6 +14,7 @@ INTERPOLATED_LOAD_TYPE, INTERPOLATION_KERNEL, ) +from interpcore.dest_tree import DestinationTree from interpcore.errors import InterpolationError @@ -294,3 +295,95 @@ def test_interpolate_block_coincident_node(self): assert interpolated[2, 0] == 0.0 # Nothing should be unmapped assert np.all(unmapped == 0) + + def test_destination_to_source_multithread_handles_empty_blocks(self, monkeypatch): + """Dest-to-source multithread should handle blocks larger than points without shape errors""" + dest_coords = np.array([[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]]) + src_coords = np.array([[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]]) + dest_ids = np.array([1, 2]) + src_values = np.array([[100.0], [200.0]]) + + config = InterpolationConfig( + kernel=INTERPOLATION_KERNEL.CLOSEST, # Dest-to-source mode + max_distance=1.0, + coincidence_tolerance=1e-9, + method=QUERY_TYPE.K, + param=1, + multithread=True, + interpolated_load=INTERPOLATED_LOAD_TYPE.HEAT_FLUX, + ) + + # Force more worker blocks than destination points (creates empty trailing blocks) + monkeypatch.setattr("interpcore.dest_tree.os.cpu_count", lambda: 4) + + tree = DestinationTree( + dest_coordinates=dest_coords, + src_coordinates=src_coords, + dest_ids=dest_ids, + config=config, + ) + interpolated, unmapped = tree.interpolate(src_values) + + np.testing.assert_array_equal(interpolated, src_values) + np.testing.assert_array_equal(unmapped, np.zeros((1, 1))) + + def test_destination_to_source_multithread_matches_single_core(self, monkeypatch): + """Dest-to-source interpolation should match between multithread and single-core runs""" + dest_coords = np.array( + [ + [0.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [5.0, 0.0, 0.0], + [9.0, 0.0, 0.0], + ] + ) + src_coords = np.array( + [ + [0.1, 0.0, 0.0], + [1.9, 0.0, 0.0], + [5.2, 0.0, 0.0], + [8.8, 0.0, 0.0], + ] + ) + dest_ids = np.array([101, 102, 103, 104]) + src_values = np.array([[10.0], [20.0], [30.0], [40.0]]) + + single_core_config = InterpolationConfig( + kernel=INTERPOLATION_KERNEL.CLOSEST, + max_distance=1.0, + coincidence_tolerance=1e-9, + method=QUERY_TYPE.K, + param=1, + multithread=False, + interpolated_load=INTERPOLATED_LOAD_TYPE.HEAT_FLUX, + ) + multithread_config = InterpolationConfig( + kernel=INTERPOLATION_KERNEL.CLOSEST, + max_distance=1.0, + coincidence_tolerance=1e-9, + method=QUERY_TYPE.K, + param=1, + multithread=True, + interpolated_load=INTERPOLATED_LOAD_TYPE.HEAT_FLUX, + ) + + monkeypatch.setattr("interpcore.dest_tree.os.cpu_count", lambda: 8) + + single_core_tree = DestinationTree( + dest_coordinates=dest_coords, + src_coordinates=src_coords, + dest_ids=dest_ids, + config=single_core_config, + ) + multithread_tree = DestinationTree( + dest_coordinates=dest_coords, + src_coordinates=src_coords, + dest_ids=dest_ids, + config=multithread_config, + ) + + interpolated_single, unmapped_single = single_core_tree.interpolate(src_values) + interpolated_multi, unmapped_multi = multithread_tree.interpolate(src_values) + + np.testing.assert_array_equal(interpolated_multi, interpolated_single) + np.testing.assert_array_equal(unmapped_multi, unmapped_single)