From 853fbfc3d947e025889d5e2579e74d064cfb51f9 Mon Sep 17 00:00:00 2001 From: Tao Hu <2441597844@qq.com> Date: Mon, 24 Jul 2017 21:41:33 +0800 Subject: [PATCH 1/6] Add files via upload --- include/caffe/util/winograd_inference.hpp | 358 ++++++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 include/caffe/util/winograd_inference.hpp diff --git a/include/caffe/util/winograd_inference.hpp b/include/caffe/util/winograd_inference.hpp new file mode 100644 index 000000000..9be0d7fb3 --- /dev/null +++ b/include/caffe/util/winograd_inference.hpp @@ -0,0 +1,358 @@ + +#ifndef CAFFE_UTIL_WINOGRAD_INFERENCE_HPP_ +#define CAFFE_UTIL_WINOGRAD_INFERENCE_HPP_ + +//winograd for cpu inference + +#include "caffe/blob.hpp" + +#define DEBUG_WINOGRAD 0 + +#if DEBUG_WINOGRAD + #include +#endif + +namespace WINOGRAD_INFERENCE { + + const enum WINOGRAD_MATRIX { + WINOGRAD_A = 0, + WINOGRAD_B, + WINOGRAD_G, + }; + const enum WINOGRAD_ALG { + WT_8X8_F_6X6_3X3 = 0, + WT_6X6_F_4X4_3X3, + WT_8X8_F_4X4_5X5, + }; + + const int WINOGRAD_MATRIX_NUM = 3; + const int WINOGRAD_ALG_NUM = 3; + + template + struct WinogradParameters{}; + + /** + * compute Kronecker product of in1 and in2, where in1 is a m by n matrix and in2 is a p by q matrix + * + * @params out an (m*p) by (n*q) matrix stored in row major + * @params in1 an m by n matrix stored in row major + * @params in2 an p by q matrix stored in row major + */ + void kronecker_product(float *out, const float *in1, const float *in2, int m, int n, int p, int q); + + //singleton, precomputation before inference + void winograd2D_initialize(); + + template<> + struct WinogradParameters + { + // wt6x6, F(4x4,3x3) + private: + static const int O = 4; + static const int K = 3; + static const int T = O + K - 1; + + static const float *getG() { + static const float G[T*K] = { + 1. / 4., 0, 0, + -1. / 6., -1. / 6., -1. / 6., + -1. / 6., 1. / 6., -1. / 6., + 1. / 24., 1. / 12., 1. / 6., + 1. / 24., -1. / 12., 1. / 6., + 0, 0, 1, + }; + return G; + } + + static const float *getA() { + static const float A[T*O] = { + 1, 0, 0, 0, + 1, 1, 1, 1, + 1, -1, 1, -1, + 1, 2, 4, 8, + 1, -2, 4, -8, + 0, 0, 0, 1, + }; + return A; + } + + static const float *getB() { + static const float B[T*T] = { + 4, 0, 0, 0, 0, 0, + 0, -4, 4, -2, 2, 4, + -5, -4, -4, -1, -1, 0, + 0, 1, -1, 2, -2, -5, + 1, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 1, + }; + return B; + }; + + public: + static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) { + +#if DEBUG_WINOGRAD + assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G); +#endif + switch (mat) { + + case WINOGRAD_A: row = T; col = O; return getA(); + case WINOGRAD_B: row = T; col = T; return getB(); + case WINOGRAD_G: row = T; col = K; return getG(); + + } + + } + + }; + + template<> + struct WinogradParameters + { + + private: + + // wt8x8, F(6x6,3x3) + + static const int O = 6; + static const int K = 3; + static const int T = O + K - 1; + + public: + static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) { + +#if DEBUG_WINOGRAD + assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G); +#endif + switch (mat) { + + case WINOGRAD_A: row = T; col = O; return getA(); + case WINOGRAD_B: row = T; col = T; return getB(); + case WINOGRAD_G: row = T; col = K; return getG(); + + } + + } + + private: + static const float *getG() { + static const float G[T*K] = { + 1.f, 0.f , 0.f , + -2.f / 9 , -2.f / 9 , -2.f / 9, + -2.f / 9 , 2.f / 9 , -2.f / 9, + 1.f / 90 , 1.f / 45 , 2.f / 45, + 1.f / 90 , -1.f / 45 , 2.f / 45, + 32.f / 45, 16.f / 45, 8.f / 45, + 32.f / 45, -16.f / 45, 8.f / 45, + 0.f , 0.f , 1.f , + }; + return G; + } + + static const float *getA() { + static const float A[T*(T - K + 1)] = { + 1 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, + 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, + 1 * 1.f, -1 * 1.f, 1 * 1.f, -1 * 1.f, 1 * 1.f, -1 * 1.f, + 1 * 1.f, 2 * 1.f, 4 * 1.f, 8 * 1.f, 16 * 1.f, 32 * 1.f, + 1 * 1.f, -2 * 1.f, 4 * 1.f, -8 * 1.f, 16 * 1.f, -32 * 1.f, + 1 * 1.f, 0.5*1.f, 0.25*1.f, 0.125*1.f, 0.0625*1.f, 0.03125*1.f, + 1 * 1.f, -0.5*1.f, 0.25*1.f, -0.125*1.f, 0.0625*1.f, -0.03125*1.f, + 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 1 * 1.f, + }; + return A; + } + + static const float *getB() { + static const float B[T*T] = { + 1 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, + 0 * 1.f, 1 * 1.f, -1 * 1.f, 0.5*1.f, -0.5*1.f, 2 * 1.f, -2 * 1.f, -1 * 1.f, + -5.25*1.f, 1 * 1.f, 1 * 1.f, 0.25*1.f, 0.25*1.f, 4 * 1.f, 4 * 1.f, 0 * 1.f, + 0 * 1.f, -4.25*1.f, 4.25*1.f, -2.5*1.f, 2.5*1.f, -2.5*1.f, 2.5*1.f, 5.25*1.f, + 5.25*1.f, -4.25*1.f, -4.25*1.f, -1.25*1.f, -1.25*1.f, -5 * 1.f, -5 * 1.f, 0 * 1.f, + 0 * 1.f, 1 * 1.f, -1 * 1.f, 2 * 1.f, -2 * 1.f, 0.5*1.f, -0.5*1.f, -5.25*1.f, + -1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 1 * 1.f, 0 * 1.f, + 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 0 * 1.f, 1 * 1.f, + }; + return B; + }; + }; + + template<> + struct WinogradParameters + { + + private: + // wt8x8, F(4x4,5x5) + static const int T = 5 + 4 - 1; + static const int K = 5; + static const int O = 4; + + public: + static const float *get(WINOGRAD_MATRIX mat, int &row, int& col) { + +#if DEBUG_WINOGRAD + assert(mat >= WINOGRAD_A && mat <= WINOGRAD_G); +#endif + switch (mat) { + + case WINOGRAD_A: row = T; col = O; return getA(); + case WINOGRAD_B: row = T; col = T; return getB(); + case WINOGRAD_G: row = T; col = K; return getG(); + + } + + } + + private: + + // from https://github.com/Maratyszcza/NNPACK/issues/12 + + static const float *getG() { + static const float G[T*K] = { + 1, 0, 0, 0, 0, + -2. / 9., -2. / 9., -2. / 9., -2. / 9., -2. / 9., + -2. / 9., 2. / 9., -2. / 9., 2. / 9., -2. / 9., + 1. / 90., 1. / 45., 2. / 45., 4. / 45., 8. / 45., + 1. / 90., -1. / 45., 2. / 45., -4. / 45., 8. / 45., + 4. / 45., 2. / 45., 1. / 45., 1. / 90., 1. / 180., + 4. / 45., -2. / 45., 1. / 45., -1. / 90., 1. / 180., + 0, 0, 0, 0, 1, + }; + return G; + } + + + + + static const float *getA() { + static const float A[T*(O)] = { + 1, 0, 0, 0, + 1, 1, 1, 1, + 1, -1, 1, -1, + 1, 2, 4, 8, + 1, -2, 4, -8, + 8, 4, 2, 1, + 8, -4, 2, -1, + 0, 0, 0, 1 + }; + return A; + } + + static const float *getB() { + static const float B[T*T] = { + 1, 0, 0, 0, 0, 0, 0, 0, + 0, 1, -1, 1. / 2, -1. / 2, 2, -2, -1, + -21. / 4, 1, 1, 1. / 4, 1. / 4, 4, 4, 0, + 0, -17. / 4, 17. / 4, -5. / 2, 5. / 2, -5. / 2, 5. / 2, 21. / 4, + 21. / 4, -17. / 4, -17. / 4, -5. / 4, -5. / 4, -5, -5, 0, + 0, 1, -1, 2, -2, 1. / 2, -1. / 2, -21. / 4, + -1, 1, 1, 1, 1, 1, 1, 0, + 0, 0, 0, 0, 0, 0, 0, 1, + }; + return B; + } + }; + + class Winograd_Kron + { + + private: + + Winograd_Kron(WINOGRAD_ALG alg, WINOGRAD_MATRIX mat) { + + isCalc = false; + + switch (alg) { + + case WT_8X8_F_6X6_3X3: + matrix = WinogradParameters::get(mat, row, col); break; + case WT_6X6_F_4X4_3X3: + matrix = WinogradParameters::get(mat, row, col); break; + case WT_8X8_F_4X4_5X5: + matrix = WinogradParameters::get(mat, row, col); break; + + } + + } + + private: + const float *matrix; // = A, B, G + int row, col;// matrix: row*col + // A: T*O + // B: M*M + // G: T*K + + boost::shared_ptr > kron; + + bool isCalc; + + public: + + static Winograd_Kron *getInstance(WINOGRAD_ALG alg, WINOGRAD_MATRIX mat) { + + // 9 instances 3*3 + static Winograd_Kron * instances[WINOGRAD_ALG_NUM *WINOGRAD_MATRIX_NUM] = { NULL }; // according to [WINOGRAD_MATRIX] [WINOGRAD_PAIR] + + int index = alg*WINOGRAD_MATRIX_NUM + mat; + + if (instances[index] == NULL) + instances[index] = new Winograd_Kron(alg, mat); + + return instances[index]; + + } + + const boost::shared_ptr > get() { + if (isCalc) + return kron; + else { + calc(); + return kron; + } + + } + + private: + + void calc() { + + kron = boost::shared_ptr >(new caffe::Blob(shape)); + + kronecker_product(kron->mutable_cpu_data(), matrix, matrix, row, col, row, col); + + isCalc = true; + + } + + }; + + void kronecker_product(float *out, const float *in1, const float *in2, int m, int n, int p, int q) + { + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int k = 0; k < p; ++k) { + for (int l = 0; l < q; ++l) { + out[(p*i + k)*n*q + q*j + l] = in1[n*i + j] * in2[k*q + l]; + /* compute in float precision in inference */ + } + } + } + } + } + + void winograd2D_initialize() { + //singleton, precomputation before inference + + Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_A)->get(); + Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_B)->get(); + Winograd_Kron::getInstance(WT_6X6_F_4X4_3X3, WINOGRAD_G)->get(); + + Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_A)->get(); + Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_B)->get(); + Winograd_Kron::getInstance(WT_8X8_F_6X6_3X3, WINOGRAD_G)->get(); + } +} + + +#endif \ No newline at end of file From ff7c0f5328db2d1b346a62610455dea243196b3d Mon Sep 17 00:00:00 2001 From: Tao Hu <2441597844@qq.com> Date: Mon, 24 Jul 2017 21:42:25 +0800 Subject: [PATCH 2/6] Add files via upload --- .../caffe/layers/winograd_layer_inference.hpp | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 include/caffe/layers/winograd_layer_inference.hpp diff --git a/include/caffe/layers/winograd_layer_inference.hpp b/include/caffe/layers/winograd_layer_inference.hpp new file mode 100644 index 000000000..943d5a801 --- /dev/null +++ b/include/caffe/layers/winograd_layer_inference.hpp @@ -0,0 +1,116 @@ +#ifndef WINOGRAD_LAYER_INFERENCE_HPP_ +#define WINOGRAD_LAYER_INFERENCE_HPP_ + +//winograd_layer for cpu inference + +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +#include "caffe/layers/base_conv_layer.hpp" + +namespace caffe { + +/** + * @brief Convolves the input image with a bank of learned filters, + * and (optionally) adds biases. + * + * Caffe convolves by reduction to matrix multiplication. This achieves + * high-throughput and generality of input and filter dimensions but comes at + * the cost of memory for matrices. This makes use of efficiency in BLAS. + * + * The input is "im2col" transformed to a channel K' x H x W data matrix + * for multiplication with the N x K' x H x W filter matrix to yield a + * N' x H x W output matrix that is then "col2im" restored. K' is the + * input channel * kernel height * kernel width dimension of the unrolled + * inputs so that the im2col matrix has a column for each input region to + * be filtered. col2im restores the output spatial structure by rolling up + * the output channel N' columns of the output matrix. + */ +template +class WinogradLayer : public BaseConvolutionLayer { + public: + /** + * @param param provides ConvolutionParameter convolution_param, + * with ConvolutionLayer options: + * - num_output. The number of filters. + * - kernel_size / kernel_h / kernel_w. The filter dimensions, given by + * kernel_size for square filters or kernel_h and kernel_w for rectangular + * filters. + * - stride / stride_h / stride_w (\b optional, default 1). The filter + * stride, given by stride_size for equal dimensions or stride_h and stride_w + * for different strides. By default the convolution is dense with stride 1. + * - pad / pad_h / pad_w (\b optional, default 0). The zero-padding for + * convolution, given by pad for equal dimensions or pad_h and pad_w for + * different padding. Input padding is computed implicitly instead of + * actually padding. + * - dilation (\b optional, default 1). The filter + * dilation, given by dilation_size for equal dimensions for different + * dilation. By default the convolution has dilation 1. + * - group (\b optional, default 1). The number of filter groups. Group + * convolution is a method for reducing parameterization by selectively + * connecting input and output channels. The input and output channel dimensions must be divisible + * by the number of groups. For group @f$ \geq 1 @f$, the + * convolutional filters' input and output channels are separated s.t. each + * group takes 1 / group of the input channels and makes 1 / group of the + * output channels. Concretely 4 input channels, 8 output channels, and + * 2 groups separate input channels 1-2 and output channels 1-4 into the + * first group and input channels 3-4 and output channels 5-8 into the second + * group. + * - bias_term (\b optional, default true). Whether to have a bias. + * - engine: convolution has CAFFE (matrix multiplication) and CUDNN (library + * kernels + stream parallelism) engines. + */ + explicit WinogradLayer(const LayerParameter& param) + : BaseConvolutionLayer(param) {} + + virtual inline const char* type() const { return "Winograd"; } + + virtual void WeightAlign(); + bool IsReshapedToWinograd(); + void ReshapeToWinograd(); + + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + protected: + void WeightAlignLocal(); + + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual inline bool reverse_dimensions() { return false; } + virtual void compute_output_shape(); + + // used in forward pass + void winograd_input_im2col_cpu(const Dtype *data, Dtype *col_buff); + void winograd_output_col2im_cpu(const Dtype *col_buff, Dtype *data); + + // used in backward pass + void winograd_output_im2col_cpu(const Dtype *col_buff, Dtype *data); + void winograd_input_col2im_cpu(const Dtype *col_buff, Dtype *data); + + Blob temp1_, temp2_, winograd_weight_; + + // The following variables are initialized in WeightAlign + int tile_h_in_, tile_w_in_; /* input tile size */ + int tile_h_out_, tile_w_out_; /* output tile size */ + int ntiles_h_, ntiles_w_; /* number of tiles */ + + shared_ptr > + in_activation_ptrs_, out_activation_ptrs_, weight_ptrs_, weight_diff_ptrs_; + /** buffer for pointers to be passed to cubalsSgemmBatched */ + + bool weight_ptrs_initialized_, weight_diff_ptrs_initialized_; +}; + +} // namespace caffe + +#endif // CAFFE_WINOGRAD_LAYER_HPP_ From 0b55f62937e1e50b7685df79dd74042cabb3a3ea Mon Sep 17 00:00:00 2001 From: Tao Hu <2441597844@qq.com> Date: Mon, 24 Jul 2017 21:43:21 +0800 Subject: [PATCH 3/6] Add files via upload --- src/caffe/layers/winograd_layer_inference.cpp | 572 ++++++++++++++++++ 1 file changed, 572 insertions(+) create mode 100644 src/caffe/layers/winograd_layer_inference.cpp diff --git a/src/caffe/layers/winograd_layer_inference.cpp b/src/caffe/layers/winograd_layer_inference.cpp new file mode 100644 index 000000000..e10fedee8 --- /dev/null +++ b/src/caffe/layers/winograd_layer_inference.cpp @@ -0,0 +1,572 @@ +#include + +#include "caffe/layers/winograd_layer_inference.hpp" +#include "caffe/util/winograd_inference.hpp" + +//winograd_layer for cpu inference + +namespace caffe { + +template +void WinogradLayer::compute_output_shape() { + const int* kernel_shape_data = this->kernel_shape_.cpu_data(); + const int* stride_data = this->stride_.cpu_data(); + const int* pad_data = this->pad_.cpu_data(); + const int* dilation_data = this->dilation_.cpu_data(); + this->output_shape_.clear(); + for (int i = 0; i < this->num_spatial_axes_; ++i) { + // i + 1 to skip channel axis + const int input_dim = this->input_shape(i + 1); + const int kernel_extent = dilation_data[i] * (kernel_shape_data[i] - 1) + 1; + const int output_dim = (input_dim + 2 * pad_data[i] - kernel_extent) + / stride_data[i] + 1; + this->output_shape_.push_back(output_dim); + } +} + +template +bool WinogradLayer::IsReshapedToWinograd() { + return !(this->blobs_[0]->shape(2) == this->blobs_[0]->shape(3) && (this->blobs_[0]->shape(2) == 3 || this->blobs_[0]->shape(2) == 5)); +} + +template +void WinogradLayer::ReshapeToWinograd() { + if (!IsReshapedToWinograd()) { + // not yet reshaped + vector shape; + shape.push_back(tile_h_in_); + shape.push_back(tile_w_in_); + shape.push_back(this->conv_out_channels_); + shape.push_back(this->conv_in_channels_/this->group_); + this->blobs_[0]->Reshape(shape); + } +} + +template +void WinogradLayer::WeightAlign() { + BaseConvolutionLayer::WeightAlign(); + + WeightAlignLocal(); +} + +template +void WinogradLayer::WeightAlignLocal() { + if (!IsReshapedToWinograd()) { + // transform weights to Winograd domain + Dtype* weight_orig = new Dtype[this->blobs_[0]->count()]; + memcpy(weight_orig, this->blobs_[0]->cpu_data(), sizeof(Dtype)*this->blobs_[0]->count()); + + ReshapeToWinograd(); + + int kernel_h = this->kernel_shape_.cpu_data()[0], kernel_w = this->kernel_shape_.cpu_data()[1]; + + caffe_cpu_gemm(CblasNoTrans, CblasTrans, + tile_h_in_*tile_w_in_, (this->conv_in_channels_/this->group_)*this->conv_out_channels_, kernel_h*kernel_w, + (Dtype)1, + WINOGRAD_INFERENCE::Winograd_Kron::getInstance(WINOGRAD_INFERENCE::WT_6X6_F_4X4_3X3, WINOGRAD_INFERENCE::WINOGRAD_G)->get()->cpu_data(), //wt6x6 + weight_orig, + (Dtype)0, this->blobs_[0]->mutable_cpu_data()); + + delete[] weight_orig; + } +} + +template +void WinogradLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + BaseConvolutionLayer::Reshape(bottom, top); + + int height = this->conv_input_shape_.cpu_data()[1], width = this->conv_input_shape_.cpu_data()[2]; + int kernel_h = this->kernel_shape_.cpu_data()[0], kernel_w = this->kernel_shape_.cpu_data()[1]; + int stride_h = this->stride_.cpu_data()[0], stride_w = this->stride_.cpu_data()[1]; + int dilation_h = this->dilation_.cpu_data()[0], dilation_w = this->dilation_.cpu_data()[1]; + + if (stride_h != 1 || stride_w != 1 || dilation_h != 1 || dilation_w != 1) { + LOG(FATAL) << "non-unit stride or dilation"; + } + if (kernel_h != kernel_w) { + LOG(FATAL) << "kernel_h != kernel_w"; + } + + WinogradGKronG *GKronG = WinogradGKronG::getInstance(kernel_h); + + tile_h_in_ = GKronG->M; + tile_w_in_ = GKronG->M; + tile_h_out_ = tile_h_in_ - GKronG->N + 1, tile_w_out_ = tile_w_in_ - GKronG->N + 1; + + int pad_h = this->pad_.cpu_data()[0], pad_w = this->pad_.cpu_data()[1]; + int output_h = (height + 2*pad_h - dilation_h*(kernel_h - 1) - 1)/stride_h + 1, output_w = (width + 2*pad_w - dilation_w*(kernel_w - 1) - 1)/stride_w + 1; + + // to cover input image: (ntiles_h_ - 1)*tile_h_out_ + (tile_h_in_ - 1) - pad_h >= height - 1 => ntiles_h_ > (height + pad_h - tile_h_in_)/tile_h_out_ + // to cover output image: ntiles_h_ >= output_h/tile_h_out_ + ntiles_h_ = (std::max(height + pad_h - tile_h_in_ + 1, output_h) + tile_h_out_ - 1)/tile_h_out_; + ntiles_w_ = (std::max(width + pad_w - tile_w_in_ + 1, output_w) + tile_w_out_ - 1)/tile_w_out_; + + // create temporary buffers + vector shape; + shape.push_back(this->num_); + shape.push_back(tile_h_in_*tile_w_in_); + shape.push_back(std::max(this->conv_in_channels_, this->conv_out_channels_)); + shape.push_back(ntiles_h_*ntiles_w_); + + if (temp1_.shape() != shape) { + temp1_.Reshape(shape); + temp2_.Reshape(shape); + + // create arrays to pointers to prepare for cuda batch sgemm + shape.clear(); + shape.push_back(tile_h_in_); + shape.push_back(tile_w_in_); + shape.push_back(this->group_); + + in_activation_ptrs_.reset(new Blob(shape)); + out_activation_ptrs_.reset(new Blob(shape)); + weight_ptrs_.reset(new Blob(shape)); + weight_diff_ptrs_.reset(new Blob(shape)); + + Dtype **in_ptrs = (Dtype **)in_activation_ptrs_->mutable_cpu_data(); + Dtype **out_ptrs = (Dtype **)out_activation_ptrs_->mutable_cpu_data(); + + for (int j = 0; j < tile_h_in_*tile_w_in_*this->group_; ++j) { + in_ptrs[j] = + temp1_.mutable_gpu_data() + + j*(this->conv_in_channels_/this->group_)*this->num_*ntiles_h_*ntiles_w_; + + out_ptrs[j] = + temp2_.mutable_gpu_data() + + j*(this->conv_out_channels_/this->group_)*this->num_*ntiles_h_*ntiles_w_; + } + + weight_ptrs_initialized_ = false; + weight_diff_ptrs_initialized_ = false; + } + + WeightAlignLocal(); +} + +template +void WinogradLayer::winograd_input_im2col_cpu(const Dtype *data, Dtype *col_buff) +{ + int height = this->conv_input_shape_.cpu_data()[1], width = this->conv_input_shape_.cpu_data()[2]; + int pad_h = this->pad_.cpu_data()[0], pad_w = this->pad_.cpu_data()[1]; + + for (int c = 0; c < this->conv_in_channels_; ++c) { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) { + for (int y = 0; y < tile_h_in_; ++y) { + for (int x = 0; x < tile_w_in_; ++x) { + int in_y = tile_h*tile_h_out_ + y - pad_h; + int in_x = tile_w*tile_w_out_ + x - pad_w; + + if (in_y < 0 || in_x < 0 || in_y >= height || in_x >= width) { + col_buff[(((c*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_in_ + y)*tile_w_in_ + x] = 0; + } + else { + col_buff[(((c*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_in_ + y)*tile_w_in_ + x] = + data[(c*height + in_y)*width + in_x]; + } + } + } + } // for each tile + } // for each tile + } // for each input channel +} + +template +void WinogradLayer::winograd_output_col2im_cpu(const Dtype *col_buff, Dtype *data) +{ + const int output_h = this->output_shape_[0], output_w = this->output_shape_[1]; + + for (int n = 0; n < this->conv_out_channels_; ++n) { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) { + for (int y = 0; y < tile_h_out_; ++y) { + for (int x = 0; x < tile_w_out_; ++x) { + int out_y = tile_h*tile_h_out_ + y; + int out_x = tile_w*tile_w_out_ + x; + + if (out_y < output_h && out_x < output_w) { + data[(n*output_h + out_y)*output_w + out_x] = + col_buff[(((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x]; + } + } + } + } // for each tile + } // for each tile + } // for each input channel +} + +template +void WinogradLayer::winograd_output_im2col_cpu(const Dtype *data, Dtype *col_buff) +{ + const int output_h = this->output_shape_[0], output_w = this->output_shape_[1]; + + for (int n = 0; n < this->conv_out_channels_; ++n) { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) { + for (int y = 0; y < tile_h_out_; ++y) { + for (int x = 0; x < tile_w_out_; ++x) { + int out_y = tile_h*tile_h_out_ + y; + int out_x = tile_w*tile_w_out_ + x; + + if (out_y < 0 || out_x < 0 || out_y >= output_h || out_x >= output_w) { + col_buff[(((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x] = 0; + } + else { + col_buff[(((n*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_out_ + y)*tile_w_out_ + x] = + data[(n*output_h + out_y)*output_w + out_x]; + } + } + } + } // for each tile + } // for each tile + } // for each input channel +} + +template +void WinogradLayer::winograd_input_col2im_cpu(const Dtype *col_buff, Dtype *data) +{ + int height = this->conv_input_shape_.cpu_data()[1], width = this->conv_input_shape_.cpu_data()[2]; + int pad_h = this->pad_.cpu_data()[0], pad_w = this->pad_.cpu_data()[1]; + + memset(data, 0, sizeof(Dtype)*this->conv_in_channels_*height*width); + + for (int c = 0; c < this->conv_in_channels_; ++c) { + for (int tile_h = 0; tile_h < ntiles_h_; ++tile_h) { + for (int tile_w = 0; tile_w < ntiles_w_; ++tile_w) { + for (int y = 0; y < tile_h_in_; ++y) { + for (int x = 0; x < tile_w_in_; ++x) { + int in_y = tile_h*tile_h_out_ + y - pad_h; + int in_x = tile_w*tile_w_out_ + x - pad_w; + + if (in_y >= 0 && in_x >= 0 && in_y < height && in_x < width) { + data[(c*height + in_y)*width + in_x] += + col_buff[(((c*ntiles_h_ + tile_h)*ntiles_w_ + tile_w)*tile_h_in_ + y)*tile_w_in_ + x]; + } + } + } + } // for each tile + } // for each tile + } // for each input channel +} + +//#define PROFILE_WINOGRAD + +template +void WinogradLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + + int kernel_h = this->kernel_shape_.cpu_data()[0], kernel_w = this->kernel_shape_.cpu_data()[1]; + + const Dtype* weight = this->blobs_[0]->cpu_data(); + +#ifdef PROFILE_WINOGRAD + CPUTimer timer; +#endif + + for (int i = 0; i < bottom.size(); ++i) { + const Dtype* bottom_data = bottom[i]->cpu_data(); + Dtype* top_data = top[i]->mutable_cpu_data(); + for (int n = 0; n < this->num_; ++n) { // JSP: this->num_ is batch size + int M = this->conv_in_channels_*ntiles_h_*ntiles_w_; + + Dtype *col_buff = this->col_buffer_.mutable_cpu_data(); + +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + winograd_input_im2col_cpu(bottom_data + n*this->bottom_dim_, col_buff); +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "winograd_output_im2col takes " << timer.MilliSeconds()/1000; +#endif + + // Transform input to Winograd domain +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + caffe_cpu_gemm(CblasTrans, CblasTrans, + tile_h_in_*tile_w_in_, M, tile_h_in_*tile_w_in_, + (Dtype)1, + WINOGRAD_INFERENCE::Winograd_Kron::getInstance(WINOGRAD_INFERENCE::WT_6X6_F_4X4_3X3, WINOGRAD_INFERENCE::WINOGRAD_B)->get()->cpu_data(), //wt6x6 + col_buff, + (Dtype)0, temp1_.mutable_cpu_data()); + // temp_ has (tile_h_in*tile_w_in) x (conv_in_channels) x (ntiles_h*ntiles_w) dimension +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Transformation of bottom takes " << timer.MilliSeconds()/1000; +#endif + +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + // Convolution in Winograd domain + for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) { + for (int g = 0; g < this->group_; ++g) { + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + this->conv_out_channels_/this->group_, ntiles_h_*ntiles_w_, this->conv_in_channels_/this->group_, + (Dtype)1, + weight + (j*this->group_ + g)*(this->conv_out_channels_/this->group_)*(this->conv_in_channels_/this->group_), + temp1_.cpu_data() + (j*this->group_ + g)*(this->conv_in_channels_/this->group_)*ntiles_h_*ntiles_w_, + (Dtype)0, col_buff + (j*this->group_ + g)*(this->conv_out_channels_/this->group_)*ntiles_h_*ntiles_w_); + } + } + // col_buff has (tile_h_in*tile_w_in) x (conv_out_channels) x (ntiles_h*ntiles_w) +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Convolution takes " << timer.MilliSeconds()/1000; +#endif + + // Transform back to time domain +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + caffe_cpu_gemm(CblasTrans, CblasNoTrans, + this->conv_out_channels_*ntiles_h_*ntiles_w_, tile_h_out_*tile_w_out_, tile_h_in_*tile_w_in_, + (Dtype)1, col_buff, + WINOGRAD_INFERENCE::Winograd_Kron::getInstance(WINOGRAD_INFERENCE::WT_6X6_F_4X4_3X3, WINOGRAD_INFERENCE::WINOGRAD_A)->get()->cpu_data(), //wt6x6, kronA + (Dtype)0, temp1_.mutable_cpu_data()); +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Inverse transformation of top takes " << timer.MilliSeconds()/1000; +#endif + +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + winograd_output_col2im_cpu(temp1_.cpu_data(), top_data + n*this->top_dim_); +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "winograd_output_col2im takes " << timer.MilliSeconds()/1000; +#endif + + if (this->bias_term_) { + const Dtype* bias = this->blobs_[1]->cpu_data(); + this->forward_cpu_bias(top_data + n * this->top_dim_, bias); + } + } + } +} + +template <> +void WinogradLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + NOT_IMPLEMENTED; +} + +template <> +void WinogradLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + + int kernel_h = this->kernel_shape_.cpu_data()[0], kernel_w = this->kernel_shape_.cpu_data()[1]; + + const float* weight = this->blobs_[0]->cpu_data(); + float* weight_diff = this->blobs_[0]->mutable_cpu_diff(); + +// fprintf(stderr, "weight_winograd\n"); +// for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) { +// for (int n = 0; n < this->conv_out_channels_; ++n) { +// for (int c = 0; c < this->conv_in_channels_; ++c) { +// fprintf(stderr, "%g ", weight[(j*this->conv_out_channels_ + n)*this->conv_in_channels_ + c]); +// } +// } +// fprintf(stderr, "\n"); +// } + +#ifdef PROFILE_WINOGRAD + CPUTimer timer; +#endif + + for (int i = 0; i < top.size(); ++i) { + const float* top_diff = top[i]->cpu_diff(); + const float* bottom_data = bottom[i]->cpu_data(); + float* bottom_diff = bottom[i]->mutable_cpu_diff(); + // Bias gradient, if necessary. + if (this->bias_term_ && this->param_propagate_down_[1]) { + float* bias_diff = this->blobs_[1]->mutable_cpu_diff(); + for (int n = 0; n < this->num_; ++n) { + this->backward_cpu_bias(bias_diff, top_diff + n * this->top_dim_); + } + } + + if (this->param_propagate_down_[0] || propagate_down[i]) { + for (int n = 0; n < this->num_; ++n) { + int M = this->conv_out_channels_*ntiles_h_*ntiles_w_; + + float *col_buff = this->col_buffer_.mutable_cpu_data(); + +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + winograd_output_im2col_cpu(top_diff + n*this->top_dim_, col_buff); +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "winograd_output_im2col takes " << timer.MilliSeconds()/1000; +#endif + + // Transform out_diff to Winograd domain +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + caffe_cpu_gemm(CblasNoTrans, CblasTrans, + tile_h_in_*tile_w_in_, M, tile_h_out_*tile_w_out_, + (float)1, + WINOGRAD_INFERENCE::Winograd_Kron::getInstance(WINOGRAD_INFERENCE::WT_6X6_F_4X4_3X3, WINOGRAD_INFERENCE::WINOGRAD_A)->get()->cpu_data(), //wt6x6 + col_buff, + (float)0, temp1_.mutable_cpu_data()); + // temp_ has (tile_h_in*tile_w_in) x (conv_out_channels) x (ntiles_h*ntiles_w) dimension +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Transformation of top_diff takes " << timer.MilliSeconds()/1000; +#endif + + // gradient w.r.t. weight. Note that we will accumulate diffs. + if (this->param_propagate_down_[0]) { +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + winograd_input_im2col_cpu(bottom_data + n*this->bottom_dim_, col_buff); +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "winograd_input_im2col takes " << timer.MilliSeconds()/1000; +#endif + + // Transform input to Winograd domain +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + caffe_cpu_gemm(CblasTrans, CblasTrans, + tile_h_in_*tile_w_in_, this->conv_in_channels_*ntiles_h_*ntiles_w_, tile_h_in_*tile_w_in_, + (float)1, + WINOGRAD_INFERENCE::Winograd_Kron::getInstance(WINOGRAD_INFERENCE::WT_6X6_F_4X4_3X3, WINOGRAD_INFERENCE::WINOGRAD_B)->get()->cpu_data(), //wt6x6 + col_buff, + (float)0, temp2_.mutable_cpu_data()); + // temp_ has (tile_h_in*tile_w_in) x (conv_in_channels) x (ntiles_h*ntiles_w) dimension +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Transformation of bottom takes " << timer.MilliSeconds()/1000; +#endif + + if (false/*n == 0*/) { + fprintf(stderr, "weight_diff_winograd0[0]\n"); + for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) { + for (int n = 0; n < this->conv_out_channels_; ++n) { + for (int c = 0; c < this->conv_in_channels_; ++c) { + fprintf(stderr, "%g ", weight_diff[(j*this->conv_out_channels_ + n)*this->conv_in_channels_ + c]); + } + } + fprintf(stderr, "\n"); + } + } + +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) { + for (int g = 0; g < this->group_; ++g) { + caffe_cpu_gemm(CblasNoTrans, CblasTrans, + this->conv_out_channels_/this->group_, this->conv_in_channels_/this->group_, ntiles_h_*ntiles_w_, + (float)1, + temp1_.cpu_data() + (j*this->group_ + g)*(this->conv_out_channels_/this->group_)*ntiles_h_*ntiles_w_, + temp2_.cpu_data() + (j*this->group_ + g)*(this->conv_in_channels_/this->group_)*ntiles_h_*ntiles_w_, + (float)1, weight_diff + (j*this->group_ + g)*(this->conv_out_channels_/this->group_)*(this->conv_in_channels_/this->group_)); + } + } + // weight_diff has (tile_h_in*tile_w_in) x (conv_out_channels) x (conv_in_channels/group) dimension +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Convolution for weight gradient takes " << timer.MilliSeconds()/1000; +#endif + +// for (int i = 0; i < tile_h_in_*tile_w_in_*this->conv_out_channels_*(this->conv_in_channels_/this->group_); ++i) { +// if (isnan(weight_diff[i])) { +// ostringstream str; +// str << "nan at weight_diff[" << i << "]"; +// LOG(FATAL) << str.str(); +// } +// } + + if (false/*n == this->num_ - 1*/) { + float *temp_weight = new float[this->conv_out_channels_*(this->conv_in_channels_/this->group_)*kernel_h*kernel_w]; + + caffe_cpu_gemm(CblasTrans, CblasNoTrans, + this->conv_out_channels_*(this->conv_in_channels_/this->group_), kernel_h*kernel_w, tile_h_in_*tile_w_in_, + (float)1, weight_diff, + WINOGRAD_INFERENCE::Winograd_Kron::getInstance(WINOGRAD_INFERENCE::WT_6X6_F_4X4_3X3, WINOGRAD_INFERENCE::WINOGRAD_G)->get()->cpu_data(), //wt6x6 + (float)0, temp_weight); + + fprintf(stderr, "weight_diff[%d]\n", n); + for (int m = 0; m < this->conv_out_channels_; ++m) { + for (int c = 0; c < this->conv_in_channels_/this->group_; ++c) { + for (int i = 0; i < kernel_h*kernel_w; ++i) { + fprintf(stderr, "%g ", temp_weight[(m*(this->conv_in_channels_/this->group_) + c)*kernel_h*kernel_w + i]); + } + } + fprintf(stderr, "\n"); + } + delete[] temp_weight; + + fprintf(stderr, "weight_diff_winograd[%d]\n", n); + for (int n = 0; n < this->conv_out_channels_; ++n) { + for (int c = 0; c < this->conv_in_channels_; ++c) { + for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) { + fprintf(stderr, "%g ", weight_diff[(j*this->conv_out_channels_ + n)*this->conv_in_channels_ + c]); + } + } + fprintf(stderr, "\n"); + } + } + } + + // gradient w.r.t. bottom data, if necessary. + if (propagate_down[i]) { +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + // Convolution in Winograd domain + for (int j = 0; j < tile_h_in_*tile_w_in_; ++j) { + for (int g = 0; g < this->group_; ++g) { + caffe_cpu_gemm(CblasTrans, CblasNoTrans, + this->conv_in_channels_/this->group_, ntiles_h_*ntiles_w_, this->conv_out_channels_/this->group_, + (float)1, + weight + (j*this->group_ + g)*(this->conv_out_channels_/this->group_)*(this->conv_in_channels_/this->group_), + temp1_.cpu_data() + (j*this->group_ + g)*(this->conv_out_channels_/this->group_)*ntiles_h_*ntiles_w_, + (float)0, col_buff + (j*this->group_ + g)*(this->conv_in_channels_/this->group_)*ntiles_h_*ntiles_w_); + } + } + // col_buff has (tile_h_in*tile_w_in) x (conv_in_channels) x (ntiles_h*ntiles_w) +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Convolution for bottom gradient takes " << timer.MilliSeconds()/1000; +#endif + + // Transform back to time domain +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + caffe_cpu_gemm(CblasTrans, CblasTrans, + this->conv_in_channels_*ntiles_h_*ntiles_w_, tile_h_in_*tile_w_in_, tile_h_in_*tile_w_in_, + (float)1, col_buff, + WINOGRAD_INFERENCE::Winograd_Kron::getInstance(WINOGRAD_INFERENCE::WT_6X6_F_4X4_3X3, WINOGRAD_INFERENCE::WINOGRAD_B)->get()->cpu_data(), //wt6x6 + (float)0, temp1_.mutable_cpu_data()); +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "Inverse transformation of bottom_diff takes " << timer.MilliSeconds()/1000; +#endif + +#ifdef PROFILE_WINOGRAD + timer.Start(); +#endif + winograd_input_col2im_cpu(temp1_.cpu_data(), bottom_diff + n*this->bottom_dim_); +#ifdef PROFILE_WINOGRAD + LOG(INFO) << "winograd_input_col2im takes " << timer.MilliSeconds()/1000; +#endif + +// for (int i = 0; i < this->bottom_dim_; ++i) { +// if (isnan(bottom_diff[i])) { +// ostringstream str; +// str << "nan at bottom_diff[" << n << ", " << i << "]"; +// } +// } + } + } // for each image + } + } +} + +#ifdef CPU_ONLY +STUB_GPU(WinogradLayer); +#endif + +INSTANTIATE_CLASS(WinogradLayer); +REGISTER_LAYER_CLASS(Winograd); + +} // namespace caffe From 9451feee1c2c65d2cd20d743d66f82415956f87d Mon Sep 17 00:00:00 2001 From: Tao Hu <2441597844@qq.com> Date: Mon, 24 Jul 2017 21:47:54 +0800 Subject: [PATCH 4/6] Add files via upload --- include/caffe/layers/winograd_layer_inference.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/caffe/layers/winograd_layer_inference.hpp b/include/caffe/layers/winograd_layer_inference.hpp index 943d5a801..eb3398114 100644 --- a/include/caffe/layers/winograd_layer_inference.hpp +++ b/include/caffe/layers/winograd_layer_inference.hpp @@ -30,7 +30,7 @@ namespace caffe { * the output channel N' columns of the output matrix. */ template -class WinogradLayer : public BaseConvolutionLayer { +class WinogradLayerInference : public BaseConvolutionLayer { public: /** * @param param provides ConvolutionParameter convolution_param, @@ -63,7 +63,7 @@ class WinogradLayer : public BaseConvolutionLayer { * - engine: convolution has CAFFE (matrix multiplication) and CUDNN (library * kernels + stream parallelism) engines. */ - explicit WinogradLayer(const LayerParameter& param) + explicit WinogradLayerInference(const LayerParameter& param) : BaseConvolutionLayer(param) {} virtual inline const char* type() const { return "Winograd"; } From cc6cbffe0673126a447ea2073d6c67523344e0fc Mon Sep 17 00:00:00 2001 From: Tao Hu <2441597844@qq.com> Date: Mon, 24 Jul 2017 21:49:42 +0800 Subject: [PATCH 5/6] Add files via upload From 70cbd577e727335d14498a400c7719c7c774398b Mon Sep 17 00:00:00 2001 From: Tao Hu <2441597844@qq.com> Date: Mon, 24 Jul 2017 21:50:41 +0800 Subject: [PATCH 6/6] Add files via upload --- src/caffe/layers/winograd_layer_inference.cpp | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/caffe/layers/winograd_layer_inference.cpp b/src/caffe/layers/winograd_layer_inference.cpp index e10fedee8..4bb381b84 100644 --- a/src/caffe/layers/winograd_layer_inference.cpp +++ b/src/caffe/layers/winograd_layer_inference.cpp @@ -8,7 +8,7 @@ namespace caffe { template -void WinogradLayer::compute_output_shape() { +void WinogradLayerInference::compute_output_shape() { const int* kernel_shape_data = this->kernel_shape_.cpu_data(); const int* stride_data = this->stride_.cpu_data(); const int* pad_data = this->pad_.cpu_data(); @@ -25,12 +25,12 @@ void WinogradLayer::compute_output_shape() { } template -bool WinogradLayer::IsReshapedToWinograd() { +bool WinogradLayerInference::IsReshapedToWinograd() { return !(this->blobs_[0]->shape(2) == this->blobs_[0]->shape(3) && (this->blobs_[0]->shape(2) == 3 || this->blobs_[0]->shape(2) == 5)); } template -void WinogradLayer::ReshapeToWinograd() { +void WinogradLayerInference::ReshapeToWinograd() { if (!IsReshapedToWinograd()) { // not yet reshaped vector shape; @@ -43,14 +43,14 @@ void WinogradLayer::ReshapeToWinograd() { } template -void WinogradLayer::WeightAlign() { +void WinogradLayerInference::WeightAlign() { BaseConvolutionLayer::WeightAlign(); WeightAlignLocal(); } template -void WinogradLayer::WeightAlignLocal() { +void WinogradLayerInference::WeightAlignLocal() { if (!IsReshapedToWinograd()) { // transform weights to Winograd domain Dtype* weight_orig = new Dtype[this->blobs_[0]->count()]; @@ -72,7 +72,7 @@ void WinogradLayer::WeightAlignLocal() { } template -void WinogradLayer::Reshape(const vector*>& bottom, +void WinogradLayerInference::Reshape(const vector*>& bottom, const vector*>& top) { BaseConvolutionLayer::Reshape(bottom, top); @@ -145,7 +145,7 @@ void WinogradLayer::Reshape(const vector*>& bottom, } template -void WinogradLayer::winograd_input_im2col_cpu(const Dtype *data, Dtype *col_buff) +void WinogradLayerInference::winograd_input_im2col_cpu(const Dtype *data, Dtype *col_buff) { int height = this->conv_input_shape_.cpu_data()[1], width = this->conv_input_shape_.cpu_data()[2]; int pad_h = this->pad_.cpu_data()[0], pad_w = this->pad_.cpu_data()[1]; @@ -173,7 +173,7 @@ void WinogradLayer::winograd_input_im2col_cpu(const Dtype *data, Dtype *c } template -void WinogradLayer::winograd_output_col2im_cpu(const Dtype *col_buff, Dtype *data) +void WinogradLayerInference::winograd_output_col2im_cpu(const Dtype *col_buff, Dtype *data) { const int output_h = this->output_shape_[0], output_w = this->output_shape_[1]; @@ -197,7 +197,7 @@ void WinogradLayer::winograd_output_col2im_cpu(const Dtype *col_buff, Dty } template -void WinogradLayer::winograd_output_im2col_cpu(const Dtype *data, Dtype *col_buff) +void WinogradLayerInference::winograd_output_im2col_cpu(const Dtype *data, Dtype *col_buff) { const int output_h = this->output_shape_[0], output_w = this->output_shape_[1]; @@ -224,7 +224,7 @@ void WinogradLayer::winograd_output_im2col_cpu(const Dtype *data, Dtype * } template -void WinogradLayer::winograd_input_col2im_cpu(const Dtype *col_buff, Dtype *data) +void WinogradLayerInference::winograd_input_col2im_cpu(const Dtype *col_buff, Dtype *data) { int height = this->conv_input_shape_.cpu_data()[1], width = this->conv_input_shape_.cpu_data()[2]; int pad_h = this->pad_.cpu_data()[0], pad_w = this->pad_.cpu_data()[1]; @@ -253,7 +253,7 @@ void WinogradLayer::winograd_input_col2im_cpu(const Dtype *col_buff, Dtyp //#define PROFILE_WINOGRAD template -void WinogradLayer::Forward_cpu(const vector*>& bottom, +void WinogradLayerInference::Forward_cpu(const vector*>& bottom, const vector*>& top) { int kernel_h = this->kernel_shape_.cpu_data()[0], kernel_w = this->kernel_shape_.cpu_data()[1]; @@ -344,13 +344,13 @@ void WinogradLayer::Forward_cpu(const vector*>& bottom, } template <> -void WinogradLayer::Backward_cpu(const vector*>& top, +void WinogradLayerInference::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { NOT_IMPLEMENTED; } template <> -void WinogradLayer::Backward_cpu(const vector*>& top, +void WinogradLayerInference::Backward_cpu(const vector*>& top, const vector& propagate_down, const vector*>& bottom) { int kernel_h = this->kernel_shape_.cpu_data()[0], kernel_w = this->kernel_shape_.cpu_data()[1]; @@ -563,10 +563,10 @@ void WinogradLayer::Backward_cpu(const vector*>& top, } #ifdef CPU_ONLY -STUB_GPU(WinogradLayer); +STUB_GPU(WinogradLayerInference); #endif -INSTANTIATE_CLASS(WinogradLayer); +INSTANTIATE_CLASS(WinogradLayerInference); REGISTER_LAYER_CLASS(Winograd); } // namespace caffe