Skip to content

Commit 6292a33

Browse files
committed
Add portable randn kernel implementation
1 parent 809a1fd commit 6292a33

File tree

7 files changed

+160
-0
lines changed

7 files changed

+160
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@
317317

318318
- op: rand.out
319319

320+
- op: randn.out
321+
320322
- op: reciprocal.out
321323

322324
- op: relu.out

kernels/portable/cpu/op_randn.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
10+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
#include <random>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::IntArrayRef;
20+
using Tensor = executorch::aten::Tensor;
21+
using ScalarType = executorch::aten::ScalarType;
22+
23+
Tensor&
24+
randn_out(KernelRuntimeContext& ctx, const IntArrayRef sizes, Tensor& out) {
25+
(void)ctx;
26+
27+
std::mt19937 gen((std::random_device())());
28+
std::normal_distribution<double> dist(0.0, 1.0);
29+
30+
// Resize for dynamic shape
31+
ET_KERNEL_CHECK_MSG(
32+
ctx,
33+
resize_tensor(out, sizes) == Error::Ok,
34+
InvalidArgument,
35+
out,
36+
"Failed to resize output tensor.");
37+
38+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "randn.out", CTYPE, [&] {
39+
auto data_out = out.mutable_data_ptr<CTYPE>();
40+
for (const auto i : c10::irange(out.numel())) {
41+
data_out[i] = static_cast<CTYPE>(dist(gen));
42+
}
43+
});
44+
45+
return out;
46+
}
47+
48+
} // namespace native
49+
} // namespace executor
50+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,11 @@
717717
kernels:
718718
- arg_meta: null
719719
kernel_name: torch::executor::rand_out
720+
- op: randn.out
721+
722+
kernels:
723+
- arg_meta: null
724+
kernel_name: torch::executor::randn_out
720725
tags: nondeterministic_seeded
721726

722727
- op: reciprocal.out

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ set(all_test_sources
198198
"op_pixel_shuffle_test.cpp"
199199
"op_prod_test.cpp"
200200
"op_rand_test.cpp"
201+
"op_randn_test.cpp"
201202
"op_reciprocal_test.cpp"
202203
"op_relu_test.cpp"
203204
"op_remainder_test.cpp"

kernels/test/op_randn_test.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <c10/util/irange.h>
10+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/kernels/test/supported_features.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
#include <cmath>
20+
#include <numeric>
21+
22+
using executorch::aten::IntArrayRef;
23+
using executorch::aten::ScalarType;
24+
using executorch::aten::Tensor;
25+
using torch::executor::testing::TensorFactory;
26+
27+
class OpRandnTest : public OperatorTest {
28+
protected:
29+
void op_randn_out(const IntArrayRef sizes, Tensor& out) {
30+
torch::executor::aten::randn_outf(context_, sizes, out);
31+
}
32+
33+
template <typename CTYPE, ScalarType DTYPE>
34+
void test_randn(std::vector<int64_t>& sizes) {
35+
TensorFactory<DTYPE> tf;
36+
37+
// Tensor factory wants int32 scales, op kernel wants int64.
38+
std::vector<int32_t> sizes_i32;
39+
std::transform(
40+
sizes.begin(),
41+
sizes.end(),
42+
std::back_inserter(sizes_i32),
43+
[](int64_t s) { return static_cast<int32_t>(s); });
44+
Tensor out = tf.zeros(sizes_i32);
45+
46+
IntArrayRef sizes_ref(sizes.data(), sizes.size());
47+
op_randn_out(sizes_ref, out);
48+
49+
// Check mean and standard deviation. To avoid flaky CI, test pretty
50+
// loosely.
51+
auto out_data = out.const_data_ptr<CTYPE>();
52+
double mean =
53+
std::accumulate(
54+
out_data,
55+
out_data + out.numel(),
56+
0.0,
57+
[](double acc, CTYPE n) { return acc + static_cast<double>(n); }) /
58+
out.numel();
59+
double var = std::accumulate(
60+
out_data,
61+
out_data + out.numel(),
62+
0.0,
63+
[=](double acc, CTYPE n) {
64+
return acc + std::pow(static_cast<double>(n) - mean, 2);
65+
}) /
66+
out.numel();
67+
auto stdev = std::sqrt(var);
68+
69+
// These are very rough thresholds. A better test implementation would
70+
// probably do a proper statistical test to compare the generated empirical
71+
// data to the reference distribution, but this should do.
72+
EXPECT_LE(std::abs(mean), 5.0 / std::sqrt(out.numel()));
73+
EXPECT_LE(std::abs(stdev - 1.0), 0.1);
74+
EXPECT_GT(stdev, 0);
75+
}
76+
};
77+
78+
TEST_F(OpRandnTest, SmokeTest) {
79+
std::vector<int64_t> sizes = {2, 3, 4, 128};
80+
81+
#define TEST_ENTRY(ctype, dtype) test_randn<ctype, ScalarType::dtype>(sizes);
82+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
83+
#undef TEST_ENTRY
84+
}
85+
86+
TEST_F(OpRandnTest, Rank) {
87+
std::vector<int64_t> sizes = {1024};
88+
89+
for (int64_t i = 0; i < 4; i++) {
90+
sizes.push_back(i + 1);
91+
test_randn<float, executorch::aten::ScalarType::Float>(sizes);
92+
}
93+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def define_common_targets():
286286
_common_op_test("op_pow_test", ["aten", "portable"])
287287
_common_op_test("op_prod_test", ["aten", "portable"])
288288
_common_op_test("op_rand_test", ["aten", "portable"])
289+
_common_op_test("op_randn_test", ["aten", "portable"])
289290
_common_op_test("op_reciprocal_test", ["aten", "portable"])
290291
_common_op_test("op_relu_test", ["aten", "portable"])
291292
_common_op_test("op_remainder_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,14 @@ ATEN_OPS = (
981981
"//executorch/runtime/core/exec_aten/util:tensor_util",
982982
]
983983
),
984+
op_target(
985+
name = "op_randn",
986+
deps = [
987+
":scalar_utils",
988+
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
989+
"//executorch/runtime/core/exec_aten/util:tensor_util",
990+
]
991+
),
984992
op_target(
985993
name = "op_reciprocal",
986994
deps = [

0 commit comments

Comments
 (0)