@@ -53,6 +53,75 @@ IPEX_FORCE_INLINE void move_ker(at::Half* out, const float* in, int64_t len) {
5353 cvt_fp32_to_fp16 (out, in, len);
5454}
5555
56+ template <>
57+ IPEX_FORCE_INLINE void move_ker (
58+ at::Half* out,
59+ const at::Half* in,
60+ int64_t len) {
61+ int64_t i = 0 ;
62+ #pragma unroll(4)
63+ for (i = 0 ; i < len - 31 ; i += 32 ) {
64+ auto in0 = _mm512_loadu_si512 (in + i);
65+ _mm512_storeu_si512 (out + i, in0);
66+ }
67+
68+ if (i < len) {
69+ auto mask = (1 << (len - i)) - 1 ;
70+ auto in0 = _mm512_maskz_loadu_epi16 (mask, in + i);
71+ _mm512_mask_storeu_epi16 (out + i, mask, in0);
72+ }
73+ }
74+
75+ static IPEX_FORCE_INLINE void zero_ker (at::Half* out, int64_t len) {
76+ int64_t i = 0 ;
77+ __m512i zero_512 = _mm512_setzero_si512 ();
78+ #pragma unroll(4)
79+ for (i = 0 ; i < len - 31 ; i += 32 ) {
80+ _mm512_storeu_si512 (out + i, zero_512);
81+ }
82+
83+ if (i < len) {
84+ auto mask = ((1 << (len - i)) - 1 );
85+ _mm512_mask_storeu_epi16 (out + i, mask, zero_512);
86+ }
87+ }
88+
89+ template <>
90+ IPEX_FORCE_INLINE void add_ker (
91+ at::Half* inout,
92+ const at::Half* in,
93+ int64_t len) {
94+ int64_t i = 0 ;
95+ #pragma unroll(2)
96+ for (i = 0 ; i < len - 31 ; i += 32 ) {
97+ auto inout1 = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(inout + i)));
98+ auto inout2 =
99+ cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(inout + i + 16 )));
100+ auto in1 = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(in + i)));
101+ auto in2 = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(in + i + 16 )));
102+ inout1 = _mm512_add_ps (inout1, in1);
103+ inout2 = _mm512_add_ps (inout2, in2);
104+ _mm256_storeu_si256 ((__m256i*)(inout + i), cvt_fp32_to_fp16 (inout1));
105+ _mm256_storeu_si256 ((__m256i*)(inout + i + 16 ), cvt_fp32_to_fp16 (inout2));
106+ }
107+
108+ if (i < len - 15 ) {
109+ auto inout1 = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(inout + i)));
110+ auto in1 = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(in + i)));
111+ inout1 = _mm512_add_ps (inout1, in1);
112+ _mm256_storeu_si256 ((__m256i*)(inout + i), cvt_fp32_to_fp16 (inout1));
113+ i += 16 ;
114+ }
115+
116+ if (i < len) {
117+ auto mask = (1 << (len - i)) - 1 ;
118+ auto inout1 = cvt_fp16_to_fp32 (_mm256_maskz_loadu_epi16 (mask, inout + i));
119+ auto in1 = cvt_fp16_to_fp32 (_mm256_maskz_loadu_epi16 (mask, in + i));
120+ inout1 = _mm512_add_ps (inout1, in1);
121+ _mm256_mask_storeu_epi16 (inout + i, mask, cvt_fp32_to_fp16 (inout1));
122+ }
123+ }
124+
56125template <>
57126IPEX_FORCE_INLINE void add_ker (float * inout, const at::Half* in, int64_t len) {
58127 int64_t i = 0 ;
@@ -85,6 +154,39 @@ IPEX_FORCE_INLINE void add_ker(float* inout, const at::Half* in, int64_t len) {
85154 }
86155}
87156
157+ template <>
158+ IPEX_FORCE_INLINE void add_ker (at::Half* inout, const float * in, int64_t len) {
159+ int64_t i = 0 ;
160+ #pragma unroll(2)
161+ for (i = 0 ; i < len - 31 ; i += 32 ) {
162+ auto in1 = _mm512_loadu_ps (in + i);
163+ auto in2 = _mm512_loadu_ps (in + i + 16 );
164+ auto inout1 = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(inout + i)));
165+ auto inout2 =
166+ cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(inout + i + 16 )));
167+ inout1 = _mm512_add_ps (inout1, in1);
168+ inout2 = _mm512_add_ps (inout2, in2);
169+ _mm256_storeu_si256 ((__m256i*)(inout + i), cvt_fp32_to_fp16 (inout1));
170+ _mm256_storeu_si256 ((__m256i*)(inout + i + 16 ), cvt_fp32_to_fp16 (inout2));
171+ }
172+
173+ if (i < len - 15 ) {
174+ auto in1 = _mm512_loadu_ps (in + i);
175+ auto inout1 = cvt_fp16_to_fp32 (_mm256_loadu_si256 ((__m256i*)(inout + i)));
176+ inout1 = _mm512_add_ps (inout1, in1);
177+ _mm256_storeu_si256 ((__m256i*)(inout + i), cvt_fp32_to_fp16 (inout1));
178+ i += 16 ;
179+ }
180+
181+ if (i < len) {
182+ auto mask = (1 << (len - i)) - 1 ;
183+ auto in1 = _mm512_maskz_loadu_ps (mask, in + i);
184+ auto inout1 = cvt_fp16_to_fp32 (_mm256_maskz_loadu_epi16 (mask, inout + i));
185+ inout1 = _mm512_add_ps (inout1, in1);
186+ _mm256_mask_storeu_epi16 (inout + i, mask, cvt_fp32_to_fp16 (inout1));
187+ }
188+ }
189+
88190} // namespace kernel
89191} // namespace cpu
90192} // namespace torch_ipex
0 commit comments