Skip to content

Commit 7e241e0

Browse files
authored
Merge pull request #32 from sqliteai/add-amerabi-v7a-support
fix(android): add support for amerabi-v7a (arm NEON 32bit); fixes #30
2 parents c066d9e + 2ed19c8 commit 7e241e0

File tree

3 files changed

+149
-8
lines changed

3 files changed

+149
-8
lines changed

.github/workflows/main.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
build:
1212
runs-on: ${{ matrix.os }}
1313
container: ${{ matrix.container && matrix.container || '' }}
14-
name: ${{ matrix.name }}${{ matrix.arch && format('-{0}', matrix.arch) || '' }} build${{ matrix.arch != 'arm64-v8a' && matrix.name != 'ios-sim' && matrix.name != 'ios' && matrix.name != 'apple-xcframework' && matrix.name != 'android-aar' && ( matrix.name != 'macos' || matrix.arch != 'x86_64' ) && ' + test' || ''}}
14+
name: ${{ matrix.name }}${{ matrix.arch && format('-{0}', matrix.arch) || '' }} build${{ matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a' && matrix.name != 'ios-sim' && matrix.name != 'ios' && matrix.name != 'apple-xcframework' && matrix.name != 'android-aar' && ( matrix.name != 'macos' || matrix.arch != 'x86_64' ) && ' + test' || ''}}
1515
timeout-minutes: 20
1616
strategy:
1717
fail-fast: false
@@ -47,6 +47,10 @@ jobs:
4747
arch: arm64-v8a
4848
name: android
4949
make: PLATFORM=android ARCH=arm64-v8a
50+
- os: ubuntu-22.04
51+
arch: armeabi-v7a
52+
name: android
53+
make: PLATFORM=android ARCH=armeabi-v7a
5054
- os: ubuntu-22.04
5155
arch: x86_64
5256
name: android
@@ -140,7 +144,7 @@ jobs:
140144
security delete-keychain build.keychain
141145
142146
- name: android setup test environment
143-
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a'
147+
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a'
144148
run: |
145149
146150
echo "::group::enable kvm group perms"
@@ -168,7 +172,7 @@ jobs:
168172
echo "::endgroup::"
169173
170174
- name: android test sqlite-vector
171-
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a'
175+
if: matrix.name == 'android' && matrix.arch != 'arm64-v8a' && matrix.arch != 'armeabi-v7a'
172176
uses: reactivecircus/[email protected]
173177
with:
174178
api-level: 26

Makefile

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,22 @@ else ifeq ($(PLATFORM),macos)
5959
STRIP = strip -x -S $@
6060
else ifeq ($(PLATFORM),android)
6161
ifndef ARCH # Set ARCH to find Android NDK's Clang compiler, the user should set the ARCH
62-
$(error "Android ARCH must be set to ARCH=x86_64 or ARCH=arm64-v8a")
62+
$(error "Android ARCH must be set to ARCH=x86_64, ARCH=arm64-v8a, or ARCH=armeabi-v7a")
6363
endif
6464
ifndef ANDROID_NDK # Set ANDROID_NDK path to find android build tools; e.g. on MacOS: export ANDROID_NDK=/Users/username/Library/Android/sdk/ndk/25.2.9519653
6565
$(error "Android NDK must be set")
6666
endif
6767
BIN = $(ANDROID_NDK)/toolchains/llvm/prebuilt/$(HOST)-x86_64/bin
6868
ifneq (,$(filter $(ARCH),arm64 arm64-v8a))
6969
override ARCH := aarch64
70+
ANDROID_ABI := android26
71+
else ifeq ($(ARCH),armeabi-v7a)
72+
override ARCH := armv7a
73+
ANDROID_ABI := androideabi26
74+
else
75+
ANDROID_ABI := android26
7076
endif
71-
CC = $(BIN)/$(ARCH)-linux-android26-clang
77+
CC = $(BIN)/$(ARCH)-linux-$(ANDROID_ABI)-clang
7278
TARGET := $(DIST_DIR)/vector.so
7379
LDFLAGS += -lm -shared
7480
STRIP = $(BIN)/llvm-strip --strip-unneeded $@
@@ -184,11 +190,14 @@ $(DIST_DIR)/%.xcframework: $(LIB_NAMES)
184190

185191
xcframework: $(DIST_DIR)/vector.xcframework
186192

187-
AAR_ARM = packages/android/src/main/jniLibs/arm64-v8a/
193+
AAR_ARM64 = packages/android/src/main/jniLibs/arm64-v8a/
194+
AAR_ARM = packages/android/src/main/jniLibs/armeabi-v7a/
188195
AAR_X86 = packages/android/src/main/jniLibs/x86_64/
189196
aar:
190-
mkdir -p $(AAR_ARM) $(AAR_X86)
197+
mkdir -p $(AAR_ARM64) $(AAR_ARM) $(AAR_X86)
191198
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=arm64-v8a
199+
mv $(DIST_DIR)/vector.so $(AAR_ARM64)
200+
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=armeabi-v7a
192201
mv $(DIST_DIR)/vector.so $(AAR_ARM)
193202
$(MAKE) clean && $(MAKE) PLATFORM=android ARCH=x86_64
194203
mv $(DIST_DIR)/vector.so $(AAR_X86)
@@ -208,7 +217,7 @@ help:
208217
@echo " linux (default on Linux)"
209218
@echo " macos (default on macOS)"
210219
@echo " windows (default on Windows)"
211-
@echo " android (needs ARCH to be set to x86_64 or arm64-v8a and ANDROID_NDK to be set)"
220+
@echo " android (needs ARCH to be set to x86_64, arm64-v8a, or armeabi-v7a and ANDROID_NDK to be set)"
212221
@echo " ios (only on macOS)"
213222
@echo " ios-sim (only on macOS)"
214223
@echo ""

src/distance-neon.c

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,27 @@
1313

1414

1515
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
16+
17+
#if __SIZEOF_POINTER__ == 4
18+
#define _ARM32BIT_ 1
19+
#endif
20+
1621
#include <arm_neon.h>
1722

1823
extern distance_function_t dispatch_distance_table[VECTOR_DISTANCE_MAX][VECTOR_TYPE_MAX];
1924
extern char *distance_backend_name;
2025

26+
// Helper function for 32-bit ARM: vmaxv_u16 is not available in ARMv7 NEON
27+
#ifdef _ARM32BIT_
28+
static inline uint16_t vmaxv_u16_compat(uint16x4_t v) {
29+
// Use pairwise max to reduce vector
30+
uint16x4_t m = vpmax_u16(v, v); // [max(v0,v1), max(v2,v3), max(v0,v1), max(v2,v3)]
31+
m = vpmax_u16(m, m); // [max(all), max(all), max(all), max(all)]
32+
return vget_lane_u16(m, 0);
33+
}
34+
#define vmaxv_u16 vmaxv_u16_compat
35+
#endif
36+
2137
// MARK: FLOAT32 -
2238

2339
float float32_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool use_sqrt) {
@@ -158,6 +174,31 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
158174
const uint16_t *a = (const uint16_t *)v1;
159175
const uint16_t *b = (const uint16_t *)v2;
160176

177+
#ifdef _ARM32BIT_
178+
// 32-bit ARM: use scalar double accumulation (no float64x2_t in NEON)
179+
double sum = 0.0;
180+
int i = 0;
181+
182+
for (; i <= n - 4; i += 4) {
183+
uint16x4_t av16 = vld1_u16(a + i);
184+
uint16x4_t bv16 = vld1_u16(b + i);
185+
186+
float32x4_t va = bf16x4_to_f32x4_u16(av16);
187+
float32x4_t vb = bf16x4_to_f32x4_u16(bv16);
188+
float32x4_t d = vsubq_f32(va, vb);
189+
// mask-out NaNs: m = (d==d)
190+
uint32x4_t m = vceqq_f32(d, d);
191+
d = vbslq_f32(m, d, vdupq_n_f32(0.0f));
192+
193+
// Store and accumulate in scalar double
194+
float tmp[4];
195+
vst1q_f32(tmp, d);
196+
for (int j = 0; j < 4; j++) {
197+
double dj = (double)tmp[j];
198+
sum = fma(dj, dj, sum);
199+
}
200+
}
201+
#else
161202
// Accumulate in f64 to avoid overflow from huge bf16 values.
162203
float64x2_t acc0 = vdupq_n_f64(0.0), acc1 = vdupq_n_f64(0.0);
163204
int i = 0;
@@ -205,6 +246,7 @@ float bfloat16_distance_l2_impl_neon (const void *v1, const void *v2, int n, boo
205246
}
206247

207248
double sum = vaddvq_f64(vaddq_f64(acc0, acc1));
249+
#endif
208250

209251
// scalar tail; treat NaN as 0, Inf as +Inf result
210252
for (; i < n; ++i) {
@@ -409,8 +451,15 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
409451
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
410452
const uint16x4_t ZERO16 = vdup_n_u16(0);
411453

454+
#ifdef _ARM32BIT_
455+
// 32-bit ARM: use scalar double accumulation
456+
double sum = 0.0;
457+
int i = 0;
458+
#else
459+
// 64-bit ARM: use float64x2_t NEON intrinsics
412460
float64x2_t acc0 = vdupq_n_f64(0.0), acc1 = vdupq_n_f64(0.0);
413461
int i = 0;
462+
#endif
414463

415464
for (; i <= n - 4; i += 4) {
416465
uint16x4_t av16 = vld1_u16(a + i);
@@ -443,6 +492,16 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
443492
uint32x4_t m = vceqq_f32(d32, d32); /* true where not-NaN */
444493
d32 = vbslq_f32(m, d32, vdupq_n_f32(0.0f));
445494

495+
#ifdef _ARM32BIT_
496+
// 32-bit ARM: accumulate in scalar double
497+
float tmp[4];
498+
vst1q_f32(tmp, d32);
499+
for (int j = 0; j < 4; j++) {
500+
double dj = (double)tmp[j];
501+
sum = fma(dj, dj, sum);
502+
}
503+
#else
504+
// 64-bit ARM: use NEON f64 operations
446505
float64x2_t dlo = vcvt_f64_f32(vget_low_f32(d32));
447506
float64x2_t dhi = vcvt_f64_f32(vget_high_f32(d32));
448507
#if defined(__ARM_FEATURE_FMA)
@@ -451,10 +510,13 @@ float float16_distance_l2_impl_neon (const void *v1, const void *v2, int n, bool
451510
#else
452511
acc0 = vaddq_f64(acc0, vmulq_f64(dlo, dlo));
453512
acc1 = vaddq_f64(acc1, vmulq_f64(dhi, dhi));
513+
#endif
454514
#endif
455515
}
456516

517+
#ifndef _ARM32BIT_
457518
double sum = vaddvq_f64(vaddq_f64(acc0, acc1));
519+
#endif
458520

459521
/* tail (scalar; same Inf/NaN policy) */
460522
for (; i < n; ++i) {
@@ -487,10 +549,17 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
487549
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
488550
const uint16x4_t ZERO16 = vdup_n_u16(0);
489551

552+
#ifdef _ARM32BIT_
553+
// 32-bit ARM: use scalar double accumulation
554+
double dot = 0.0, normx = 0.0, normy = 0.0;
555+
int i = 0;
556+
#else
557+
// 64-bit ARM: use float64x2_t NEON intrinsics
490558
float64x2_t acc_dot_lo = vdupq_n_f64(0.0), acc_dot_hi = vdupq_n_f64(0.0);
491559
float64x2_t acc_a2_lo = vdupq_n_f64(0.0), acc_a2_hi = vdupq_n_f64(0.0);
492560
float64x2_t acc_b2_lo = vdupq_n_f64(0.0), acc_b2_hi = vdupq_n_f64(0.0);
493561
int i = 0;
562+
#endif
494563

495564
for (; i <= n - 4; i += 4) {
496565
uint16x4_t av16 = vld1_u16(a + i);
@@ -512,6 +581,19 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
512581
ax = vbslq_f32(mx, ax, vdupq_n_f32(0.0f));
513582
by = vbslq_f32(my, by, vdupq_n_f32(0.0f));
514583

584+
#ifdef _ARM32BIT_
585+
// 32-bit ARM: accumulate in scalar double
586+
float ax_tmp[4], by_tmp[4];
587+
vst1q_f32(ax_tmp, ax);
588+
vst1q_f32(by_tmp, by);
589+
for (int j = 0; j < 4; j++) {
590+
double x = (double)ax_tmp[j];
591+
double y = (double)by_tmp[j];
592+
dot += x * y;
593+
normx += x * x;
594+
normy += y * y;
595+
}
596+
#else
515597
/* widen to f64 and accumulate */
516598
float64x2_t ax_lo = vcvt_f64_f32(vget_low_f32(ax)), ax_hi = vcvt_f64_f32(vget_high_f32(ax));
517599
float64x2_t by_lo = vcvt_f64_f32(vget_low_f32(by)), by_hi = vcvt_f64_f32(vget_high_f32(by));
@@ -530,12 +612,15 @@ float float16_distance_cosine_neon (const void *v1, const void *v2, int n) {
530612
acc_a2_hi = vaddq_f64(acc_a2_hi, vmulq_f64(ax_hi, ax_hi));
531613
acc_b2_lo = vaddq_f64(acc_b2_lo, vmulq_f64(by_lo, by_lo));
532614
acc_b2_hi = vaddq_f64(acc_b2_hi, vmulq_f64(by_hi, by_hi));
615+
#endif
533616
#endif
534617
}
535618

619+
#ifndef _ARM32BIT_
536620
double dot = vaddvq_f64(vaddq_f64(acc_dot_lo, acc_dot_hi));
537621
double normx= vaddvq_f64(vaddq_f64(acc_a2_lo, acc_a2_hi));
538622
double normy= vaddvq_f64(vaddq_f64(acc_b2_lo, acc_b2_hi));
623+
#endif
539624

540625
/* tail (scalar) */
541626
for (; i < n; ++i) {
@@ -569,8 +654,15 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
569654
const uint16x4_t FRAC_MASK = vdup_n_u16(0x03FFu);
570655
const uint16x4_t ZERO16 = vdup_n_u16(0);
571656

657+
#ifdef _ARM32BIT_
658+
// 32-bit ARM: use scalar double accumulation
659+
double dot = 0.0;
660+
int i = 0;
661+
#else
662+
// 64-bit ARM: use float64x2_t NEON intrinsics
572663
float64x2_t acc_lo = vdupq_n_f64(0.0), acc_hi = vdupq_n_f64(0.0);
573664
int i = 0;
665+
#endif
574666

575667
for (; i <= n - 4; i += 4) {
576668
uint16x4_t av16 = vld1_u16(a + i);
@@ -588,7 +680,11 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
588680
if (isnan(x) || isnan(y)) continue;
589681
double p = (double)x * (double)y;
590682
if (isinf(p)) return (p>0)? -INFINITY : INFINITY;
683+
#ifdef _ARM32BIT_
684+
dot += p;
685+
#else
591686
acc_lo = vsetq_lane_f64(vgetq_lane_f64(acc_lo,0)+p, acc_lo, 0); /* cheap add */
687+
#endif
592688
}
593689
continue;
594690
}
@@ -603,13 +699,26 @@ float float16_distance_dot_neon (const void *v1, const void *v2, int n) {
603699
by = vbslq_f32(my, by, vdupq_n_f32(0.0f));
604700

605701
float32x4_t prod = vmulq_f32(ax, by);
702+
703+
#ifdef _ARM32BIT_
704+
// 32-bit ARM: accumulate in scalar double
705+
float prod_tmp[4];
706+
vst1q_f32(prod_tmp, prod);
707+
for (int j = 0; j < 4; j++) {
708+
dot += (double)prod_tmp[j];
709+
}
710+
#else
711+
// 64-bit ARM: use NEON f64 operations
606712
float64x2_t lo = vcvt_f64_f32(vget_low_f32(prod));
607713
float64x2_t hi = vcvt_f64_f32(vget_high_f32(prod));
608714
acc_lo = vaddq_f64(acc_lo, lo);
609715
acc_hi = vaddq_f64(acc_hi, hi);
716+
#endif
610717
}
611718

719+
#ifndef _ARM32BIT_
612720
double dot = vaddvq_f64(vaddq_f64(acc_lo, acc_hi));
721+
#endif
613722

614723
for (; i < n; ++i) {
615724
float x = float16_to_float32(a[i]);
@@ -635,8 +744,15 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
635744
const uint16x4_t SIGN_MASK = vdup_n_u16(0x8000u);
636745
const uint16x4_t ZERO16 = vdup_n_u16(0);
637746

747+
#ifdef _ARM32BIT_
748+
// 32-bit ARM: use scalar double accumulation
749+
double sum = 0.0;
750+
int i = 0;
751+
#else
752+
// 64-bit ARM: use float64x2_t NEON intrinsics
638753
float64x2_t acc = vdupq_n_f64(0.0);
639754
int i = 0;
755+
#endif
640756

641757
for (; i <= n - 4; i += 4) {
642758
uint16x4_t av16 = vld1_u16(a + i);
@@ -665,13 +781,25 @@ float float16_distance_l1_neon (const void *v1, const void *v2, int n) {
665781
uint32x4_t m = vceqq_f32(d, d); /* mask NaNs -> 0 */
666782
d = vbslq_f32(m, d, vdupq_n_f32(0.0f));
667783

784+
#ifdef _ARM32BIT_
785+
// 32-bit ARM: accumulate in scalar double
786+
float tmp[4];
787+
vst1q_f32(tmp, d);
788+
for (int j = 0; j < 4; j++) {
789+
sum += (double)tmp[j];
790+
}
791+
#else
792+
// 64-bit ARM: use NEON f64 operations
668793
float64x2_t lo = vcvt_f64_f32(vget_low_f32(d));
669794
float64x2_t hi = vcvt_f64_f32(vget_high_f32(d));
670795
acc = vaddq_f64(acc, lo);
671796
acc = vaddq_f64(acc, hi);
797+
#endif
672798
}
673799

800+
#ifndef _ARM32BIT_
674801
double sum = vaddvq_f64(acc);
802+
#endif
675803

676804
for (; i < n; ++i) {
677805
uint16_t ai=a[i], bi=b[i];

0 commit comments

Comments
 (0)