1313
1414
1515#if defined(__ARM_NEON ) || defined(__ARM_NEON__ )
16+
17+ #if __SIZEOF_POINTER__ == 4
18+ #define _ARM32BIT_ 1
19+ #endif
20+
1621#include <arm_neon.h>
1722
1823extern distance_function_t dispatch_distance_table [VECTOR_DISTANCE_MAX ][VECTOR_TYPE_MAX ];
1924extern char * distance_backend_name ;
2025
2126// Helper function for 32-bit ARM: vmaxv_u16 is not available in ARMv7 NEON
22- #if __SIZEOF_POINTER__ == 4
27+ #ifdef _ARM32BIT_
2328static inline uint16_t vmaxv_u16_compat (uint16x4_t v ) {
2429 // Use pairwise max to reduce vector
2530 uint16x4_t m = vpmax_u16 (v , v ); // [max(v0,v1), max(v2,v3), max(v0,v1), max(v2,v3)]
@@ -169,7 +174,7 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
169174 const uint16_t * a = (const uint16_t * )v1 ;
170175 const uint16_t * b = (const uint16_t * )v2 ;
171176
172- #if __SIZEOF_POINTER__ == 4
177+ #ifdef _ARM32BIT_
173178 // 32-bit ARM: use scalar double accumulation (no float64x2_t in NEON)
174179 double sum = 0.0 ;
175180 int i = 0 ;
@@ -446,7 +451,7 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
446451 const uint16x4_t SIGN_MASK = vdup_n_u16 (0x8000u );
447452 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
448453
449- #if __SIZEOF_POINTER__ == 4
454+ #ifdef _ARM32BIT_
450455 // 32-bit ARM: use scalar double accumulation
451456 double sum = 0.0 ;
452457 int i = 0 ;
@@ -487,7 +492,7 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
487492 uint32x4_t m = vceqq_f32 (d32 , d32 ); /* true where not-NaN */
488493 d32 = vbslq_f32 (m , d32 , vdupq_n_f32 (0.0f ));
489494
490- #if __SIZEOF_POINTER__ == 4
495+ #ifdef _ARM32BIT_
491496 // 32-bit ARM: accumulate in scalar double
492497 float tmp [4 ];
493498 vst1q_f32 (tmp , d32 );
@@ -509,7 +514,7 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
509514#endif
510515 }
511516
512- #if __SIZEOF_POINTER__ != 4
517+ #ifndef _ARM32BIT_
513518 double sum = vaddvq_f64 (vaddq_f64 (acc0 , acc1 ));
514519#endif
515520
@@ -544,7 +549,7 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
544549 const uint16x4_t FRAC_MASK = vdup_n_u16 (0x03FFu );
545550 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
546551
547- #if __SIZEOF_POINTER__ == 4
552+ #ifdef _ARM32BIT_
548553 // 32-bit ARM: use scalar double accumulation
549554 double dot = 0.0 , normx = 0.0 , normy = 0.0 ;
550555 int i = 0 ;
@@ -576,7 +581,7 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
576581 ax = vbslq_f32 (mx , ax , vdupq_n_f32 (0.0f ));
577582 by = vbslq_f32 (my , by , vdupq_n_f32 (0.0f ));
578583
579- #if __SIZEOF_POINTER__ == 4
584+ #ifdef _ARM32BIT_
580585 // 32-bit ARM: accumulate in scalar double
581586 float ax_tmp [4 ], by_tmp [4 ];
582587 vst1q_f32 (ax_tmp , ax );
@@ -611,7 +616,7 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
611616#endif
612617 }
613618
614- #if __SIZEOF_POINTER__ != 4
619+ #ifndef _ARM32BIT_
615620 double dot = vaddvq_f64 (vaddq_f64 (acc_dot_lo , acc_dot_hi ));
616621 double normx = vaddvq_f64 (vaddq_f64 (acc_a2_lo , acc_a2_hi ));
617622 double normy = vaddvq_f64 (vaddq_f64 (acc_b2_lo , acc_b2_hi ));
@@ -649,7 +654,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
649654 const uint16x4_t FRAC_MASK = vdup_n_u16 (0x03FFu );
650655 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
651656
652- #if __SIZEOF_POINTER__ == 4
657+ #ifdef _ARM32BIT_
653658 // 32-bit ARM: use scalar double accumulation
654659 double dot = 0.0 ;
655660 int i = 0 ;
@@ -675,7 +680,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
675680 if (isnan (x ) || isnan (y )) continue ;
676681 double p = (double )x * (double )y ;
677682 if (isinf (p )) return (p > 0 )? - INFINITY : INFINITY ;
678- #if __SIZEOF_POINTER__ == 4
683+ #ifdef _ARM32BIT_
679684 dot += p ;
680685#else
681686 acc_lo = vsetq_lane_f64 (vgetq_lane_f64 (acc_lo ,0 )+ p , acc_lo , 0 ); /* cheap add */
@@ -695,7 +700,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
695700
696701 float32x4_t prod = vmulq_f32 (ax , by );
697702
698- #if __SIZEOF_POINTER__ == 4
703+ #ifdef _ARM32BIT_
699704 // 32-bit ARM: accumulate in scalar double
700705 float prod_tmp [4 ];
701706 vst1q_f32 (prod_tmp , prod );
@@ -711,7 +716,7 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
711716#endif
712717 }
713718
714- #if __SIZEOF_POINTER__ != 4
719+ #ifndef _ARM32BIT_
715720 double dot = vaddvq_f64 (vaddq_f64 (acc_lo , acc_hi ));
716721#endif
717722
@@ -739,7 +744,7 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
739744 const uint16x4_t SIGN_MASK = vdup_n_u16 (0x8000u );
740745 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
741746
742- #if __SIZEOF_POINTER__ == 4
747+ #ifdef _ARM32BIT_
743748 // 32-bit ARM: use scalar double accumulation
744749 double sum = 0.0 ;
745750 int i = 0 ;
@@ -776,7 +781,7 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
776781 uint32x4_t m = vceqq_f32 (d , d ); /* mask NaNs -> 0 */
777782 d = vbslq_f32 (m , d , vdupq_n_f32 (0.0f ));
778783
779- #if __SIZEOF_POINTER__ == 4
784+ #ifdef _ARM32BIT_
780785 // 32-bit ARM: accumulate in scalar double
781786 float tmp [4 ];
782787 vst1q_f32 (tmp , d );
@@ -792,7 +797,7 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
792797#endif
793798 }
794799
795- #if __SIZEOF_POINTER__ != 4
800+ #ifndef _ARM32BIT_
796801 double sum = vaddvq_f64 (acc );
797802#endif
798803
0 commit comments