Skip to content

Commit 06a44d2

Browse files
committed
Refactor Convolution to new structure and fix build failures
Signed-off-by: Jonathan Clohessy <[email protected]>
1 parent 22beb50 commit 06a44d2

File tree

3 files changed

+106
-21
lines changed

3 files changed

+106
-21
lines changed

include/xnnpack.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3049,6 +3049,46 @@ enum xnn_status xnn_create_convolution2d_nhwc_f16(
30493049
xnn_weights_cache_t weights_cache,
30503050
xnn_operator_t* convolution_op_out);
30513051

3052+
enum xnn_status xnn_create_convolution2d_nhwc_pf16(
3053+
uint32_t input_padding_top,
3054+
uint32_t input_padding_right,
3055+
uint32_t input_padding_bottom,
3056+
uint32_t input_padding_left,
3057+
uint32_t kernel_height,
3058+
uint32_t kernel_width,
3059+
uint32_t subsampling_height,
3060+
uint32_t subsampling_width,
3061+
uint32_t dilation_height,
3062+
uint32_t dilation_width,
3063+
uint32_t groups,
3064+
size_t group_input_channels,
3065+
size_t group_output_channels,
3066+
size_t input_channel_stride,
3067+
size_t output_channel_stride,
3068+
const void* kernel,
3069+
const void* bias,
3070+
float output_min,
3071+
float output_max,
3072+
uint32_t flags,
3073+
xnn_weights_cache_t weights_cache,
3074+
xnn_operator_t* convolution_op_out);
3075+
3076+
enum xnn_status xnn_reshape_convolution2d_nhwc_pf16(
3077+
xnn_operator_t convolution_op,
3078+
size_t batch_size,
3079+
size_t input_height,
3080+
size_t input_width,
3081+
size_t* workspace_size,
3082+
size_t* output_height_out,
3083+
size_t* output_width_out,
3084+
pthreadpool_t threadpool);
3085+
3086+
enum xnn_status xnn_setup_convolution2d_nhwc_pf16(
3087+
xnn_operator_t convolution_op,
3088+
void* workspace,
3089+
const void* input,
3090+
void* output);
3091+
30523092
enum xnn_status xnn_reshape_convolution2d_nhwc_f16(
30533093
xnn_operator_t convolution_op,
30543094
size_t batch_size,

src/operators/convolution-nhwc.c

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,6 +1941,46 @@ enum xnn_status xnn_create_convolution2d_nhwc_f32(
19411941
convolution_op_out);
19421942
}
19431943

1944+
enum xnn_status xnn_create_convolution2d_nhwc_pf16(
1945+
uint32_t input_padding_top, uint32_t input_padding_right,
1946+
uint32_t input_padding_bottom, uint32_t input_padding_left,
1947+
uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
1948+
uint32_t subsampling_width, uint32_t dilation_height,
1949+
uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
1950+
size_t group_output_channels, size_t input_channel_stride,
1951+
size_t output_channel_stride, const void* kernel, const void* bias,
1952+
float output_min, float output_max, uint32_t flags,
1953+
xnn_weights_cache_t weights_cache,
1954+
xnn_operator_t* convolution_op_out) {
1955+
struct convolution2d_nhwc_context context = {
1956+
.input_padding_top = input_padding_top,
1957+
.input_padding_right = input_padding_right,
1958+
.input_padding_bottom = input_padding_bottom,
1959+
.input_padding_left = input_padding_left,
1960+
.kernel_height = kernel_height,
1961+
.kernel_width = kernel_width,
1962+
.subsampling_height = subsampling_height,
1963+
.subsampling_width = subsampling_width,
1964+
.dilation_height = dilation_height,
1965+
.dilation_width = dilation_width,
1966+
.groups = groups,
1967+
.group_input_channels = group_input_channels,
1968+
.group_output_channels = group_output_channels,
1969+
.input_channel_stride = input_channel_stride,
1970+
.output_channel_stride = output_channel_stride,
1971+
.kernel = kernel,
1972+
.bias = bias,
1973+
.output_min = output_min,
1974+
.output_max = output_max,
1975+
.flags = flags,
1976+
.weights_cache = weights_cache,
1977+
.gemm_config = xnn_init_pf16_gemm_config(),
1978+
.operator_type = xnn_operator_type_convolution_nhwc_pf16,
1979+
};
1980+
return create_convolution2d_nhwc_helper(&f16_variant, &context,
1981+
convolution_op_out);
1982+
}
1983+
19441984
enum xnn_status xnn_create_convolution2d_nhwc_f32_f16(
19451985
uint32_t input_padding_top, uint32_t input_padding_right,
19461986
uint32_t input_padding_bottom, uint32_t input_padding_left,
@@ -2974,6 +3014,22 @@ enum xnn_status xnn_reshape_convolution2d_nhwc_f32(
29743014
output_width_out, threadpool);
29753015
}
29763016

