Skip to content

Commit 71e9232

Browse files
authored
Merge branch 'main' into export-D72726244
2 parents 59e3d16 + edf25c4 commit 71e9232

File tree

9 files changed

+208
-5
lines changed

9 files changed

+208
-5
lines changed

examples/demo-apps/apple_ios/LLaMA/LLaMA.xcodeproj/project.pbxproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
306A71512DC1DC3D00936B1F /* pre_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A71472DC1DC3D00936B1F /* pre_tokenizer.cpp */; };
6565
306A71522DC1DC3D00936B1F /* token_decoder.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 306A714B2DC1DC3D00936B1F /* token_decoder.cpp */; };
6666
3072D5232DC3EA280083FC83 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = 3072D5222DC3EA280083FC83 /* Constants.swift */; };
67+
F24909E82E207004001E5B69 /* normalizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F24909E72E207004001E5B69 /* normalizer.cpp */; };
6768
F292B0752D88B0C200BE6839 /* tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06F2D88B0C200BE6839 /* tiktoken.cpp */; };
6869
F292B0762D88B0C200BE6839 /* llama2c_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */; };
6970
F292B0772D88B0C200BE6839 /* bpe_tokenizer_base.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */; };
@@ -152,6 +153,7 @@
152153
306A714A2DC1DC3D00936B1F /* std_regex.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = std_regex.cpp; path = src/std_regex.cpp; sourceTree = "<group>"; };
153154
306A714B2DC1DC3D00936B1F /* token_decoder.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = token_decoder.cpp; path = src/token_decoder.cpp; sourceTree = "<group>"; };
154155
3072D5222DC3EA280083FC83 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = "<group>"; };
156+
F24909E72E207004001E5B69 /* normalizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = normalizer.cpp; path = src/normalizer.cpp; sourceTree = "<group>"; };
155157
F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = bpe_tokenizer_base.cpp; path = src/bpe_tokenizer_base.cpp; sourceTree = "<group>"; };
156158
F292B06C2D88B0C200BE6839 /* llama2c_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = llama2c_tokenizer.cpp; path = src/llama2c_tokenizer.cpp; sourceTree = "<group>"; };
157159
F292B06F2D88B0C200BE6839 /* tiktoken.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = tiktoken.cpp; path = src/tiktoken.cpp; sourceTree = "<group>"; };
@@ -309,6 +311,7 @@
309311
03729F0E2BB203D700152F2E /* tokenizers */ = {
310312
isa = PBXGroup;
311313
children = (
314+
F24909E72E207004001E5B69 /* normalizer.cpp */,
312315
F292B06A2D88B0C200BE6839 /* bpe_tokenizer_base.cpp */,
313316
306A71452DC1DC3D00936B1F /* hf_tokenizer.cpp */,
314317
F292B1002D88B20C00BE6839 /* llama_tiktoken.cpp */,
@@ -598,6 +601,7 @@
598601
files = (
599602
03D151B82E0E0908007A38BE /* LLaVARunner.mm in Sources */,
600603
03729EE12BB1F93800152F2E /* LLaMARunner.mm in Sources */,
604+
F24909E82E207004001E5B69 /* normalizer.cpp in Sources */,
601605
0372C3152C89418E00CD942A /* llava_runner.cpp in Sources */,
602606
03D151CA2E0E98C4007A38BE /* sentencepiece.cpp in Sources */,
603607
03D151D92E0E9E43007A38BE /* ExecuTorchTextLLMRunner.mm in Sources */,

exir/emit/test/test_emit.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1875,3 +1875,55 @@ def forward(self, x):
18751875
),
18761876
)
18771877
)
1878+
1879+
def test_emit_sym_min_max(self) -> None:
1880+
class SymMaxModel(nn.Module):
1881+
def __init__(self, test_min=False):
1882+
super().__init__()
1883+
self.test_min = test_min
1884+
1885+
def forward(self, x):
1886+
# Get size of 0th dimension - this creates sym_size op
1887+
batch_size = x.shape[0]
1888+
# Compute max of batch_size and 10 - this should create sym_max op
1889+
if self.test_min:
1890+
out_size = min(batch_size, 10)
1891+
else:
1892+
out_size = max(batch_size, 10)
1893+
# Create a 1D tensor of zeros with the computed size
1894+
result = torch.zeros(out_size, dtype=x.dtype, device=x.device)
1895+
return result
1896+
1897+
for validate_min in [True, False]:
1898+
model = SymMaxModel(test_min=validate_min)
1899+
test_inputs = [
1900+
torch.randn(5, 3), # should output zeros(10) for max zeros(5) for min
1901+
torch.randn(15, 3), # should output zeros(15) for max zeros(10) for min
1902+
torch.randn(10, 3), # should output zeros(10) for max zeros(10) for min
1903+
]
1904+
model.eval()
1905+
reference_outputs = []
1906+
with torch.no_grad():
1907+
for _, inp in enumerate(test_inputs):
1908+
output = model(inp)
1909+
reference_outputs.append(output)
1910+
1911+
batch_dim = Dim("batch", min=1, max=20)
1912+
dynamic_shapes = {"x": {0: batch_dim}} # 0th dimension is dynamic
1913+
exported_program = torch.export.export(
1914+
model, (test_inputs[0],), dynamic_shapes=dynamic_shapes
1915+
)
1916+
edge_program = to_edge(
1917+
exported_program,
1918+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1919+
)
1920+
et_program = edge_program.to_executorch()
1921+
program_buffer = et_program.buffer
1922+
et_module = _load_for_executorch_from_buffer(program_buffer)
1923+
for _, (inp, expected) in enumerate(zip(test_inputs, reference_outputs)):
1924+
# Execute with ExecutorTorch
1925+
et_output = et_module.forward([inp])
1926+
et_result = et_output[0] # Get first output
1927+
# Compare results
1928+
self.assertTrue(expected.shape == et_result.shape)
1929+
self.assertTrue(torch.allclose(expected, et_result))

