diff --git a/SingleSource/UnitTests/Vectorizer/argmin-argmax.cpp b/SingleSource/UnitTests/Vectorizer/argmin-argmax.cpp new file mode 100644 index 0000000000..43f0411544 --- /dev/null +++ b/SingleSource/UnitTests/Vectorizer/argmin-argmax.cpp @@ -0,0 +1,473 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" + +template +using Fn2Ty = std::function; +template +static void checkVectorFunction(Fn2Ty ScalarFn, + Fn2Ty VectorFn, const char *Name) { + std::cout << "Checking " << Name << "\n"; + + unsigned MaxN = 1024; + std::unique_ptr Src1(new Ty[MaxN]); + std::unique_ptr Src2(new Ty[MaxN]); + + // Test with different trip counts, including odd ones and powers of 2. + unsigned TripCounts[] = {1, 7, 15, 16, 31, 127, 128, 999, 1000, 1023}; + for (unsigned N : TripCounts) { + auto runTest = [&](const char *Desc) { + auto Reference = ScalarFn(&Src1[0], &Src2[0], N); + auto ToCheck = VectorFn(&Src1[0], &Src2[0], N); + if (Reference != ToCheck) { + std::cerr << "Miscompare for N=" << N << " (" << Desc << ")\n"; + exit(1); + } + }; + + // Check with random inputs. + init_data(Src1, N); + init_data(Src2, N); + runTest("random"); + + // Check with sorted inputs. + std::sort(&Src1[0], &Src1[N]); + std::sort(&Src2[0], &Src2[N]); + runTest("sorted"); + + // Check with reverse sorted inputs. + std::sort(&Src1[0], &Src1[N], std::greater()); + std::sort(&Src2[0], &Src2[N], std::greater()); + runTest("reverse sorted"); + + // Check with all max values. + for (unsigned I = 0; I != N; ++I) { + Src1[I] = std::numeric_limits::max(); + Src2[I] = std::numeric_limits::max(); + } + runTest("all max"); + + // Check with all min values. + for (unsigned I = 0; I != N; ++I) { + Src1[I] = std::numeric_limits::lowest(); + Src2[I] = std::numeric_limits::lowest(); + } + runTest("all min"); + + // Check with all zeros. + for (unsigned I = 0; I != N; ++I) { + Src1[I] = 0; + Src2[I] = 0; + } + runTest("all zero"); + + // Check with alternating min/max pattern (interesting for ADD tests). + for (unsigned I = 0; I != N; ++I) { + Src1[I] = (I & 1) ? std::numeric_limits::max() + : std::numeric_limits::lowest(); + Src2[I] = (I & 1) ? std::numeric_limits::lowest() + : std::numeric_limits::max(); + } + runTest("alternating"); + + // Helper to test both min and max with specific target indices. + auto testPattern = [&](Ty TargetVal, Ty FillVal, + const std::vector &Indices, + const char *Name) { + for (unsigned I = 0; I != N; ++I) { + Src1[I] = FillVal; + Src2[I] = FillVal; + } + for (unsigned Idx : Indices) { + if (Idx < N) { + Src1[Idx] = TargetVal; + Src2[Idx] = TargetVal; + } + } + runTest(Name); + }; + + // Define test patterns with their index positions. + struct Pattern { + std::function(unsigned)> GetIndices; + const char *Name; + unsigned MinN; + }; + + const Pattern Patterns[] = { + {{[](unsigned N) { return std::vector{0}; }}, + "first element", + 1}, + {{[](unsigned N) { return std::vector{N - 1}; }}, + "last element", + 1}, + {{[](unsigned N) { return std::vector{0, N / 2, N - 1}; }}, + "duplicate", + 3}, + {{[](unsigned N) { + // Mix of even and odd indices at different positions + std::vector Indices = {0, 1}; + if (N >= 6) { + unsigned MidEven = (N / 2) & ~1u; + unsigned MidOdd = (N / 2) | 1; + if (MidEven < N && MidEven != 0) + Indices.push_back(MidEven); + if (MidOdd < N && MidOdd != 1) + Indices.push_back(MidOdd); + } + return Indices; + }}, + "duplicate mixed even/odd indices", + 2}, + {{[](unsigned N) { + std::vector Indices = {1}; + unsigned MidOdd = (N / 2) | 1; + if (MidOdd < N) + Indices.push_back(MidOdd); + unsigned LastOdd = (N - 1) | 1; + if (LastOdd < N && LastOdd != 1 && LastOdd != MidOdd) + Indices.push_back(LastOdd); + return Indices; + }}, + "duplicate odd indices", + 4}, + {{[](unsigned N) { + // Start at index 2 (even), avoiding 0 and 1 + std::vector Indices = {2}; + if (N >= 8) { + unsigned MidEven = (N / 2) & ~1u; + if (MidEven > 2 && MidEven < N) + Indices.push_back(MidEven); + } + if (N >= 4) { + unsigned LastEven = (N - 1) & ~1u; + if (LastEven > 2 && LastEven < N) + Indices.push_back(LastEven); + } + return Indices; + }}, + "duplicate even indices skip 0", + 3}, + {{[](unsigned N) { + // Start at index 3 (odd), avoiding 0 and 1 + std::vector Indices = {3}; + if (N >= 8) { + unsigned MidOdd = (N / 2) | 1; + if (MidOdd > 3 && MidOdd < N) + Indices.push_back(MidOdd); + } + if (N >= 5) { + unsigned LastOdd = (N - 1) | 1; + if (LastOdd > 3 && LastOdd < N) + Indices.push_back(LastOdd); + } + return Indices; + }}, + "duplicate odd indices skip 1", + 4}, + {{[](unsigned N) { + // Only indices in second half of array + std::vector Indices; + unsigned Start = N / 2; + if (Start < 2) + Start = 2; + if (Start < N) + Indices.push_back(Start); + unsigned Mid = N / 2 + N / 4; + if (Mid > Start && Mid < N) + Indices.push_back(Mid); + if (N - 1 > Start) + Indices.push_back(N - 1); + return Indices; + }}, + "duplicate second half only", + 4}, + }; + + // Run all patterns for both min and max. + for (const auto &P : Patterns) { + if (N >= P.MinN) { + auto Indices = P.GetIndices(N); + testPattern(std::numeric_limits::lowest(), + std::numeric_limits::max(), Indices, + (std::string(P.Name) + " min").c_str()); + testPattern(std::numeric_limits::max(), + std::numeric_limits::lowest(), Indices, + (std::string(P.Name) + " max").c_str()); + testPattern(-10, + 20, Indices, + (std::string(P.Name) + " mixed1").c_str()); + testPattern(30, + -20, Indices, + (std::string(P.Name) + " mixed2").c_str()); + } + } + } +} + +#define NAME(Op, MinTy, MinIdxTy, Suffix) \ + "arg" Op "_" #MinTy "_" #MinIdxTy Suffix + +// Generic test: Loop body is passed as macro parameter +#define TEST(Op, MinTy, MinIdxTy, Init, Loop, Suffix) \ + { \ + DEFINE_SCALAR_AND_VECTOR_FN2(Init, Loop) \ + checkVectorFunction(ScalarFn, VectorFn, \ + NAME(Op, MinTy, MinIdxTy, Suffix)); \ + } + +// Test with explicit VF and interleave count +#define TEST_VF_IC(Op, MinTy, MinIdxTy, Init, Loop, VF, IC, Suffix) \ + { \ + DEFINE_SCALAR_AND_VECTOR_FN2_VF_INTERLEAVE(Init, Loop, VF, IC) \ + checkVectorFunction(ScalarFn, VectorFn, \ + NAME(Op, MinTy, MinIdxTy, Suffix)); \ + } + +// Predicate-parameterized tests +#define T_BASIC_P(Op, M, I, Start, Inc, InitVal, InitIdx, Pred, Suf) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = InitIdx; \ + , \ + for (unsigned i = Start; i < TC; i += Inc) { \ + if (A[i] Pred Min) { \ + Min = A[i]; \ + MinIdx = i; \ + } \ + } return MinIdx; \ + , Suf) + +#define T_ADD_P(Op, M, I, Start, Inc, InitVal, InitIdx, Pred, Suf) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = InitIdx; \ + , \ + for (unsigned i = Start; i < TC; i += Inc) { \ + M D = A[i] + B[i]; \ + if (D Pred Min) { \ + Min = D; \ + MinIdx = i; \ + } \ + } return MinIdx; \ + , Suf) + +#define T_SEP_P(Op, M, I, InitVal, Pred, Suf) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = 0; \ + , \ + for (unsigned i = 0; i < TC; i++) { \ + if (B[i] Pred Min) { \ + MinIdx = i; \ + } \ + \ + if (A[i] Pred Min) { \ + Min = A[i]; \ + } \ + } return MinIdx; \ + , Suf) + +#define T_DIFF_P(Op, M, I, InitVal, Pred, Suf) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = 0; \ + , \ + for (unsigned i = 0; i < TC; i++) { \ + if (B[i] Pred Min) { \ + MinIdx = i; \ + } \ + \ + if (B[i] Pred Min) { \ + Min = A[i]; \ + } \ + } return MinIdx; \ + , Suf) + +// Two-predicate tests (for mismatch patterns) +#define T_2PRED(Op, M, I, InitVal, Pred1, Pred2, Suf) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = 0; \ + , \ + for (unsigned i = 0; i < TC; i++) { \ + if (A[i] Pred1 Min) { \ + MinIdx = i; \ + } \ + \ + if (A[i] Pred2 Min) { \ + Min = A[i]; \ + } \ + } return MinIdx; \ + , Suf) + +#define T_2PRED_REV(Op, M, I, InitVal, Pred1, Pred2, Suf) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = 0; \ + , \ + for (unsigned i = 0; i < TC; i++) { \ + if (A[i] Pred1 Min) { \ + Min = A[i]; \ + } \ + \ + if (A[i] Pred2 Min) { \ + MinIdx = i; \ + } \ + } return MinIdx; \ + , Suf) + +#define T_TRUNC(Op, M, I, InitVal) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = 0; \ + , \ + for (uint64_t i = 0; i < TC; i++) { \ + if (A[i] <= Min) { \ + Min = A[i]; \ + MinIdx = i; \ + } \ + } return MinIdx; \ + , "_with_trunc") + +#define T_DEC(Op, M, I, InitVal) \ + TEST( \ + Op, M, I, M Min = InitVal; I MinIdx = 0; \ + , \ + for (unsigned i = TC; i-- > 0;) { \ + if (A[i] <= Min) { \ + Min = A[i]; \ + MinIdx = i; \ + } \ + } return MinIdx; \ + , "_induction_decrement") + +// Variants with explicit VF and interleave count. +#define T_BASIC_VF_IC(Op, M, I, Start, Inc, InitVal, InitIdx, VF, IC, Suf) \ + TEST_VF_IC( \ + Op, M, I, M Min = InitVal; I MinIdx = InitIdx; \ + , \ + for (unsigned i = Start; i < TC; i += Inc) { \ + if (A[i] <= Min) { \ + Min = A[i]; \ + MinIdx = i; \ + } \ + } return MinIdx; \ + , VF, IC, Suf) + +#define RUN_ALL_TESTS_FOR_TYPE(M, I) \ + T_BASIC_P("min", M, I, 0, 1, std::numeric_limits::max(), 0, <=, \ + "_start_0") \ + T_ADD_P("min", M, I, 0, 2, std::numeric_limits::max(), 0, <=, \ + "_start_0_inc_2") \ + T_BASIC_P("min", M, I, 0, 1, std::numeric_limits::max(), 2, <=, \ + "_start_2") \ + T_BASIC_P("min", M, I, 0, 1, std::numeric_limits::max(), \ + std::numeric_limits::max(), <=, "_start_0_min_idx_neg1") \ + T_BASIC_P("min", M, I, 3, 1, std::numeric_limits::max(), 3, <=, \ + "_start_3_min_idx_3") \ + T_ADD_P("min", M, I, 3, 1, std::numeric_limits::max(), 2, <=, \ + "_start_3_min_idx_2") \ + T_BASIC_P("min", M, I, 3, 1, std::numeric_limits::max(), 4, <=, \ + "_start_3_min_idx_4") \ + T_SEP_P("min", M, I, std::numeric_limits::max(), <=, "_separate_selects") \ + T_DIFF_P("min", M, I, std::numeric_limits::max(), <=, "_different_selects") + +// Run tests with different predicates for additional coverage +#define RUN_PRED_TESTS_FOR_TYPE(M, I) \ + T_BASIC_P("min", M, I, 0, 1, std::numeric_limits::max(), 0, <, \ + "_start_0_lt") \ + T_BASIC_P("max", M, I, 0, 1, std::numeric_limits::lowest(), 0, >=, \ + "_start_0_ge") \ + T_BASIC_P("max", M, I, 0, 1, std::numeric_limits::lowest(), 0, >, \ + "_start_0_gt") \ + T_SEP_P("min", M, I, std::numeric_limits::max(), <, \ + "_separate_selects_lt") \ + T_SEP_P("max", M, I, std::numeric_limits::lowest(), >, \ + "_separate_selects_gt") + +// Run tests with explicit VF and interleave count +#define RUN_VF_IC_TESTS_FOR_TYPE(M, I) \ + T_BASIC_VF_IC("min", M, I, 0, 1, std::numeric_limits::max(), 0, 4, 1, \ + "_start_0_vf4_ic1") \ + T_BASIC_VF_IC("min", M, I, 0, 1, std::numeric_limits::max(), 0, 2, 2, \ + "_start_0_vf2_ic2") + +#define RUN_TRUNC_TESTS_FOR_TYPE(M, I) \ + T_TRUNC("min", M, I, std::numeric_limits::max()) \ + T_DEC("min", M, I, std::numeric_limits::max()) + +int main(void) { + rng = std::mt19937(15); + + // Run tests for unsigned 8-bit types + RUN_ALL_TESTS_FOR_TYPE(uint8_t, uint8_t) + RUN_PRED_TESTS_FOR_TYPE(uint8_t, uint8_t) + + // Run tests for unsigned 16-bit types + RUN_ALL_TESTS_FOR_TYPE(uint16_t, uint16_t) + RUN_PRED_TESTS_FOR_TYPE(uint16_t, uint16_t) + + // Run tests for unsigned 32-bit types + RUN_ALL_TESTS_FOR_TYPE(uint32_t, uint32_t) + RUN_PRED_TESTS_FOR_TYPE(uint32_t, uint32_t) + RUN_VF_IC_TESTS_FOR_TYPE(uint32_t, uint32_t) + + // Run tests for unsigned 64-bit types + RUN_ALL_TESTS_FOR_TYPE(uint64_t, uint64_t) + RUN_PRED_TESTS_FOR_TYPE(uint64_t, uint64_t) + RUN_VF_IC_TESTS_FOR_TYPE(uint64_t, uint64_t) + + // Run tests for signed 8-bit types + RUN_ALL_TESTS_FOR_TYPE(int8_t, int8_t) + RUN_PRED_TESTS_FOR_TYPE(int8_t, int8_t) + + // Run tests for signed 16-bit types + RUN_ALL_TESTS_FOR_TYPE(int16_t, int16_t) + RUN_PRED_TESTS_FOR_TYPE(int16_t, int16_t) + + // Run tests for signed 32-bit types + RUN_ALL_TESTS_FOR_TYPE(int32_t, int32_t) + RUN_PRED_TESTS_FOR_TYPE(int32_t, int32_t) + RUN_VF_IC_TESTS_FOR_TYPE(int32_t, int32_t) + + // Run tests for signed 64-bit types + RUN_ALL_TESTS_FOR_TYPE(int64_t, int64_t) + RUN_PRED_TESTS_FOR_TYPE(int64_t, int64_t) + + // Run tests with mixed signedness (unsigned min, signed idx) + RUN_ALL_TESTS_FOR_TYPE(uint8_t, int8_t) + RUN_PRED_TESTS_FOR_TYPE(uint8_t, int8_t) + + RUN_ALL_TESTS_FOR_TYPE(uint16_t, int16_t) + RUN_PRED_TESTS_FOR_TYPE(uint16_t, int16_t) + + RUN_ALL_TESTS_FOR_TYPE(uint32_t, int32_t) + RUN_PRED_TESTS_FOR_TYPE(uint32_t, int32_t) + + RUN_ALL_TESTS_FOR_TYPE(uint64_t, int64_t) + RUN_PRED_TESTS_FOR_TYPE(uint64_t, int64_t) + + // Run tests with mixed signedness (signed min, unsigned idx) + RUN_ALL_TESTS_FOR_TYPE(int8_t, uint8_t) + RUN_PRED_TESTS_FOR_TYPE(int8_t, uint8_t) + + RUN_ALL_TESTS_FOR_TYPE(int16_t, uint16_t) + RUN_PRED_TESTS_FOR_TYPE(int16_t, uint16_t) + + RUN_ALL_TESTS_FOR_TYPE(int32_t, uint32_t) + RUN_PRED_TESTS_FOR_TYPE(int32_t, uint32_t) + + RUN_ALL_TESTS_FOR_TYPE(int64_t, uint64_t) + RUN_PRED_TESTS_FOR_TYPE(int64_t, uint64_t) + + // Run truncation tests with original type combination + RUN_TRUNC_TESTS_FOR_TYPE(uint64_t, uint32_t) + + // Run mismatch tests only for uint64_t (matching original scope) + T_2PRED("mixed", uint64_t, uint64_t, std::numeric_limits::max(), >=, + <=, "_predicate_mismatch_0") + T_2PRED_REV("mixed", uint64_t, uint64_t, std::numeric_limits::max(), + <=, >=, "_predicate_mismatch_1") + + return 0; +} diff --git a/SingleSource/UnitTests/Vectorizer/argmin-argmax.reference_output b/SingleSource/UnitTests/Vectorizer/argmin-argmax.reference_output new file mode 100644 index 0000000000..ee44ea6f0d --- /dev/null +++ b/SingleSource/UnitTests/Vectorizer/argmin-argmax.reference_output @@ -0,0 +1,235 @@ +Checking argmin_uint8_t_uint8_t_start_0 +Checking argmin_uint8_t_uint8_t_start_0_inc_2 +Checking argmin_uint8_t_uint8_t_start_2 +Checking argmin_uint8_t_uint8_t_start_0_min_idx_neg1 +Checking argmin_uint8_t_uint8_t_start_3_min_idx_3 +Checking argmin_uint8_t_uint8_t_start_3_min_idx_2 +Checking argmin_uint8_t_uint8_t_start_3_min_idx_4 +Checking argmin_uint8_t_uint8_t_separate_selects +Checking argmin_uint8_t_uint8_t_different_selects +Checking argmin_uint8_t_uint8_t_start_0_lt +Checking argmax_uint8_t_uint8_t_start_0_ge +Checking argmax_uint8_t_uint8_t_start_0_gt +Checking argmin_uint8_t_uint8_t_separate_selects_lt +Checking argmax_uint8_t_uint8_t_separate_selects_gt +Checking argmin_uint16_t_uint16_t_start_0 +Checking argmin_uint16_t_uint16_t_start_0_inc_2 +Checking argmin_uint16_t_uint16_t_start_2 +Checking argmin_uint16_t_uint16_t_start_0_min_idx_neg1 +Checking argmin_uint16_t_uint16_t_start_3_min_idx_3 +Checking argmin_uint16_t_uint16_t_start_3_min_idx_2 +Checking argmin_uint16_t_uint16_t_start_3_min_idx_4 +Checking argmin_uint16_t_uint16_t_separate_selects +Checking argmin_uint16_t_uint16_t_different_selects +Checking argmin_uint16_t_uint16_t_start_0_lt +Checking argmax_uint16_t_uint16_t_start_0_ge +Checking argmax_uint16_t_uint16_t_start_0_gt +Checking argmin_uint16_t_uint16_t_separate_selects_lt +Checking argmax_uint16_t_uint16_t_separate_selects_gt +Checking argmin_uint32_t_uint32_t_start_0 +Checking argmin_uint32_t_uint32_t_start_0_inc_2 +Checking argmin_uint32_t_uint32_t_start_2 +Checking argmin_uint32_t_uint32_t_start_0_min_idx_neg1 +Checking argmin_uint32_t_uint32_t_start_3_min_idx_3 +Checking argmin_uint32_t_uint32_t_start_3_min_idx_2 +Checking argmin_uint32_t_uint32_t_start_3_min_idx_4 +Checking argmin_uint32_t_uint32_t_separate_selects +Checking argmin_uint32_t_uint32_t_different_selects +Checking argmin_uint32_t_uint32_t_start_0_lt +Checking argmax_uint32_t_uint32_t_start_0_ge +Checking argmax_uint32_t_uint32_t_start_0_gt +Checking argmin_uint32_t_uint32_t_separate_selects_lt +Checking argmax_uint32_t_uint32_t_separate_selects_gt +Checking argmin_uint32_t_uint32_t_start_0_vf4_ic1 +Checking argmin_uint32_t_uint32_t_start_0_vf2_ic2 +Checking argmin_uint64_t_uint64_t_start_0 +Checking argmin_uint64_t_uint64_t_start_0_inc_2 +Checking argmin_uint64_t_uint64_t_start_2 +Checking argmin_uint64_t_uint64_t_start_0_min_idx_neg1 +Checking argmin_uint64_t_uint64_t_start_3_min_idx_3 +Checking argmin_uint64_t_uint64_t_start_3_min_idx_2 +Checking argmin_uint64_t_uint64_t_start_3_min_idx_4 +Checking argmin_uint64_t_uint64_t_separate_selects +Checking argmin_uint64_t_uint64_t_different_selects +Checking argmin_uint64_t_uint64_t_start_0_lt +Checking argmax_uint64_t_uint64_t_start_0_ge +Checking argmax_uint64_t_uint64_t_start_0_gt +Checking argmin_uint64_t_uint64_t_separate_selects_lt +Checking argmax_uint64_t_uint64_t_separate_selects_gt +Checking argmin_uint64_t_uint64_t_start_0_vf4_ic1 +Checking argmin_uint64_t_uint64_t_start_0_vf2_ic2 +Checking argmin_int8_t_int8_t_start_0 +Checking argmin_int8_t_int8_t_start_0_inc_2 +Checking argmin_int8_t_int8_t_start_2 +Checking argmin_int8_t_int8_t_start_0_min_idx_neg1 +Checking argmin_int8_t_int8_t_start_3_min_idx_3 +Checking argmin_int8_t_int8_t_start_3_min_idx_2 +Checking argmin_int8_t_int8_t_start_3_min_idx_4 +Checking argmin_int8_t_int8_t_separate_selects +Checking argmin_int8_t_int8_t_different_selects +Checking argmin_int8_t_int8_t_start_0_lt +Checking argmax_int8_t_int8_t_start_0_ge +Checking argmax_int8_t_int8_t_start_0_gt +Checking argmin_int8_t_int8_t_separate_selects_lt +Checking argmax_int8_t_int8_t_separate_selects_gt +Checking argmin_int16_t_int16_t_start_0 +Checking argmin_int16_t_int16_t_start_0_inc_2 +Checking argmin_int16_t_int16_t_start_2 +Checking argmin_int16_t_int16_t_start_0_min_idx_neg1 +Checking argmin_int16_t_int16_t_start_3_min_idx_3 +Checking argmin_int16_t_int16_t_start_3_min_idx_2 +Checking argmin_int16_t_int16_t_start_3_min_idx_4 +Checking argmin_int16_t_int16_t_separate_selects +Checking argmin_int16_t_int16_t_different_selects +Checking argmin_int16_t_int16_t_start_0_lt +Checking argmax_int16_t_int16_t_start_0_ge +Checking argmax_int16_t_int16_t_start_0_gt +Checking argmin_int16_t_int16_t_separate_selects_lt +Checking argmax_int16_t_int16_t_separate_selects_gt +Checking argmin_int32_t_int32_t_start_0 +Checking argmin_int32_t_int32_t_start_0_inc_2 +Checking argmin_int32_t_int32_t_start_2 +Checking argmin_int32_t_int32_t_start_0_min_idx_neg1 +Checking argmin_int32_t_int32_t_start_3_min_idx_3 +Checking argmin_int32_t_int32_t_start_3_min_idx_2 +Checking argmin_int32_t_int32_t_start_3_min_idx_4 +Checking argmin_int32_t_int32_t_separate_selects +Checking argmin_int32_t_int32_t_different_selects +Checking argmin_int32_t_int32_t_start_0_lt +Checking argmax_int32_t_int32_t_start_0_ge +Checking argmax_int32_t_int32_t_start_0_gt +Checking argmin_int32_t_int32_t_separate_selects_lt +Checking argmax_int32_t_int32_t_separate_selects_gt +Checking argmin_int32_t_int32_t_start_0_vf4_ic1 +Checking argmin_int32_t_int32_t_start_0_vf2_ic2 +Checking argmin_int64_t_int64_t_start_0 +Checking argmin_int64_t_int64_t_start_0_inc_2 +Checking argmin_int64_t_int64_t_start_2 +Checking argmin_int64_t_int64_t_start_0_min_idx_neg1 +Checking argmin_int64_t_int64_t_start_3_min_idx_3 +Checking argmin_int64_t_int64_t_start_3_min_idx_2 +Checking argmin_int64_t_int64_t_start_3_min_idx_4 +Checking argmin_int64_t_int64_t_separate_selects +Checking argmin_int64_t_int64_t_different_selects +Checking argmin_int64_t_int64_t_start_0_lt +Checking argmax_int64_t_int64_t_start_0_ge +Checking argmax_int64_t_int64_t_start_0_gt +Checking argmin_int64_t_int64_t_separate_selects_lt +Checking argmax_int64_t_int64_t_separate_selects_gt +Checking argmin_uint8_t_int8_t_start_0 +Checking argmin_uint8_t_int8_t_start_0_inc_2 +Checking argmin_uint8_t_int8_t_start_2 +Checking argmin_uint8_t_int8_t_start_0_min_idx_neg1 +Checking argmin_uint8_t_int8_t_start_3_min_idx_3 +Checking argmin_uint8_t_int8_t_start_3_min_idx_2 +Checking argmin_uint8_t_int8_t_start_3_min_idx_4 +Checking argmin_uint8_t_int8_t_separate_selects +Checking argmin_uint8_t_int8_t_different_selects +Checking argmin_uint8_t_int8_t_start_0_lt +Checking argmax_uint8_t_int8_t_start_0_ge +Checking argmax_uint8_t_int8_t_start_0_gt +Checking argmin_uint8_t_int8_t_separate_selects_lt +Checking argmax_uint8_t_int8_t_separate_selects_gt +Checking argmin_uint16_t_int16_t_start_0 +Checking argmin_uint16_t_int16_t_start_0_inc_2 +Checking argmin_uint16_t_int16_t_start_2 +Checking argmin_uint16_t_int16_t_start_0_min_idx_neg1 +Checking argmin_uint16_t_int16_t_start_3_min_idx_3 +Checking argmin_uint16_t_int16_t_start_3_min_idx_2 +Checking argmin_uint16_t_int16_t_start_3_min_idx_4 +Checking argmin_uint16_t_int16_t_separate_selects +Checking argmin_uint16_t_int16_t_different_selects +Checking argmin_uint16_t_int16_t_start_0_lt +Checking argmax_uint16_t_int16_t_start_0_ge +Checking argmax_uint16_t_int16_t_start_0_gt +Checking argmin_uint16_t_int16_t_separate_selects_lt +Checking argmax_uint16_t_int16_t_separate_selects_gt +Checking argmin_uint32_t_int32_t_start_0 +Checking argmin_uint32_t_int32_t_start_0_inc_2 +Checking argmin_uint32_t_int32_t_start_2 +Checking argmin_uint32_t_int32_t_start_0_min_idx_neg1 +Checking argmin_uint32_t_int32_t_start_3_min_idx_3 +Checking argmin_uint32_t_int32_t_start_3_min_idx_2 +Checking argmin_uint32_t_int32_t_start_3_min_idx_4 +Checking argmin_uint32_t_int32_t_separate_selects +Checking argmin_uint32_t_int32_t_different_selects +Checking argmin_uint32_t_int32_t_start_0_lt +Checking argmax_uint32_t_int32_t_start_0_ge +Checking argmax_uint32_t_int32_t_start_0_gt +Checking argmin_uint32_t_int32_t_separate_selects_lt +Checking argmax_uint32_t_int32_t_separate_selects_gt +Checking argmin_uint64_t_int64_t_start_0 +Checking argmin_uint64_t_int64_t_start_0_inc_2 +Checking argmin_uint64_t_int64_t_start_2 +Checking argmin_uint64_t_int64_t_start_0_min_idx_neg1 +Checking argmin_uint64_t_int64_t_start_3_min_idx_3 +Checking argmin_uint64_t_int64_t_start_3_min_idx_2 +Checking argmin_uint64_t_int64_t_start_3_min_idx_4 +Checking argmin_uint64_t_int64_t_separate_selects +Checking argmin_uint64_t_int64_t_different_selects +Checking argmin_uint64_t_int64_t_start_0_lt +Checking argmax_uint64_t_int64_t_start_0_ge +Checking argmax_uint64_t_int64_t_start_0_gt +Checking argmin_uint64_t_int64_t_separate_selects_lt +Checking argmax_uint64_t_int64_t_separate_selects_gt +Checking argmin_int8_t_uint8_t_start_0 +Checking argmin_int8_t_uint8_t_start_0_inc_2 +Checking argmin_int8_t_uint8_t_start_2 +Checking argmin_int8_t_uint8_t_start_0_min_idx_neg1 +Checking argmin_int8_t_uint8_t_start_3_min_idx_3 +Checking argmin_int8_t_uint8_t_start_3_min_idx_2 +Checking argmin_int8_t_uint8_t_start_3_min_idx_4 +Checking argmin_int8_t_uint8_t_separate_selects +Checking argmin_int8_t_uint8_t_different_selects +Checking argmin_int8_t_uint8_t_start_0_lt +Checking argmax_int8_t_uint8_t_start_0_ge +Checking argmax_int8_t_uint8_t_start_0_gt +Checking argmin_int8_t_uint8_t_separate_selects_lt +Checking argmax_int8_t_uint8_t_separate_selects_gt +Checking argmin_int16_t_uint16_t_start_0 +Checking argmin_int16_t_uint16_t_start_0_inc_2 +Checking argmin_int16_t_uint16_t_start_2 +Checking argmin_int16_t_uint16_t_start_0_min_idx_neg1 +Checking argmin_int16_t_uint16_t_start_3_min_idx_3 +Checking argmin_int16_t_uint16_t_start_3_min_idx_2 +Checking argmin_int16_t_uint16_t_start_3_min_idx_4 +Checking argmin_int16_t_uint16_t_separate_selects +Checking argmin_int16_t_uint16_t_different_selects +Checking argmin_int16_t_uint16_t_start_0_lt +Checking argmax_int16_t_uint16_t_start_0_ge +Checking argmax_int16_t_uint16_t_start_0_gt +Checking argmin_int16_t_uint16_t_separate_selects_lt +Checking argmax_int16_t_uint16_t_separate_selects_gt +Checking argmin_int32_t_uint32_t_start_0 +Checking argmin_int32_t_uint32_t_start_0_inc_2 +Checking argmin_int32_t_uint32_t_start_2 +Checking argmin_int32_t_uint32_t_start_0_min_idx_neg1 +Checking argmin_int32_t_uint32_t_start_3_min_idx_3 +Checking argmin_int32_t_uint32_t_start_3_min_idx_2 +Checking argmin_int32_t_uint32_t_start_3_min_idx_4 +Checking argmin_int32_t_uint32_t_separate_selects +Checking argmin_int32_t_uint32_t_different_selects +Checking argmin_int32_t_uint32_t_start_0_lt +Checking argmax_int32_t_uint32_t_start_0_ge +Checking argmax_int32_t_uint32_t_start_0_gt +Checking argmin_int32_t_uint32_t_separate_selects_lt +Checking argmax_int32_t_uint32_t_separate_selects_gt +Checking argmin_int64_t_uint64_t_start_0 +Checking argmin_int64_t_uint64_t_start_0_inc_2 +Checking argmin_int64_t_uint64_t_start_2 +Checking argmin_int64_t_uint64_t_start_0_min_idx_neg1 +Checking argmin_int64_t_uint64_t_start_3_min_idx_3 +Checking argmin_int64_t_uint64_t_start_3_min_idx_2 +Checking argmin_int64_t_uint64_t_start_3_min_idx_4 +Checking argmin_int64_t_uint64_t_separate_selects +Checking argmin_int64_t_uint64_t_different_selects +Checking argmin_int64_t_uint64_t_start_0_lt +Checking argmax_int64_t_uint64_t_start_0_ge +Checking argmax_int64_t_uint64_t_start_0_gt +Checking argmin_int64_t_uint64_t_separate_selects_lt +Checking argmax_int64_t_uint64_t_separate_selects_gt +Checking argmin_uint64_t_uint32_t_with_trunc +Checking argmin_uint64_t_uint32_t_induction_decrement +Checking argmixed_uint64_t_uint64_t_predicate_mismatch_0 +Checking argmixed_uint64_t_uint64_t_predicate_mismatch_1 +exit 0 diff --git a/SingleSource/UnitTests/Vectorizer/common.h b/SingleSource/UnitTests/Vectorizer/common.h index ac779d5447..e02abb17f5 100644 --- a/SingleSource/UnitTests/Vectorizer/common.h +++ b/SingleSource/UnitTests/Vectorizer/common.h @@ -1,6 +1,12 @@ #include #include +// Helper macros for stringification in _Pragma +#define XSTR(s) STR(s) +#define STR(s) #s +#define PRAGMA_VF(VF) _Pragma(STR(clang loop vectorize_width(VF))) +#define PRAGMA_IC(IC) _Pragma(STR(clang loop interleave_count(IC))) + #define DEFINE_SCALAR_AND_VECTOR_FN1_TYPE(Init, Loop, Type) \ auto ScalarFn = [](auto *A, Type TC) -> Type { \ Init _Pragma("clang loop vectorize(disable) interleave_count(1)") Loop \ @@ -17,6 +23,15 @@ Init _Pragma("clang loop vectorize(enable)") Loop \ }; +// Macro with explicit VF and interleave count control +#define DEFINE_SCALAR_AND_VECTOR_FN2_VF_INTERLEAVE(Init, Loop, VF, IC) \ + auto ScalarFn = [](auto *A, auto *B, unsigned TC) { \ + Init _Pragma("clang loop vectorize(disable) interleave_count(1)") Loop \ + }; \ + auto VectorFn = [](auto *A, auto *B, unsigned TC) { \ + Init PRAGMA_VF(VF) PRAGMA_IC(IC) Loop \ + }; + #define DEFINE_SCALAR_AND_VECTOR_FN2_TYPE(Init, Loop, Type) \ auto ScalarFn = [](auto *A, auto *B, Type TC) -> Type { \ Init _Pragma("clang loop vectorize(disable) interleave_count(1)") Loop \ @@ -36,34 +51,34 @@ #define DEFINE_NESTED_SCALAR_AND_VECTOR_FN4(InnerLoopCode) \ auto ScalarFn = [](auto *A, auto *B, unsigned OuterTC, unsigned InnerTC) { \ for (unsigned long i = 0; i < OuterTC; i++) { \ - _Pragma("clang loop vectorize(disable) interleave_count(1)") \ - for (unsigned long j = 0; j < InnerTC; j++) { \ + _Pragma("clang loop vectorize(disable) interleave_count(1)") for ( \ + unsigned long j = 0; j < InnerTC; j++) { \ InnerLoopCode \ } \ } \ }; \ auto VectorFn = [](auto *A, auto *B, unsigned OuterTC, unsigned InnerTC) { \ for (unsigned long i = 0; i < OuterTC; i++) { \ - _Pragma("clang loop vectorize(enable)") \ - for (unsigned long j = 0; j < InnerTC; j++) { \ + _Pragma("clang loop vectorize(enable)") for (unsigned long j = 0; \ + j < InnerTC; j++) { \ InnerLoopCode \ } \ } \ }; -#define DEFINE_NESTED_SCALAR_AND_VECTOR_FN5(InnerLoopCode) \ +#define DEFINE_NESTED_SCALAR_AND_VECTOR_FN5(InnerLoopCode) \ auto ScalarFn = [](auto *A, auto *B, unsigned OuterTC, unsigned InnerTC) { \ - for (long i = OuterTC - 1; i >= 0; i--) { \ - _Pragma("clang loop vectorize(disable) interleave_count(1)") \ - for (unsigned long j = 0; j < InnerTC; j++) { \ + for (long i = OuterTC - 1; i >= 0; i--) { \ + _Pragma("clang loop vectorize(disable) interleave_count(1)") for ( \ + unsigned long j = 0; j < InnerTC; j++) { \ InnerLoopCode \ } \ } \ }; \ auto VectorFn = [](auto *A, auto *B, unsigned OuterTC, unsigned InnerTC) { \ - for (long i = OuterTC - 1; i >= 0; i--) { \ - _Pragma("clang loop vectorize(enable)") \ - for (unsigned long j = 0; j < InnerTC; j++) { \ + for (long i = OuterTC - 1; i >= 0; i--) { \ + _Pragma("clang loop vectorize(enable)") for (unsigned long j = 0; \ + j < InnerTC; j++) { \ InnerLoopCode \ } \ } \ diff --git a/SingleSource/UnitTests/Vectorizer/index-select.cpp b/SingleSource/UnitTests/Vectorizer/index-select.cpp deleted file mode 100644 index 26cad979d6..0000000000 --- a/SingleSource/UnitTests/Vectorizer/index-select.cpp +++ /dev/null @@ -1,271 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "common.h" - -template -using Fn2Ty = std::function; -template -static void checkVectorFunction(Fn2Ty ScalarFn, - Fn2Ty VectorFn, const char *Name) { - std::cout << "Checking " << Name << "\n"; - - unsigned N = 1000; - std::unique_ptr Src1(new Ty[N]); - std::unique_ptr Src2(new Ty[N]); - init_data(Src1, N); - init_data(Src2, N); - - // Test VectorFn with different input data. - { - // Check with random inputs. - auto Reference = ScalarFn(&Src1[0], &Src2[0], N); - auto ToCheck = VectorFn(&Src1[0], &Src2[0], N); - if (Reference != ToCheck) { - std::cerr << "Miscompare\n"; - exit(1); - } - } - - { - // Check with sorted inputs. - std::sort(&Src1[0], &Src1[N]); - std::sort(&Src2[0], &Src2[N]); - auto Reference = ScalarFn(&Src1[0], &Src2[0], N); - auto ToCheck = VectorFn(&Src1[0], &Src2[0], N); - if (Reference != ToCheck) { - std::cerr << "Miscompare\n"; - exit(1); - } - } - - { - // Check with all max values. - for (unsigned I = 0; I != N; ++I) { - Src1[I] = std::numeric_limits::max(); - Src2[I] = std::numeric_limits::max(); - } - auto Reference = ScalarFn(&Src1[0], &Src2[0], N); - auto ToCheck = VectorFn(&Src1[0], &Src2[0], N); - if (Reference != ToCheck) { - std::cerr << "Miscompare\n"; - exit(1); - } - } - - { - // Check with first input all zeros and second input -1. - for (unsigned I = 0; I != N; ++I) { - Src1[I] = 0; - Src2[I] = std::numeric_limits::max(); - } - auto Reference = ScalarFn(&Src1[0], &Src2[0], N); - auto ToCheck = VectorFn(&Src1[0], &Src2[0], N); - if (Reference != ToCheck) { - std::cerr << "Miscompare\n"; - exit(1); - } - } - - { - // Check with first input all max values and second input all zeros. - for (unsigned I = 0; I != N; ++I) { - Src1[I] = std::numeric_limits::max(); - Src2[I] = 0; - } - auto Reference = ScalarFn(&Src1[0], &Src2[0], N); - auto ToCheck = VectorFn(&Src1[0], &Src2[0], N); - if (Reference != ToCheck) { - std::cerr << "Miscompare\n"; - exit(1); - } - } - - { - // Check with inputs all zero. - for (unsigned I = 0; I != N; ++I) { - Src1[I] = 0; - Src2[I] = 0; - } - auto Reference = ScalarFn(&Src1[0], &Src2[0], N); - auto ToCheck = VectorFn(&Src1[0], &Src2[0], N); - if (Reference != ToCheck) { - std::cerr << "Miscompare\n"; - exit(1); - } - } -} - -int main(void) { - rng = std::mt19937(15); - - // Tests select-minimum-index loops, where the loop selects the minimum value - // and the first index of that value. - { - // Check loop starting at index 0 and stepping by 1. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 0;, - for (unsigned I = 0; I < TC; I++) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction(ScalarFn, VectorFn, - "min_index_select_u32_u32_start_0"); - } - - { - // Check loop starting at index 0 and stepping by 1. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 0;, - for (unsigned I = 0; I < TC / 2; I += 2) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction( - ScalarFn, VectorFn, "min_index_select_u32_u32_start_0_inc_2"); - } - - { - // Check loop starting at index 0 and stepping by 1. MinIdx starting at 2. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 2;, - for (unsigned I = 0; I < TC; I++) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction(ScalarFn, VectorFn, - "min_index_select_u32_u32_start_2"); - } - - { - // Index is truncated in the loop. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 0;, - for (uint64_t I = 0; I < TC; I++) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction( - ScalarFn, VectorFn, "min_index_select_u32_u32_with_trunc"); - } - - { - // Check loop where induction is truncated. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 0;, - for (unsigned I = TC; I > 0; I--) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction( - ScalarFn, VectorFn, "min_index_select_u32_u32_induction_decrement"); - } - - { - // Check loop where both Min and MinIdx starts at the maximum value. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = std::numeric_limits::max();, - for (unsigned I = 0; I < TC; I++) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction( - ScalarFn, VectorFn, "min_index_select_u32_u32_start_0_min_idx_neg1"); - } - - { - // Check loop starting at index 3 and stepping by 1. MinIdx starts at 3. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 3;, - for (unsigned I = 3; I < TC; I++) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction( - ScalarFn, VectorFn, "min_index_select_u32_u32_start_3_min_idx_3"); - } - - { - // Check loop starting at index 3 and stepping by 1. MinIdx starts at 2. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 2;, - for (unsigned I = 3; I < TC; I++) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction( - ScalarFn, VectorFn, "min_index_select_u32_u32_start_3_min_idx_2"); - } - - { - // Check loop starting at index 3 and stepping by 1. MinIdx starts at 4. - DEFINE_SCALAR_AND_VECTOR_FN2( - uint32_t Min = std::numeric_limits::max(); - uint32_t MinIdx = 4;, - for (unsigned I = 3; I < TC; I++) { - uint32_t D = A[I] + B[I]; - if (D < Min) { - Min = D; - MinIdx = I; - } - } - return MinIdx; - ); - checkVectorFunction( - ScalarFn, VectorFn, "min_index_select_u32_u32_start_3_min_idx_4"); - } - - return 0; -}