Skip to content

Commit b816617

Browse files
committed
Fix: Skip set intersections in Aarch64 emulator
1 parent db2b793 commit b816617

File tree

2 files changed

+37
-29
lines changed

2 files changed

+37
-29
lines changed

include/simsimd/sparse.h

+22-22
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,18 @@ SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u16, size) // simsimd_intersect_u16_accu
9191
SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u32, size) // simsimd_intersect_u32_accurate
9292

9393
#define SIMSIMD_MAKE_INTERSECT_GALLOPING(name, input_type, accumulator_type) \
94-
SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const* b, \
95-
simsimd_size_t start, simsimd_size_t b_length, \
94+
SIMSIMD_PUBLIC simsimd_size_t simsimd_galloping_search_##input_type(simsimd_##input_type##_t const* array, \
95+
simsimd_size_t start, simsimd_size_t length, \
9696
simsimd_##input_type##_t val) { \
9797
simsimd_size_t low = start; \
9898
simsimd_size_t high = start + 1; \
99-
while (high < b_length && b[high] < val) { \
99+
while (high < length && array[high] < val) { \
100100
low = high; \
101-
high = (2 * high < b_length) ? 2 * high : b_length; \
101+
high = (2 * high < length) ? 2 * high : length; \
102102
} \
103103
while (low < high) { \
104104
simsimd_size_t mid = low + (high - low) / 2; \
105-
if (b[mid] < val) { \
105+
if (array[mid] < val) { \
106106
low = mid + 1; \
107107
} else { \
108108
high = mid; \
@@ -112,31 +112,31 @@ SIMSIMD_MAKE_INTERSECT_LINEAR(accurate, u32, size) // simsimd_intersect_u32_accu
112112
} \
113113
\
114114
SIMSIMD_PUBLIC void simsimd_intersect_##input_type##_##name( \
115-
simsimd_##input_type##_t const* a, simsimd_##input_type##_t const* b, simsimd_size_t a_length, \
116-
simsimd_size_t b_length, simsimd_distance_t* result) { \
117-
/* Swap arrays if necessary, as we want "b" to be larger than "a" */ \
118-
if (a_length > b_length) { \
119-
simsimd_##input_type##_t const* temp = a; \
120-
a = b; \
121-
b = temp; \
122-
simsimd_size_t temp_length = a_length; \
123-
a_length = b_length; \
124-
b_length = temp_length; \
115+
simsimd_##input_type##_t const* shorter, simsimd_##input_type##_t const* longer, \
116+
simsimd_size_t shorter_length, simsimd_size_t longer_length, simsimd_distance_t* result) { \
117+
/* Swap arrays if necessary, as we want "longer" to be larger than "shorter" */ \
118+
if (longer_length < shorter_length) { \
119+
simsimd_##input_type##_t const* temp = shorter; \
120+
shorter = longer; \
121+
longer = temp; \
122+
simsimd_size_t temp_length = shorter_length; \
123+
shorter_length = longer_length; \
124+
longer_length = temp_length; \
125125
} \
126126
\
127-
/* Use accurate implementation if galloping is not beneficial */ \
128-
if (b_length < 64 * a_length) { \
129-
simsimd_intersect_##input_type##_accurate(a, b, a_length, b_length, result); \
127+
/* Use the accurate implementation if galloping is not beneficial */ \
128+
if (longer_length < 64 * shorter_length) { \
129+
simsimd_intersect_##input_type##_accurate(shorter, longer, shorter_length, longer_length, result); \
130130
return; \
131131
} \
132132
\
133133
/* Perform galloping, shrinking the target range */ \
134134
simsimd_##accumulator_type##_t intersection = 0; \
135135
simsimd_size_t j = 0; \
136-
for (simsimd_size_t i = 0; i < a_length; ++i) { \
137-
simsimd_##input_type##_t ai = a[i]; \
138-
j = simsimd_galloping_search_##input_type(b, j, b_length, ai); \
139-
if (j < b_length && b[j] == ai) { \
136+
for (simsimd_size_t i = 0; i < shorter_length; ++i) { \
137+
simsimd_##input_type##_t shorter_i = shorter[i]; \
138+
j = simsimd_galloping_search_##input_type(longer, j, longer_length, shorter_i); \
139+
if (j < longer_length && longer[j] == shorter_i) { \
140140
intersection++; \
141141
} \
142142
} \

python/test.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
import platform
3+
24
import pytest
35
import simsimd as simd
46

@@ -479,16 +481,22 @@ def test_dot_complex_explicit(ndim):
479481

480482

481483
@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed")
482-
@pytest.mark.repeat(300)
484+
@pytest.mark.repeat(100)
483485
@pytest.mark.parametrize("dtype", ["uint16", "uint32"])
484-
@pytest.mark.parametrize("length_bound", [10, 25, 1000])
485-
def test_intersect(dtype, length_bound):
486+
@pytest.mark.parametrize("first_length_bound", [10, 100, 1000])
487+
@pytest.mark.parametrize("second_length_bound", [10, 100, 1000])
488+
def test_intersect(dtype, first_length_bound, second_length_bound):
486489
"""Compares the simd.intersect() function with numpy.intersect1d."""
490+
491+
if is_running_under_qemu() and (platform.machine() == "aarch64" or platform.machine() == "arm64"):
492+
pytest.skip("In QEMU `aarch64` emulation on `x86_64` the `intersect` function is not reliable")
493+
487494
np.random.seed()
488-
a_length = np.random.randint(1, length_bound)
489-
b_length = np.random.randint(1, length_bound)
490-
a = np.random.randint(length_bound * 2, size=a_length, dtype=dtype)
491-
b = np.random.randint(length_bound * 2, size=b_length, dtype=dtype)
495+
496+
a_length = np.random.randint(1, first_length_bound)
497+
b_length = np.random.randint(1, second_length_bound)
498+
a = np.random.randint(first_length_bound * 2, size=a_length, dtype=dtype)
499+
b = np.random.randint(second_length_bound * 2, size=b_length, dtype=dtype)
492500

493501
# Remove duplicates, converting into sorted arrays
494502
a = np.unique(a)

0 commit comments

Comments
 (0)