3017+
enum xnn_status xnn_reshape_convolution2d_nhwc_pf16(
3018+
xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
3019+
size_t input_width, size_t* workspace_size, size_t* output_height_out,
3020+
size_t* output_width_out, pthreadpool_t threadpool) {
3021+
return reshape_convolution2d_nhwc(
3022+
convolution_op, xnn_operator_type_convolution_nhwc_pf16, batch_size,
3023+
input_height, input_width,
3024+
/*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT16,
3025+
/*log2_filter_element_size=*/XNN_LOG2_SIZEOF_FLOAT16,
3026+
/*log2_accumulator_element_size=*/XNN_LOG2_SIZEOF_FLOAT16,
3027+
/*extra_weights_elements_size=*/sizeof(uint16_t),
3028+
/*log2_output_element_size=*/XNN_LOG2_SIZEOF_FLOAT16,
3029+
/*dynamic_quantization=*/false, workspace_size, output_height_out,
3030+
output_width_out, threadpool);
3031+
}
3032+
29773033
static enum xnn_status setup_igemm(xnn_operator_t convolution_op,
29783034
void* workspace,
29793035
uint32_t log2_input_element_size) {
@@ -3172,6 +3228,16 @@ enum xnn_status xnn_setup_convolution2d_nhwc_f16(xnn_operator_t convolution_op,
31723228
/*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT16);
31733229
}
31743230

3231+
enum xnn_status xnn_setup_convolution2d_nhwc_pf16(xnn_operator_t convolution_op,
3232+
void* workspace,
3233+
const void* input,
3234+
void* output) {
3235+
return setup_convolution2d_nhwc(
3236+
convolution_op, xnn_operator_type_convolution_nhwc_pf16, workspace, input,
3237+
output, /*quantization_params=*/NULL,
3238+
/*log2_input_element_size=*/XNN_LOG2_SIZEOF_FLOAT16);
3239+
}
3240+
31753241
enum xnn_status xnn_setup_convolution2d_nhwc_f32(xnn_operator_t convolution_op,
31763242
void* workspace,
31773243
const float* input,

src/xnnpack/internal.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -561,27 +561,6 @@ enum xnn_status xnn_create_fully_connected_nc_qdu8_f32_qb4w_f16_scales(
561561
float output_min, float output_max, uint32_t flags,
562562
xnn_weights_cache_t weights_cache, xnn_operator_t* fully_connected_op_out);
563563

564-
enum xnn_status xnn_create_convolution2d_nhwc_pf16(
565-
uint32_t input_padding_top, uint32_t input_padding_right,
566-
uint32_t input_padding_bottom, uint32_t input_padding_left,
567-
uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
568-
uint32_t subsampling_width, uint32_t dilation_height,
569-
uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
570-
size_t group_output_channels, size_t input_channel_stride,
571-
size_t output_channel_stride, const void* kernel, const void* bias,
572-
float output_min, float output_max, uint32_t flags,
573-
xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
574-
575-
enum xnn_status xnn_reshape_convolution2d_nhwc_pf16(
576-
xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
577-
size_t input_width, size_t* workspace_size, size_t* output_height_out,
578-
size_t* output_width_out, pthreadpool_t threadpool);
579-
580-
enum xnn_status xnn_setup_convolution2d_nhwc_pf16(xnn_operator_t convolution_op,
581-
void* workspace,
582-
const void* input,
583-
void* output);
584-
585564
enum xnn_status xnn_create_convolution2d_nhwc_pqs8_qs8_qc8w(
586565
uint32_t input_padding_top, uint32_t input_padding_right,
587566
uint32_t input_padding_bottom, uint32_t input_padding_left,

0 commit comments

Comments
 (0)