exir/passes/executorch_prim_ops_registry.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ def trunc(a: _SymScalar) -> _SymScalar:
110110
return math.trunc(a) # pyre-ignore
111111

112112

113+
@bind_pattern_to_op(
114+
executorch_prims_lib, "sym_max.Scalar(Scalar a, Scalar b) -> Scalar"
115+
)
116+
def sym_max(a: _SymScalar, b: _SymScalar) -> bool:
117+
return max(a, b) # pyre-ignore
118+
119+
120+
@bind_pattern_to_op(
121+
executorch_prims_lib, "sym_min.Scalar(Scalar a, Scalar b) -> Scalar"
122+
)
123+
def sym_min(a: _SymScalar, b: _SymScalar) -> bool:
124+
return min(a, b) # pyre-ignore
125+
126+
113127
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[Any, OpOverload] = {
114128
builtins.round: ops.backend.executorch_prim.round.Scalar,
115129
math.ceil: ops.backend.executorch_prim.ceil.Scalar,
@@ -127,12 +141,12 @@ def trunc(a: _SymScalar) -> _SymScalar:
127141
operator.mod: ops.backend.executorch_prim.mod.Scalar,
128142
operator.neg: ops.backend.executorch_prim.neg.Scalar,
129143
torch.sym_float: ops.backend.executorch_prim.sym_float.Scalar,
144+
torch.sym_max: ops.backend.executorch_prim.sym_max.Scalar,
145+
torch.sym_min: ops.backend.executorch_prim.sym_min.Scalar,
130146
}
131147

132148

