Skip to content

Commit 2ed19c8

Browse files
committed
fix(distance-neon): update ARM pointer size checks to use _ARM32BIT_ macro for clarity
1 parent 1fa8143 commit 2ed19c8

File tree

1 file changed

+20
-15
lines changed

1 file changed

+20
-15
lines changed

src/distance-neon.c

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313

1414

1515
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
16+
17+
#if __SIZEOF_POINTER__ == 4
18+
#define _ARM32BIT_ 1
19+
#endif
20+
1621
#include <arm_neon.h>
1722

1823
extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX];
1924
extern char *distance_backend_name;
2025

2126
// Helper function for 32-bit ARM: vmaxv_u16 is not available in ARMv7 NEON
22-
#if __SIZEOF_POINTER__ == 4
27+
#ifdef _ARM32BIT_
2328
static inline uint16_t vmaxv_u16_compat(uint16x4_t v) {
2429
// Use pairwise max to reduce vector
2530
uint16x4_t m = vpmax_u16(v, v); // [max(v0,v1), max(v2,v3), max(v0,v1), max(v2,v3)]
@@ -169,7 +174,7 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
169174
const uint16_t *a = (const uint16_t *)v1;
170175
const uint16_t *b = (const uint16_t *)v2;
171176

172-
#if __SIZEOF_POINTER__ == 4
177+
#ifdef _ARM32BIT_
173178
// 32-bit ARM: use scalar double accumulation (no float64x2_t in NEON)
174179
double sum = 0.0;
175180
int i = 0;
@@ -446,7 +451,7 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
446451
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
447452
const uint16x4_t ZERO16 = vdup_n_u16(0);
448453

449-
#if __SIZEOF_POINTER__ == 4
454+
#ifdef _ARM32BIT_
450455
// 32-bit ARM: use scalar double accumulation
451456
double sum = 0.0;
452457
int i = 0;
@@ -487,7 +492,7 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
487492
uint32x4_t m = vceqq_f32(d32, d32); /* true where not-NaN */
488493
d32 = vbslq_f32(m, d32, vdupq_n_f32(0.0f));
489494

490-
#if __SIZEOF_POINTER__ == 4
495+
#ifdef _ARM32BIT_
491496
// 32-bit ARM: accumulate in scalar double
492497
float tmp[4];
493498
vst1q_f32(tmp, d32);
@@ -509,7 +514,7 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
509514
#endif
510515
}
511516

512-
#if __SIZEOF_POINTER__ != 4
517+
#ifndef _ARM32BIT_
513518
double sum = vaddvq_f64(vaddq_f64(acc0, acc1));
514519
#endif
515520

@@ -544,7 +549,7 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
544549
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
545550
const uint16x4_t ZERO16 = vdup_n_u16(0);
546551

547-
#if __SIZEOF_POINTER__ == 4
552+
#ifdef _ARM32BIT_
548553
// 32-bit ARM: use scalar double accumulation
549554
double dot = 0.0, normx = 0.0, normy = 0.0;
550555
int i = 0;
@@ -576,7 +581,7 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
576581
ax = vbslq_f32(mx, ax, vdupq_n_f32(0.0f));
577582
by = vbslq_f32(my, by, vdupq_n_f32(0.0f));
578583

579-
#if __SIZEOF_POINTER__ == 4
584+
#ifdef _ARM32BIT_
580585
// 32-bit ARM: accumulate in scalar double
581586
float ax_tmp[4], by_tmp[4];
582587
vst1q_f32(ax_tmp, ax);
@@ -611,7 +616,7 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
611616
#endif
612617
}
613618

614-
#if __SIZEOF_POINTER__ != 4
619+
#ifndef _ARM32BIT_
615620
double dot = vaddvq_f64(vaddq_f64(acc_dot_lo, acc_dot_hi));
616621
double normx= vaddvq_f64(vaddq_f64(acc_a2_lo, acc_a2_hi));
617622
double normy= vaddvq_f64(vaddq_f64(acc_b2_lo, acc_b2_hi));
@@ -649,7 +654,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
649654
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
650655
const uint16x4_t ZERO16 = vdup_n_u16(0);
651656

652-
#if __SIZEOF_POINTER__ == 4
657+
#ifdef _ARM32BIT_
653658
// 32-bit ARM: use scalar double accumulation
654659
double dot = 0.0;
655660
int i = 0;
@@ -675,7 +680,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
675680
if (isnan(x) || isnan(y)) continue;
676681
double p = (double)x * (double)y;
677682
if (isinf(p)) return (p>0)? -INFINITY : INFINITY;
678-
#if __SIZEOF_POINTER__ == 4
683+
#ifdef _ARM32BIT_
679684
dot += p;
680685
#else
681686
acc_lo = vsetq_lane_f64(vgetq_lane_f64(acc_lo,0)+p, acc_lo, 0); /* cheap add */
@@ -695,7 +700,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
695700

696701
float32x4_t prod = vmulq_f32(ax, by);
697702

698-
#if __SIZEOF_POINTER__ == 4
703+
#ifdef _ARM32BIT_
699704
// 32-bit ARM: accumulate in scalar double
700705
float prod_tmp[4];
701706
vst1q_f32(prod_tmp, prod);
@@ -711,7 +716,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
711716
#endif
712717
}
713718

714-
#if __SIZEOF_POINTER__ != 4
719+
#ifndef _ARM32BIT_
715720
double dot = vaddvq_f64(vaddq_f64(acc_lo, acc_hi));
716721
#endif
717722

@@ -739,7 +744,7 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
739744
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
740745
const uint16x4_t ZERO16 = vdup_n_u16(0);
741746

742-
#if __SIZEOF_POINTER__ == 4
747+
#ifdef _ARM32BIT_
743748
// 32-bit ARM: use scalar double accumulation
744749
double sum = 0.0;
745750
int i = 0;
@@ -776,7 +781,7 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
776781
uint32x4_t m = vceqq_f32(d, d); /* mask NaNs -> 0 */
777782
d = vbslq_f32(m, d, vdupq_n_f32(0.0f));
778783

779-
#if __SIZEOF_POINTER__ == 4
784+
#ifdef _ARM32BIT_
780785
// 32-bit ARM: accumulate in scalar double
781786
float tmp[4];
782787
vst1q_f32(tmp, d);
@@ -792,7 +797,7 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
792797
#endif
793798
}
794799

795-
#if __SIZEOF_POINTER__ != 4
800+
#ifndef _ARM32BIT_
796801
double sum = vaddvq_f64(acc);
797802
#endif
798803

0 commit comments

Comments
 (0)