1313
1414
1515#if defined(__ARM_NEON ) || defined(__ARM_NEON__ )
16+
17+ #if __SIZEOF_POINTER__ == 4
18+ #define _ARM32BIT_ 1
19+ #endif
20+
1621#include <arm_neon.h>
1722
1823extern distance_function_t dispatch_distance_table [VECTOR_DISTANCE_MAX ][VECTOR_TYPE_MAX ];
1924extern char * distance_backend_name ;
2025
26+ // Helper function for 32-bit ARM: vmaxv_u16 is not available in ARMv7 NEON
27+ #ifdef _ARM32BIT_
28+ static inline uint16_t vmaxv_u16_compat (uint16x4_t v ) {
29+ // Use pairwise max to reduce vector
30+ uint16x4_t m = vpmax_u16 (v , v ); // [max(v0,v1), max(v2,v3), max(v0,v1), max(v2,v3)]
31+ m = vpmax_u16 (m , m ); // [max(all), max(all), max(all), max(all)]
32+ return vget_lane_u16 (m , 0 );
33+ }
34+ #define vmaxv_u16 vmaxv_u16_compat
35+ #endif
36+
2137// MARK: FLOAT32 -
2238
2339float float32_distance_l2_impl_neon (const void * v1 , const void * v2 , int n , bool use_sqrt ) {
@@ -158,6 +174,31 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
158174 const uint16_t * a = (const uint16_t * )v1 ;
159175 const uint16_t * b = (const uint16_t * )v2 ;
160176
177+ #ifdef _ARM32BIT_
178+ // 32-bit ARM: use scalar double accumulation (no float64x2_t in NEON)
179+ double sum = 0.0 ;
180+ int i = 0 ;
181+
182+ for (; i <= n - 4 ; i += 4 ) {
183+ uint16x4_t av16 = vld1_u16 (a + i );
184+ uint16x4_t bv16 = vld1_u16 (b + i );
185+
186+ float32x4_t va = bf16x4_to_f32x4_u16 (av16 );
187+ float32x4_t vb = bf16x4_to_f32x4_u16 (bv16 );
188+ float32x4_t d = vsubq_f32 (va , vb );
189+ // mask-out NaNs: m = (d==d)
190+ uint32x4_t m = vceqq_f32 (d , d );
191+ d = vbslq_f32 (m , d , vdupq_n_f32 (0.0f ));
192+
193+ // Store and accumulate in scalar double
194+ float tmp [4 ];
195+ vst1q_f32 (tmp , d );
196+ for (int j = 0 ; j < 4 ; j ++ ) {
197+ double dj = (double )tmp [j ];
198+ sum = fma (dj , dj , sum );
199+ }
200+ }
201+ #else
161202 // Accumulate in f64 to avoid overflow from huge bf16 values.
162203 float64x2_t acc0 = vdupq_n_f64 (0.0 ), acc1 = vdupq_n_f64 (0.0 );
163204 int i = 0 ;
@@ -205,6 +246,7 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
205246 }
206247
207248 double sum = vaddvq_f64 (vaddq_f64 (acc0 , acc1 ));
249+ #endif
208250
209251 // scalar tail; treat NaN as 0, Inf as +Inf result
210252 for (; i < n ; ++ i ) {
@@ -409,8 +451,15 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
409451 const uint16x4_t SIGN_MASK = vdup_n_u16 (0x8000u );
410452 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
411453
454+ #ifdef _ARM32BIT_
455+ // 32-bit ARM: use scalar double accumulation
456+ double sum = 0.0 ;
457+ int i = 0 ;
458+ #else
459+ // 64-bit ARM: use float64x2_t NEON intrinsics
412460 float64x2_t acc0 = vdupq_n_f64 (0.0 ), acc1 = vdupq_n_f64 (0.0 );
413461 int i = 0 ;
462+ #endif
414463
415464 for (; i <= n - 4 ; i += 4 ) {
416465 uint16x4_t av16 = vld1_u16 (a + i );
@@ -443,6 +492,16 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
443492 uint32x4_t m = vceqq_f32 (d32 , d32 ); /* true where not-NaN */
444493 d32 = vbslq_f32 (m , d32 , vdupq_n_f32 (0.0f ));
445494
495+ #ifdef _ARM32BIT_
496+ // 32-bit ARM: accumulate in scalar double
497+ float tmp [4 ];
498+ vst1q_f32 (tmp , d32 );
499+ for (int j = 0 ; j < 4 ; j ++ ) {
500+ double dj = (double )tmp [j ];
501+ sum = fma (dj , dj , sum );
502+ }
503+ #else
504+ // 64-bit ARM: use NEON f64 operations
446505 float64x2_t dlo = vcvt_f64_f32 (vget_low_f32 (d32 ));
447506 float64x2_t dhi = vcvt_f64_f32 (vget_high_f32 (d32 ));
448507#if defined(__ARM_FEATURE_FMA )
@@ -451,10 +510,13 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
451510#else
452511 acc0 = vaddq_f64 (acc0 , vmulq_f64 (dlo , dlo ));
453512 acc1 = vaddq_f64 (acc1 , vmulq_f64 (dhi , dhi ));
513+ #endif
454514#endif
455515 }
456516
517+ #ifndef _ARM32BIT_
457518 double sum = vaddvq_f64 (vaddq_f64 (acc0 , acc1 ));
519+ #endif
458520
459521 /* tail (scalar; same Inf/NaN policy) */
460522 for (; i < n ; ++ i ) {
@@ -487,10 +549,17 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
487549 const uint16x4_t FRAC_MASK = vdup_n_u16 (0x03FFu );
488550 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
489551
552+ #ifdef _ARM32BIT_
553+ // 32-bit ARM: use scalar double accumulation
554+ double dot = 0.0 , normx = 0.0 , normy = 0.0 ;
555+ int i = 0 ;
556+ #else
557+ // 64-bit ARM: use float64x2_t NEON intrinsics
490558 float64x2_t acc_dot_lo = vdupq_n_f64 (0.0 ), acc_dot_hi = vdupq_n_f64 (0.0 );
491559 float64x2_t acc_a2_lo = vdupq_n_f64 (0.0 ), acc_a2_hi = vdupq_n_f64 (0.0 );
492560 float64x2_t acc_b2_lo = vdupq_n_f64 (0.0 ), acc_b2_hi = vdupq_n_f64 (0.0 );
493561 int i = 0 ;
562+ #endif
494563
495564 for (; i <= n - 4 ; i += 4 ) {
496565 uint16x4_t av16 = vld1_u16 (a + i );
@@ -512,6 +581,19 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
512581 ax = vbslq_f32 (mx , ax , vdupq_n_f32 (0.0f ));
513582 by = vbslq_f32 (my , by , vdupq_n_f32 (0.0f ));
514583
584+ #ifdef _ARM32BIT_
585+ // 32-bit ARM: accumulate in scalar double
586+ float ax_tmp [4 ], by_tmp [4 ];
587+ vst1q_f32 (ax_tmp , ax );
588+ vst1q_f32 (by_tmp , by );
589+ for (int j = 0 ; j < 4 ; j ++ ) {
590+ double x = (double )ax_tmp [j ];
591+ double y = (double )by_tmp [j ];
592+ dot += x * y ;
593+ normx += x * x ;
594+ normy += y * y ;
595+ }
596+ #else
515597 /* widen to f64 and accumulate */
516598 float64x2_t ax_lo = vcvt_f64_f32 (vget_low_f32 (ax )), ax_hi = vcvt_f64_f32 (vget_high_f32 (ax ));
517599 float64x2_t by_lo = vcvt_f64_f32 (vget_low_f32 (by )), by_hi = vcvt_f64_f32 (vget_high_f32 (by ));
@@ -530,12 +612,15 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
530612 acc_a2_hi = vaddq_f64 (acc_a2_hi , vmulq_f64 (ax_hi , ax_hi ));
531613 acc_b2_lo = vaddq_f64 (acc_b2_lo , vmulq_f64 (by_lo , by_lo ));
532614 acc_b2_hi = vaddq_f64 (acc_b2_hi , vmulq_f64 (by_hi , by_hi ));
615+ #endif
533616#endif
534617 }
535618
619+ #ifndef _ARM32BIT_
536620 double dot = vaddvq_f64 (vaddq_f64 (acc_dot_lo , acc_dot_hi ));
537621 double normx = vaddvq_f64 (vaddq_f64 (acc_a2_lo , acc_a2_hi ));
538622 double normy = vaddvq_f64 (vaddq_f64 (acc_b2_lo , acc_b2_hi ));
623+ #endif
539624
540625 /* tail (scalar) */
541626 for (; i < n ; ++ i ) {
@@ -569,8 +654,15 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
569654 const uint16x4_t FRAC_MASK = vdup_n_u16 (0x03FFu );
570655 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
571656
657+ #ifdef _ARM32BIT_
658+ // 32-bit ARM: use scalar double accumulation
659+ double dot = 0.0 ;
660+ int i = 0 ;
661+ #else
662+ // 64-bit ARM: use float64x2_t NEON intrinsics
572663 float64x2_t acc_lo = vdupq_n_f64 (0.0 ), acc_hi = vdupq_n_f64 (0.0 );
573664 int i = 0 ;
665+ #endif
574666
575667 for (; i <= n - 4 ; i += 4 ) {
576668 uint16x4_t av16 = vld1_u16 (a + i );
@@ -588,7 +680,11 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
588680 if (isnan (x ) || isnan (y )) continue ;
589681 double p = (double )x * (double )y ;
590682 if (isinf (p )) return (p > 0 )? - INFINITY : INFINITY ;
683+ #ifdef _ARM32BIT_
684+ dot += p ;
685+ #else
591686 acc_lo = vsetq_lane_f64 (vgetq_lane_f64 (acc_lo ,0 )+ p , acc_lo , 0 ); /* cheap add */
687+ #endif
592688 }
593689 continue ;
594690 }
@@ -603,13 +699,26 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
603699 by = vbslq_f32 (my , by , vdupq_n_f32 (0.0f ));
604700
605701 float32x4_t prod = vmulq_f32 (ax , by );
702+
703+ #ifdef _ARM32BIT_
704+ // 32-bit ARM: accumulate in scalar double
705+ float prod_tmp [4 ];
706+ vst1q_f32 (prod_tmp , prod );
707+ for (int j = 0 ; j < 4 ; j ++ ) {
708+ dot += (double )prod_tmp [j ];
709+ }
710+ #else
711+ // 64-bit ARM: use NEON f64 operations
606712 float64x2_t lo = vcvt_f64_f32 (vget_low_f32 (prod ));
607713 float64x2_t hi = vcvt_f64_f32 (vget_high_f32 (prod ));
608714 acc_lo = vaddq_f64 (acc_lo , lo );
609715 acc_hi = vaddq_f64 (acc_hi , hi );
716+ #endif
610717 }
611718
719+ #ifndef _ARM32BIT_
612720 double dot = vaddvq_f64 (vaddq_f64 (acc_lo , acc_hi ));
721+ #endif
613722
614723 for (; i < n ; ++ i ) {
615724 float x = float16_to_float32 (a [i ]);
@@ -635,8 +744,15 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
635744 const uint16x4_t SIGN_MASK = vdup_n_u16 (0x8000u );
636745 const uint16x4_t ZERO16 = vdup_n_u16 (0 );
637746
747+ #ifdef _ARM32BIT_
748+ // 32-bit ARM: use scalar double accumulation
749+ double sum = 0.0 ;
750+ int i = 0 ;
751+ #else
752+ // 64-bit ARM: use float64x2_t NEON intrinsics
638753 float64x2_t acc = vdupq_n_f64 (0.0 );
639754 int i = 0 ;
755+ #endif
640756
641757 for (; i <= n - 4 ; i += 4 ) {
642758 uint16x4_t av16 = vld1_u16 (a + i );
@@ -665,13 +781,25 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
665781 uint32x4_t m = vceqq_f32 (d , d ); /* mask NaNs -> 0 */
666782 d = vbslq_f32 (m , d , vdupq_n_f32 (0.0f ));
667783
784+ #ifdef _ARM32BIT_
785+ // 32-bit ARM: accumulate in scalar double
786+ float tmp [4 ];
787+ vst1q_f32 (tmp , d );
788+ for (int j = 0 ; j < 4 ; j ++ ) {
789+ sum += (double )tmp [j ];
790+ }
791+ #else
792+ // 64-bit ARM: use NEON f64 operations
668793 float64x2_t lo = vcvt_f64_f32 (vget_low_f32 (d ));
669794 float64x2_t hi = vcvt_f64_f32 (vget_high_f32 (d ));
670795 acc = vaddq_f64 (acc , lo );
671796 acc = vaddq_f64 (acc , hi );
797+ #endif
672798 }
673799
800+ #ifndef _ARM32BIT_
674801 double sum = vaddvq_f64 (acc );
802+ #endif
675803
676804 for (; i < n ; ++ i ) {
677805 uint16_t ai = a [i ], bi = b [i ];
0 commit comments