133-
_EXECUTORCH_SYM_OPS: Set[OpOverload] = set(
134-
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.values()
135-
)
149+
_EXECUTORCH_SYM_OPS: Set[Any] = set(_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS.values())
136150
_EXECUTORCH_SYM_OPS.update(
137151
{
138152
torch.ops.aten.sym_stride.int,

extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
30AA4B662DC0766800B1BE50 /* re2_regex.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 30AA4B5C2DC0766800B1BE50 /* re2_regex.cpp */; };
3636
3C6ABD332DFA27DE0015DE55 /* regex_lookahead.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 3C6ABD322DFA27DE0015DE55 /* regex_lookahead.cpp */; };
3737
F22E9E1A2DF2CBB900EC5425 /* text_llm_runner.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F22E9E192DF2CBB900EC5425 /* text_llm_runner.cpp */; };
38+
F24909E22E206FBA001E5B69 /* normalizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F24909E12E206FBA001E5B69 /* normalizer.cpp */; };
3839
F292B01D2D88AF3500BE6839 /* bpe_tokenizer_base.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B0162D88AF3500BE6839 /* bpe_tokenizer_base.cpp */; };
3940
F292B0202D88AF3500BE6839 /* llama2c_tokenizer.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B0172D88AF3500BE6839 /* llama2c_tokenizer.cpp */; };
4041
F292B0212D88AF3500BE6839 /* tiktoken.cpp in Sources */ = {isa = PBXBuildFile; fileRef = F292B01A2D88AF3500BE6839 /* tiktoken.cpp */; };
@@ -100,6 +101,7 @@
100101
3C6ABD322DFA27DE0015DE55 /* regex_lookahead.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = regex_lookahead.cpp; path = src/regex_lookahead.cpp; sourceTree = "<group>"; };
101102
F22E9E182DF2CBB900EC5425 /* text_llm_runner.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = text_llm_runner.h; sourceTree = "<group>"; };
102103
F22E9E192DF2CBB900EC5425 /* text_llm_runner.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = text_llm_runner.cpp; sourceTree = "<group>"; };
104+
F24909E12E206FBA001E5B69 /* normalizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = normalizer.cpp; path = src/normalizer.cpp; sourceTree = "<group>"; };
103105
F292B0162D88AF3500BE6839 /* bpe_tokenizer_base.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = bpe_tokenizer_base.cpp; path = src/bpe_tokenizer_base.cpp; sourceTree = "<group>"; };
104106
F292B0172D88AF3500BE6839 /* llama2c_tokenizer.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = llama2c_tokenizer.cpp; path = src/llama2c_tokenizer.cpp; sourceTree = "<group>"; };
105107
F292B01A2D88AF3500BE6839 /* tiktoken.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; name = tiktoken.cpp; path = src/tiktoken.cpp; sourceTree = "<group>"; };
@@ -185,6 +187,7 @@
185187
032A74022CAFBB7800932D36 /* tokenizers */ = {
186188
isa = PBXGroup;
187189
children = (
190+
F24909E12E206FBA001E5B69 /* normalizer.cpp */,
188191
F2E1B5162E03AC19002C9718 /* sentencepiece.cpp */,
189192
3C6ABD322DFA27DE0015DE55 /* regex_lookahead.cpp */,
190193
30AA4B592DC0766800B1BE50 /* hf_tokenizer.cpp */,
@@ -430,6 +433,7 @@
430433
F292B0202D88AF3500BE6839 /* llama2c_tokenizer.cpp in Sources */,
431434
F292B0212D88AF3500BE6839 /* tiktoken.cpp in Sources */,
432435
F2E1B5172E03AC19002C9718 /* sentencepiece.cpp in Sources */,
436+
F24909E22E206FBA001E5B69 /* normalizer.cpp in Sources */,
433437
03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */,
434438
032A74232CAFC1B300932D36 /* runner.cpp in Sources */,
435439
03B2D37A2C8A515C0046936E /* GenericTests.mm in Sources */,

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313
#include <executorch/runtime/kernel/operator_registry.h>
1414

15+
#include <algorithm>
1516
#include <cmath>
1617

1718
using torch::executor::function::et_copy_index;
@@ -120,6 +121,48 @@ static Kernel prim_ops[] = {
120121
int64_t numel = self_tensor.numel();
121122
out = EValue(numel);
122123
}),
124+
// executorch_prim::sym_max.Scalar(SymInt a, SymInt b) -> SymInt
125+
Kernel(
126+
"executorch_prim::sym_max.Scalar",
127+
[](KernelRuntimeContext& context, EValue** stack) {
128+
(void)context;
129+
EValue& a = *stack[0];
130+
EValue& b = *stack[1];
131+
EValue& out = *stack[2];
132+
if (a.isInt() && b.isInt()) {
133+
out = EValue(std::max(a.toInt(), b.toInt()));
134+
} else {
135+
ET_KERNEL_CHECK_MSG(
136+
context,
137+
false,
138+
InvalidType,
139+
/* void */,
140+
"sym_max only supports int inputs, got %zu, %zu",
141+
(size_t)a.tag,
142+
(size_t)b.tag);
143+
}
144+
}),
145+
// executorch_prim::sym_min.Scalar(SymInt a, SymInt b) -> SymInt
146+
Kernel(
147+
"executorch_prim::sym_min.Scalar",
148+
[](KernelRuntimeContext& context, EValue** stack) {
149+
(void)context;
150+
EValue& a = *stack[0];
151+
EValue& b = *stack[1];
152+
EValue& out = *stack[2];
153+
if (a.isInt() && b.isInt()) {
154+
out = EValue(std::min(a.toInt(), b.toInt()));
155+
} else {
156+
ET_KERNEL_CHECK_MSG(
157+
context,
158+
false,
159+
InvalidType,
160+
/* void */,
161+
"sym_min only supports int inputs, got %zu, %zu",
162+
(size_t)a.tag,
163+
(size_t)b.tag);
164+
}
165+
}),
123166
// executorch_prim::add.Scalar(Scalar, Scalar) -> Scalar
124167
Kernel(
125168
"executorch_prim::add.Scalar",

kernels/prim_ops/test/prim_ops_test.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class RegisterPrimOpsTest : public OperatorTest {
3737
TEST_F(RegisterPrimOpsTest, OpRegistered) {
3838
EXPECT_TRUE(hasOpsFn("aten::sym_size.int"));
3939
EXPECT_TRUE(hasOpsFn("aten::sym_numel"));
40+
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_max.Scalar"));
41+
EXPECT_TRUE(hasOpsFn("executorch_prim::sym_min.Scalar"));
4042
}
4143

4244
TEST_F(RegisterPrimOpsTest, SymSizeReturnsCorrectValue) {
@@ -81,6 +83,88 @@ TEST_F(RegisterPrimOpsTest, SymNumelReturnsCorrectValue) {
8183
EXPECT_EQ(stack[1]->toInt(), expected);
8284
}
8385

86+
TEST_F(RegisterPrimOpsTest, SymMaxReturnsCorrectValue) {
87+
EValue values[3];
88+
int64_t a = 5;
89+
int64_t b = 3;
90+
int64_t out = 0;
91+
values[0] = EValue(a);
92+
values[1] = EValue(b);
93+
values[2] = EValue(out);
94+
95+
EValue* stack[3];
96+
for (size_t i = 0; i < 3; i++) {
97+
stack[i] = &values[i];
98+
}
99+
100+
getOpsFn("executorch_prim::sym_max.Scalar")(context_, stack);
101+
EXPECT_EQ(stack[2]->toInt(), 5);
102+
103+
// Test with swapped values
104+
values[0] = EValue(b);
105+
values[1] = EValue(a);
106+
values[2] = EValue(out);
107+
getOpsFn("executorch_prim::sym_max.Scalar")(context_, stack);
108+
EXPECT_EQ(stack[2]->toInt(), 5);
109+
110+
// Test with equal values
111+
values[0] = EValue(a);
112+
values[1] = EValue(a);
113+
values[2] = EValue(out);
114+
getOpsFn("executorch_prim::sym_max.Scalar")(context_, stack);
115+
EXPECT_EQ(stack[2]->toInt(), 5);
116+
117+
// Test with negative values
118+
a = -2;
119+
b = -5;
120+
values[0] = EValue(a);
121+
values[1] = EValue(b);
122+
values[2] = EValue(out);
123+
getOpsFn("executorch_prim::sym_max.Scalar")(context_, stack);
124+
EXPECT_EQ(stack[2]->toInt(), -2);
125+
}
126+
127+
TEST_F(RegisterPrimOpsTest, SymMinReturnsCorrectValue) {
128+
EValue values[3];
129+
int64_t a = 5;
130+
int64_t b = 3;
131+
int64_t out = 0;
132+
values[0] = EValue(a);
133+
values[1] = EValue(b);
134+
values[2] = EValue(out);
135+
136+
EValue* stack[3];
137+
for (size_t i = 0; i < 3; i++) {
138+
stack[i] = &values[i];
139+
}
140+
141+
getOpsFn("executorch_prim::sym_min.Scalar")(context_, stack);
142+
EXPECT_EQ(stack[2]->toInt(), 3);
143+
144+
// Test with swapped values
145+
values[0] = EValue(b);
146+
values[1] = EValue(a);
147+
values[2] = EValue(out);
148+
getOpsFn("executorch_prim::sym_min.Scalar")(context_, stack);
149+
EXPECT_EQ(stack[2]->toInt(), 3);
150+
151+
// Test with equal values
152+
values[0] = EValue(a);
153+
values[1] = EValue(a);
154+
values[2] = EValue(out);
155+
getOpsFn("executorch_prim::sym_min.Scalar")(context_, stack);
156+
EXPECT_EQ(stack[2]->toInt(), 5);
157+
158+
// Test with negative values
159+
a = -2;
160+
b = -5;
161+
values[0] = EValue(a);
162+
values[1] = EValue(b);
163+
values[2] = EValue(out);
164+
getOpsFn("executorch_prim::sym_min.Scalar")(context_, stack);
165+
EXPECT_EQ(stack[2]->toInt(), -5);
166+
}
167+
84168
TEST_F(RegisterPrimOpsTest, TestAlgebraOps) {
85169
EValue values[3];
86170
int64_t a = 3;

third-party/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
78
add_subdirectory(json)
89
add_subdirectory(gflags)
910

@@ -86,6 +87,7 @@ ExternalProject_Add(
8687
-DFLATCC_REFLECTION=OFF
8788
-DFLATCC_DEBUG_CLANG_SANITIZE=OFF
8889
-DFLATCC_INSTALL=ON
90+
-DCMAKE_POLICY_VERSION_MINIMUM=3.5
8991
-DCMAKE_INSTALL_PREFIX:PATH=<INSTALL_DIR>
9092
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
9193
-DCMAKE_TOOLCHAIN_FILE=

third-party/ao

Submodule ao updated 260 files

0 commit comments

Comments
 (0)