Skip to content

Commit 95a1db5

Browse files
authored
Add support for bias in optimized op_linear.cpp.
Differential Revision: D75491158 Pull Request resolved: #11210
1 parent 6875c8e commit 95a1db5

File tree

2 files changed

+203
-29
lines changed

2 files changed

+203
-29
lines changed

kernels/optimized/cpu/op_linear.cpp

Lines changed: 113 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,76 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <array>
10+
11+
#include <c10/util/irange.h>
12+
913
#include <executorch/kernels/optimized/blas/CPUBlas.h>
14+
#include <executorch/kernels/optimized/vec/functional_base.h>
15+
#include <executorch/kernels/optimized/vec/vec_base.h>
1016
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1117
#include <executorch/runtime/kernel/kernel_includes.h>
1218

13-
#include <array>
14-
1519
namespace torch {
1620
namespace executor {
1721
namespace native {
1822

19-
using Tensor = executorch::aten::Tensor;
23+
namespace {
24+
using ::executorch::aten::Tensor;
25+
using ::executorch::cpublas::gemm;
26+
using ::executorch::cpublas::TransposeType;
27+
using ::executorch::runtime::toString;
28+
using ::executorch::vec::map;
29+
using ::executorch::vec::Vectorized;
30+
31+
// Use vector store to initialize with scalar bias.
32+
template <typename scalar_t>
33+
void initialize_scalar(
34+
const ssize_t out_numel,
35+
const scalar_t init,
36+
scalar_t* out) {
37+
using Vec = Vectorized<scalar_t>;
38+
39+
// Initialize a vector with the scalar initial value.
40+
Vec init_vec(init);
41+
42+
ssize_t d = 0;
43+
for (; d < out_numel - (out_numel % Vec::size()); d += Vec::size()) {
44+
// Vector-length store.
45+
init_vec.store(out + d);
46+
}
47+
if (out_numel - d > 0) {
48+
// Sub-vector-length store.
49+
init_vec.store(out + d, static_cast<int>(out_numel - d));
50+
}
51+
}
52+
53+
// Use std::memcpy to initialize with vector bias.
54+
template <typename scalar_t>
55+
void initialize_to_vector(
56+
const ssize_t n,
57+
const ssize_t m,
58+
const scalar_t* bias,
59+
scalar_t* out) {
60+
// Output is a n x m x scalar_t, while bias is m x scalar_t.
61+
const size_t row_size = static_cast<size_t>(m) * sizeof(scalar_t);
62+
for (const auto col : c10::irange(n)) {
63+
std::memcpy(
64+
// Point to Column `col` of the output tensor.
65+
out + col * m,
66+
bias,
67+
row_size);
68+
}
69+
}
70+
71+
} // namespace
2072

2173
Tensor& opt_linear_out(
2274
RuntimeContext& ctx,
2375
const Tensor& in,
2476
const Tensor& mat2,
2577
const optional<Tensor>& bias,
2678
Tensor& out) {
27-
ET_KERNEL_CHECK_MSG(
28-
ctx,
29-
!bias.has_value(),
30-
InvalidArgument,
31-
out,
32-
"bias not supported yet in linear");
3379
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);
3480

3581
size_t output_ndim = 0;
@@ -46,28 +92,74 @@ Tensor& opt_linear_out(
4692
return out;
4793
}
4894

49-
int flattened_input_dim = 1;
95+
ssize_t n = 1;
5096
for (int ii = 0; ii < in.dim() - 1; ++ii) {
51-
flattened_input_dim *= in.sizes()[ii];
97+
n *= in.sizes()[ii];
5298
}
99+
const ssize_t k = in.sizes()[in.dim() - 1];
100+
const ssize_t m = mat2.size(0);
101+
102+
if (bias.has_value()) {
103+
ET_KERNEL_CHECK_MSG(
104+
ctx,
105+
// Bias and output dtype must match.
106+
bias->dtype() == out.dtype(),
107+
InvalidArgument,
108+
out,
109+
"Bias has wrong dtype! Expected bias dtype to be the same as out dtype %s"
110+
" but got %s",
111+
toString(bias->dtype()),
112+
toString(out.dtype()));
113+
114+
ET_KERNEL_CHECK_MSG(
115+
ctx,
116+
// Either no bias or bias is a 1D tensor of size m or 1.
117+
bias->dim() == 1 && (bias->size(0) == m || bias->size(0) == 1),
118+
InvalidArgument,
119+
out,
120+
"Bias has wrong dimensionality! Expected 1-D tensor of size %d or empty,"
121+
" but got %d-D tensor with %d elements",
122+
static_cast<int>(m),
123+
static_cast<int>(bias->dim()),
124+
static_cast<int>(bias->numel()));
125+
}
126+
53127
ET_SWITCH_REAL_TYPES_AND2(
54-
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
55-
size_t n = flattened_input_dim;
56-
size_t k = in.sizes()[in.dim() - 1];
57-
size_t m = mat2.size(0);
58-
59-
executorch::cpublas::gemm(
60-
executorch::cpublas::TransposeType::Transpose,
61-
executorch::cpublas::TransposeType::NoTranspose,
128+
Half, BFloat16, out.scalar_type(), ctx, "linear.out", CTYPE, [&] {
129+
// Fill output with bias if it is provided.
130+
if (bias.has_value() && bias->numel() == 1) {
131+
// Scalar version of initialization.
132+
initialize_scalar<CTYPE>(
133+
out.numel(),
134+
*bias->const_data_ptr<CTYPE>(),
135+
out.mutable_data_ptr<CTYPE>());
136+
} else if (bias.has_value()) {
137+
// Assume bias is a 1D tensor of size m.
138+
initialize_to_vector<CTYPE>(
139+
n,
140+
m,
141+
bias->const_data_ptr<CTYPE>(),
142+
out.mutable_data_ptr<CTYPE>());
143+
}
144+
145+
// Set beta to 1 if bias was applied so that GEMM adds to the pre-filled
146+
// bias, otherwise beta remains 0 (i.e. the output is fully overwritten
147+
// by GEMM).
148+
const CTYPE beta =
149+
bias.has_value() ? static_cast<CTYPE>(1) : static_cast<CTYPE>(0);
150+
151+
gemm(
152+
/*transa=*/TransposeType::Transpose,
153+
/*transb=*/TransposeType::NoTranspose,
62154
m,
63155
n,
64156
k,
65-
static_cast<CTYPE>(1),
157+
/*alpha=*/static_cast<CTYPE>(1),
66158
mat2.const_data_ptr<CTYPE>(),
67159
k,
68160
in.const_data_ptr<CTYPE>(),
69161
k,
70-
static_cast<CTYPE>(0),
162+
beta,
71163
out.mutable_data_ptr<CTYPE>(),
72164
m);
73165
});

kernels/test/op_linear_test.cpp

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
#include <gtest/gtest.h>
1919
#include <limits>
2020

21-
using namespace ::testing;
21+
namespace {
22+
2223
using executorch::aten::ArrayRef;
2324
using executorch::aten::Scalar;
2425
using executorch::aten::ScalarType;
@@ -31,7 +32,15 @@ class OpLinearOutTest : public OperatorTest {
3132
return torch::executor::aten::linear_outf(context_, self, mat2, {}, out);
3233
}
3334

34-
template <class CTYPE, executorch::aten::ScalarType DTYPE>
35+
Tensor& op_linear_out(
36+
const Tensor& self,
37+
const Tensor& mat2,
38+
const Tensor& bias,
39+
Tensor& out) {
40+
return torch::executor::aten::linear_outf(context_, self, mat2, bias, out);
41+
}
42+
43+
template <class CTYPE, ScalarType DTYPE>
3544
void test_dtype() {
3645
TensorFactory<DTYPE> tf;
3746

@@ -43,16 +52,16 @@ class OpLinearOutTest : public OperatorTest {
4352
}
4453
}
4554

46-
// matmul gives 32 * 2 * 3 = 192
47-
Tensor x = tf.full({3, 32}, 2);
48-
Tensor y = tf.full({5, 32}, 3);
55+
// matmul gives 19 * 2 * 3 = 114
56+
Tensor x = tf.full({3, 19}, 2);
57+
Tensor y = tf.full({5, 19}, 3);
4958

5059
// Output shape should be (3, 5)
5160
Tensor out = tf.zeros({3, 5});
5261

5362
op_linear_out(x, y, out);
5463

55-
Tensor expected = tf.full({3, 5}, 192);
64+
Tensor expected = tf.full({3, 5}, 114);
5665

5766
EXPECT_TENSOR_EQ(out, expected);
5867
}
@@ -88,6 +97,80 @@ TEST_F(OpLinearOutTest, AllDtypesSupported) {
8897
// for those types.
8998
}
9099

100+
TEST_F(OpLinearOutTest, BiasTest) {
101+
TensorFactory<ScalarType::Int> tf;
102+
103+
// Initialize input tensors.
104+
constexpr int kReduceDim = 4;
105+
constexpr int kDimX = 3, kDimY = 2;
106+
constexpr int kValueX = 1;
107+
constexpr int kValueY = 2;
108+
constexpr int kValueBias0 = 4, kValueBias1 = 7;
109+
const Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
110+
const Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
111+
const Tensor b = tf.make({kDimY}, {kValueBias0, kValueBias1});
112+
// Output matrix is also empty
113+
Tensor out = tf.zeros({kDimX, kDimY});
114+
// Initialize expected tensor.
115+
constexpr int kValueExpected0 = kValueX * kValueY * kReduceDim + kValueBias0;
116+
constexpr int kValueExpected1 = kValueX * kValueY * kReduceDim + kValueBias1;
117+
// Check that the bias is added to the correct position in the output matrix.
118+
const Tensor expected = tf.make(
119+
{kDimX, kDimY},
120+
{kValueExpected0,
121+
kValueExpected1,
122+
kValueExpected0,
123+
kValueExpected1,
124+
kValueExpected0,
125+
kValueExpected1});
126+
127+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
128+
}
129+
130+
TEST_F(OpLinearOutTest, BiasBroadcastTest) {
131+
TensorFactory<ScalarType::Int> tf;
132+
133+
// Initialize input tensors.
134+
constexpr int kReduceDim = 4;
135+
constexpr int kDimX = 3, kDimY = 5;
136+
constexpr int kValueX = 1;
137+
constexpr int kValueY = 2;
138+
constexpr int kValueBias = 4;
139+
const Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
140+
const Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
141+
const Tensor b = tf.full({1}, kValueBias);
142+
// Output matrix is also empty
143+
Tensor out = tf.zeros({kDimX, kDimY});
144+
// Initialize expected tensor.
145+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
146+
const Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
147+
148+
EXPECT_TENSOR_EQ(op_linear_out(x, y, b, out), expected);
149+
}
150+
151+
TEST_F(OpLinearOutTest, BiasDtypeMismatch) {
152+
TensorFactory<ScalarType::Int> tf;
153+
TensorFactory<ScalarType::Short> tf_bias;
154+
155+
// Initialize input tensors.
156+
constexpr int kReduceDim = 4;
157+
constexpr int kDimX = 3, kDimY = 5;
158+
constexpr int kValueX = 1;
159+
constexpr int kValueY = 2;
160+
constexpr int kValueBias = 4;
161+
Tensor x = tf.full({kDimX, kReduceDim}, kValueX);
162+
Tensor y = tf.full({kDimY, kReduceDim}, kValueY);
163+
// Same size as output.
164+
Tensor b = tf_bias.full({kDimY}, kValueBias);
165+
// Output matrix is also empty
166+
Tensor out = tf.zeros({kDimX, kDimY});
167+
// Initialize expected tensor.
168+
constexpr int kValueExpected = kValueX * kValueY * kReduceDim + kValueBias;
169+
Tensor expected = tf.full({kDimX, kDimY}, kValueExpected);
170+
171+
ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, b, out));
172+
}
173+
91174
TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) {
92175
TensorFactory<ScalarType::Float> tf;
93176

@@ -297,5 +380,4 @@ TEST_F(OpLinearOutTest, DynamicShapeUnbound) {
297380
Tensor ret = op_linear_out(x, y, out);
298381
EXPECT_TENSOR_CLOSE(out, expected_result);
299382
}
300-
301-
// TODO: support and test bias
383+
} // namespace

0 commit comments

Comments
 (0)