Skip to content

Commit 783a38a

Browse files
committed
Add portable randn kernel implementation
1 parent f2fb351 commit 783a38a

File tree

6 files changed

+173
-0
lines changed

6 files changed

+173
-0
lines changed

kernels/portable/cpu/op_randn.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <c10/util/irange.h>
9+
10+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
#include <random>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using executorch::aten::IntArrayRef;
20+
using Tensor = executorch::aten::Tensor;
21+
using ScalarType = executorch::aten::ScalarType;
22+
23+
template <class CTYPE>
24+
void impl(
25+
CTYPE* data,
26+
int64_t numel,
27+
std::mt19937& gen,
28+
std::normal_distribution<double>& dist) {
29+
for (const auto i : c10::irange(numel)) {
30+
auto val = dist(gen);
31+
data[i] = static_cast<CTYPE>(val);
32+
}
33+
}
34+
35+
Tensor&
36+
randn_out(KernelRuntimeContext& ctx, const IntArrayRef sizes, Tensor& out) {
37+
(void)ctx;
38+
39+
std::mt19937 gen((std::random_device())());
40+
std::normal_distribution<double> dist(0.0, 1.0);
41+
42+
// Resize for dynamic shape
43+
ET_KERNEL_CHECK_MSG(
44+
ctx,
45+
resize_tensor(out, sizes) == Error::Ok,
46+
InvalidArgument,
47+
out,
48+
"Failed to resize output tensor.");
49+
50+
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, "randn.out", CTYPE, [&] {
51+
auto data_out = out.mutable_data_ptr<CTYPE>();
52+
impl(data_out, out.numel(), gen, dist);
53+
/*
54+
for (const auto i : c10::irange(out.numel())) {
55+
data_out[i] = static_cast<CTYPE>(dist(gen));
56+
}*/
57+
});
58+
59+
return out;
60+
}
61+
62+
} // namespace native
63+
} // namespace executor
64+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,12 @@
713713
- arg_meta: null
714714
kernel_name: torch::executor::prod_out
715715

716+
- op: randn.out
717+
kernels:
718+
- arg_meta: null
719+
kernel_name: torch::executor::randn_out
720+
tags: nondeterministic_seeded
721+
716722
- op: reciprocal.out
717723
kernels:
718724
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ set(all_test_sources
197197
"op_permute_copy_test.cpp"
198198
"op_pixel_shuffle_test.cpp"
199199
"op_prod_test.cpp"
200+
"op_randn_test.cpp"
200201
"op_reciprocal_test.cpp"
201202
"op_relu_test.cpp"
202203
"op_remainder_test.cpp"

kernels/test/op_randn_test.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <c10/util/irange.h>
10+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/kernels/test/supported_features.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
#include <cmath>
20+
#include <numeric>
21+
22+
using executorch::aten::IntArrayRef;
23+
using executorch::aten::ScalarType;
24+
using executorch::aten::Tensor;
25+
using torch::executor::testing::TensorFactory;
26+
27+
class OpRandnTest : public OperatorTest {
28+
protected:
29+
void op_randn_out(const IntArrayRef sizes, Tensor& out) {
30+
torch::executor::aten::randn_outf(context_, sizes, out);
31+
}
32+
33+
template <typename CTYPE, ScalarType DTYPE>
34+
void test_randn(std::vector<int64_t>& sizes) {
35+
TensorFactory<DTYPE> tf;
36+
37+
// Tensor factory wants int32 scales, op kernel wants int64.
38+
std::vector<int32_t> sizes_i32;
39+
std::transform(
40+
sizes.begin(),
41+
sizes.end(),
42+
std::back_inserter(sizes_i32),
43+
[](int64_t s) { return static_cast<int32_t>(s); });
44+
Tensor out = tf.zeros(sizes_i32);
45+
46+
IntArrayRef sizes_ref(sizes.data(), sizes.size());
47+
op_randn_out(sizes_ref, out);
48+
49+
// Check mean and standard deviation. To avoid flaky CI, test pretty
50+
// loosely.
51+
auto out_data = out.const_data_ptr<CTYPE>();
52+
double mean =
53+
std::accumulate(
54+
out_data,
55+
out_data + out.numel(),
56+
0.0,
57+
[](double acc, CTYPE n) { return acc + static_cast<double>(n); }) /
58+
out.numel();
59+
double var = std::accumulate(
60+
out_data,
61+
out_data + out.numel(),
62+
0.0,
63+
[=](double acc, CTYPE n) {
64+
return acc + std::pow(static_cast<double>(n) - mean, 2);
65+
}) /
66+
out.numel();
67+
auto stdev = std::sqrt(var);
68+
69+
// These are very rough thresholds. A better test implementation would
70+
// probably do a proper statistical test to compare the generated empirical
71+
// data to the reference distribution, but this should do for now.
72+
EXPECT_LE(std::abs(mean), 5.0 / std::sqrt(out.numel()));
73+
EXPECT_LE(std::abs(stdev - 1.0), 0.1);
74+
EXPECT_GT(stdev, 0);
75+
}
76+
};
77+
78+
TEST_F(OpRandnTest, SmokeTest) {
79+
std::vector<int64_t> sizes = {2, 3, 4, 128};
80+
81+
#define TEST_ENTRY(ctype, dtype) test_randn<ctype, ScalarType::dtype>(sizes);
82+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
83+
#undef TEST_ENTRY
84+
}
85+
86+
TEST_F(OpRandnTest, Rank) {
87+
std::vector<int64_t> sizes = {1024};
88+
89+
for (int64_t i = 0; i < 4; i++) {
90+
sizes.push_back(i + 1);
91+
test_randn<float, executorch::aten::ScalarType::Float>(sizes);
92+
}
93+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def define_common_targets():
285285
_common_op_test("op_pixel_unshuffle_test", ["aten", "portable"])
286286
_common_op_test("op_pow_test", ["aten", "portable"])
287287
_common_op_test("op_prod_test", ["aten", "portable"])
288+
_common_op_test("op_randn_test", ["aten", "portable"])
288289
_common_op_test("op_reciprocal_test", ["aten", "portable"])
289290
_common_op_test("op_relu_test", ["aten", "portable"])
290291
_common_op_test("op_remainder_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,14 @@ ATEN_OPS = (
973973
"//executorch/kernels/portable/cpu/util:reduce_util",
974974
],
975975
),
976+
op_target(
977+
name = "op_randn",
978+
deps = [
979+
":scalar_utils",
980+
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
981+
"//executorch/runtime/core/exec_aten/util:tensor_util",
982+
]
983+
),
976984
op_target(
977985
name = "op_reciprocal",
978986
deps = [

0 commit comments

Comments
 (0)