Skip to content

Commit c5ecea6

Browse files
[llm] Support different shape of input_pos (#11966)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11869 by @larryliu0820 ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/larryliu0820/67/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/larryliu0820/67/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/larryliu0820/66/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/larryliu0820/67/orig @diff-train-skip-merge --------- Co-authored-by: Mengwei Liu <[email protected]> Co-authored-by: Mengwei Liu <[email protected]>
1 parent 8b3b028 commit c5ecea6

18 files changed

+354
-59
lines changed

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,28 @@
1111
#pragma once
1212

1313
#include <executorch/extension/llm/runner/text_decoder_runner.h>
14+
#include <executorch/extension/tensor/tensor.h>
1415

1516
namespace example {
1617

1718
class ET_EXPERIMENTAL LlavaTextDecoderRunner
1819
: public executorch::extension::llm::TextDecoderRunner {
1920
public:
2021
explicit LlavaTextDecoderRunner(executorch::extension::Module* module)
21-
: TextDecoderRunner(module, true) {}
22+
: TextDecoderRunner(module) {}
2223

2324
inline executorch::runtime::Result<executorch::aten::Tensor> step(
2425
executorch::extension::TensorPtr& tokens,
25-
executorch::extension::TensorPtr& start_pos) override {
26+
int64_t start_pos) override {
2627
// run token embedding
2728
auto token_embedding_outputs =
2829
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));
2930

31+
auto start_pos_tensor = ::executorch::extension::from_blob(
32+
&start_pos, {1}, executorch::aten::ScalarType::Long);
3033
// run text model
3134
auto outputs_res = ET_UNWRAP(module_->execute(
32-
kTextModelMethod, {start_pos, token_embedding_outputs[0]}));
35+
kTextModelMethod, {start_pos_tensor, token_embedding_outputs[0]}));
3336

3437
ET_CHECK_MSG(
3538
outputs_res.size() == 1,

extension/llm/runner/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ add_subdirectory(
5656
set(runner_deps executorch_core extension_module extension_tensor tokenizers)
5757

5858
target_link_libraries(extension_llm_runner PUBLIC ${runner_deps})
59+
set_target_properties(extension_llm_runner PROPERTIES POSITION_INDEPENDENT_CODE ON)
5960

6061
target_include_directories(
6162
extension_llm_runner

extension/llm/runner/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def define_common_targets():
3434
],
3535
exported_deps = [
3636
":stats",
37+
"//executorch/kernels/portable/cpu/util:arange_util" + aten_suffix,
3738
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
3839
"//executorch/extension/module:module" + aten_suffix,
3940
"//executorch/extension/tensor:tensor" + aten_suffix,

extension/llm/runner/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
1818
include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
1919

2020
set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp
21-
test_text_prefiller.cpp
21+
test_text_prefiller.cpp test_text_decoder_runner.cpp
2222
)
2323

2424
et_cxx_test(

extension/llm/runner/test/TARGETS

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,22 @@
88
# targets.bzl. This file can contain fbcode-only targets.
99

1010
load(":targets.bzl", "define_common_targets")
11-
11+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
1212
oncall("executorch")
1313

1414
define_common_targets()
15+
16+
runtime.cxx_test(
17+
name = "test_text_decoder_runner",
18+
srcs = ["test_text_decoder_runner.cpp"],
19+
deps = [
20+
"//executorch/extension/llm/runner:runner_lib",
21+
"//executorch/kernels/portable:generated_lib",
22+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
23+
],
24+
env = {
25+
"KVCACHE_CACHE_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheCachePos.pte])",
26+
"KVCACHE_INPUT_POS": "$(location fbcode//executorch/test/models:exported_programs[ModuleKVCacheInputPos.pte])",
27+
"NO_KVCACHE": "$(location fbcode//executorch/test/models:exported_programs[ModuleNoKVCache.pte])",
28+
}
29+
)
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
8+
*/
9+
10+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
11+
#include <executorch/extension/module/module.h>
12+
#include <executorch/extension/tensor/tensor.h>
13+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14+
#include <gmock/gmock.h>
15+
#include <gtest/gtest.h>
16+
#include <cstdlib>
17+
18+
using namespace ::testing;
19+
using executorch::extension::Module;
20+
using executorch::extension::TensorPtr;
21+
using executorch::extension::llm::TextDecoderRunner;
22+
using executorch::runtime::Error;
23+
using executorch::runtime::EValue;
24+
using executorch::runtime::Result;
25+
using executorch::runtime::testing::TensorFactory;
26+
27+
// Mock Module class for testing
28+
class MockModule : public Module {
29+
public:
30+
MockModule() : Module("") {}
31+
};
32+
33+
class TextDecoderRunnerTest : public Test {
34+
protected:
35+
void SetUp() override {
36+
mock_module_ = std::make_unique<MockModule>();
37+
runner_ = std::make_unique<TextDecoderRunner>(mock_module_.get());
38+
}
39+
40+
std::unique_ptr<MockModule> mock_module_;
41+
std::unique_ptr<TextDecoderRunner> runner_;
42+
};
43+
44+
// Test logits_to_token() method with Float tensor
45+
TEST_F(TextDecoderRunnerTest, LogitsToTokenFloat) {
46+
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
47+
auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
48+
49+
// Call logits_to_token with temperature 0 (deterministic)
50+
int32_t token = runner_->logits_to_token(logits, 0.0f);
51+
52+
// With temperature 0, should return the argmax (index 2)
53+
EXPECT_EQ(token, 2);
54+
}
55+
56+
// Test logits_to_token() method with 3D tensor (batch, seq_length, vocab_size)
57+
TEST_F(TextDecoderRunnerTest, LogitsToToken3D) {
58+
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
59+
// Shape: [1, 2, 4] - batch=1, seq_length=2, vocab_size=4
60+
auto logits = tf_float.make(
61+
{1, 2, 4},
62+
{
63+
0.1f,
64+
0.2f,
65+
0.3f,
66+
0.4f, // First sequence position
67+
0.5f,
68+
0.6f,
69+
0.9f,
70+
0.8f // Second sequence position (last)
71+
});
72+
73+
// Call logits_to_token with temperature 0 (deterministic)
74+
int32_t token = runner_->logits_to_token(logits, 0.0f);
75+
76+
// Should use the last sequence position and return argmax (index 2)
77+
EXPECT_EQ(token, 2);
78+
}
79+
80+
// Test logits_to_token() method with Half tensor
81+
TEST_F(TextDecoderRunnerTest, LogitsToTokenHalf) {
82+
TensorFactory<executorch::aten::ScalarType::Half> tf_half;
83+
auto logits = tf_half.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
84+
85+
// Call logits_to_token with temperature 0 (deterministic)
86+
int32_t token = runner_->logits_to_token(logits, 0.0f);
87+
88+
// With temperature 0, should return the argmax (index 2)
89+
EXPECT_EQ(token, 2);
90+
}
91+
92+
// Test logits_to_token() method with BFloat16 tensor
93+
TEST_F(TextDecoderRunnerTest, LogitsToTokenBFloat16) {
94+
TensorFactory<executorch::aten::ScalarType::BFloat16> tf_bfloat16;
95+
auto logits = tf_bfloat16.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
96+
97+
// Call logits_to_token with temperature 0 (deterministic)
98+
int32_t token = runner_->logits_to_token(logits, 0.0f);
99+
100+
// With temperature 0, should return the argmax (index 2)
101+
EXPECT_EQ(token, 2);
102+
}
103+
104+
// Test logits_to_token() method with non-zero temperature
105+
TEST_F(TextDecoderRunnerTest, LogitsToTokenWithTemperature) {
106+
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
107+
auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
108+
109+
// Call logits_to_token with temperature > 0 (stochastic)
110+
int32_t token = runner_->logits_to_token(logits, 1.0f);
111+
112+
// With temperature > 0, result should be within valid range
113+
EXPECT_GE(token, 0);
114+
EXPECT_LT(token, 4);
115+
}
116+
117+
// Test step() method with all available PTE models
118+
TEST_F(TextDecoderRunnerTest, StepWithAllModels) {
119+
// List of all environment variables for PTE models
120+
std::vector<std::pair<std::string, const char*>> env_vars = {
121+
{"KVCACHE_CACHE_POS", "KVCACHE_CACHE_POS"},
122+
{"KVCACHE_INPUT_POS", "KVCACHE_INPUT_POS"},
123+
{"NO_KVCACHE", "NO_KVCACHE"}};
124+
125+
// Check if any environment variables are set up front
126+
bool any_env_set = false;
127+
for (const auto& [model_name, env_var] : env_vars) {
128+
if (std::getenv(env_var)) {
129+
any_env_set = true;
130+
break;
131+
}
132+
}
133+
134+
// Skip test if no environment variables are set
135+
if (!any_env_set) {
136+
GTEST_SKIP() << "No PTE model environment variables were set";
137+
}
138+
139+
bool any_model_tested = false;
140+
141+
// Loop through all available models
142+
for (const auto& [model_name, env_var] : env_vars) {
143+
const char* model_path = std::getenv(env_var);
144+
if (!model_path) {
145+
continue; // Skip if environment variable not set
146+
}
147+
148+
SCOPED_TRACE(
149+
"Testing model: " + model_name + " from " + std::string(model_path));
150+
151+
// Load the model
152+
auto module = std::make_unique<Module>(model_path);
153+
auto load_result = module->load();
154+
if (load_result != Error::Ok) {
155+
ADD_FAILURE() << "Failed to load model " << model_name << " from "
156+
<< model_path << " with error: " << (int)load_result;
157+
continue;
158+
}
159+
160+
// Create TextDecoderRunner
161+
TextDecoderRunner runner(module.get());
162+
auto runner_load_result = runner.load();
163+
ASSERT_EQ(runner_load_result, Error::Ok)
164+
<< "Failed to load runner for " << model_name;
165+
166+
// Verify method is loaded
167+
EXPECT_TRUE(runner.is_method_loaded())
168+
<< "Method not loaded for " << model_name;
169+
170+
// Create input tensor pointer
171+
172+
TensorFactory<executorch::aten::ScalarType::Long> tf_long;
173+
auto input_tokens_ =
174+
tf_long.make({1, 3}, {50, 7, 11}); // Single token input
175+
176+
auto input_ptr = std::make_shared<executorch::aten::Tensor>(input_tokens_);
177+
int64_t start_pos = 0;
178+
179+
// Call step() and verify result is ok
180+
auto result = runner.step(input_ptr, start_pos);
181+
ASSERT_TRUE(result.ok()) << "step() failed for " << model_name
182+
<< " with error: " << (int)result.error();
183+
184+
// Verify output tensor is valid
185+
auto output_tensor = result.get();
186+
EXPECT_GT(output_tensor.numel(), 0)
187+
<< "Output tensor empty for " << model_name;
188+
189+
// Test logits_to_token works
190+
int32_t token = runner.logits_to_token(output_tensor, 0.0f);
191+
EXPECT_GE(token, 0) << "Invalid token for " << model_name;
192+
193+
any_model_tested = true;
194+
}
195+
196+
// This should not happen since we checked environment variables up front
197+
ASSERT_TRUE(any_model_tested)
198+
<< "No models were tested despite environment variables being set";
199+
}

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ class MockModule : public ::executorch::extension::Module {
6363

6464
class MockTextDecoderRunner : public TextDecoderRunner {
6565
public:
66-
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
66+
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
6767
MOCK_METHOD(
6868
Result<executorch::aten::Tensor>,
6969
step,
70-
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
70+
(executorch::extension::TensorPtr&, int64_t),
7171
());
7272
MOCK_METHOD(bool, is_method_loaded, (), ());
7373
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
@@ -134,8 +134,7 @@ class RunnerTest : public Test {
134134
std::unique_ptr<MockTextDecoderRunner> createMockTextDecoderRunner() {
135135
auto text_decoder_runner = std::make_unique<MockTextDecoderRunner>();
136136
ON_CALL(*text_decoder_runner, step)
137-
.WillByDefault([&](executorch::extension::TensorPtr&,
138-
executorch::extension::TensorPtr&) {
137+
.WillByDefault([&](executorch::extension::TensorPtr&, int64_t) {
139138
return Result<executorch::aten::Tensor>(tensor);
140139
});
141140
ON_CALL(*text_decoder_runner, is_method_loaded())

extension/llm/runner/test/test_text_prefiller.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ using executorch::runtime::testing::TensorFactory;
2424
// Mock class for TextDecoderRunner
2525
class MockTextDecoderRunner : public TextDecoderRunner {
2626
public:
27-
MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {}
27+
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
2828
MOCK_METHOD(
2929
Result<executorch::aten::Tensor>,
3030
step,
31-
(executorch::extension::TensorPtr&, executorch::extension::TensorPtr&),
31+
(executorch::extension::TensorPtr&, int64_t),
3232
());
3333
MOCK_METHOD(bool, is_method_loaded, (), ());
3434
MOCK_METHOD(Result<uint64_t>, prefill, (std::vector<uint64_t>&, int64_t), ());
@@ -44,8 +44,7 @@ class TextPrefillerTest : public Test {
4444
ON_CALL(text_decoder_runner_, is_method_loaded())
4545
.WillByDefault(Return(true));
4646
ON_CALL(text_decoder_runner_, step)
47-
.WillByDefault([&](executorch::extension::TensorPtr&,
48-
executorch::extension::TensorPtr&) {
47+
.WillByDefault([&](executorch::extension::TensorPtr&, int64_t) {
4948
return Result<executorch::aten::Tensor>(tensor);
5049
});
5150
}

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// Given inputs, run a text decoder and return logits.
1010

1111
#include <executorch/extension/llm/runner/text_decoder_runner.h>
12+
#include <executorch/kernels/portable/cpu/util/arange_util.h>
1213

1314
#include <ctime>
1415

@@ -21,18 +22,53 @@ namespace llm {
2122
// NOTE: we observed ~2x loading performance increase on iPhone 15
2223
// and a ~5% improvement on Galaxy S22 by switching to
2324
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
24-
TextDecoderRunner::TextDecoderRunner(Module* module, bool use_kv_cache)
25-
: module_(module), use_kv_cache_(use_kv_cache) {}
25+
TextDecoderRunner::TextDecoderRunner(Module* module) : module_(module) {}
2626

2727
// This function is functional, meaning it shouldn't modify any state of the
2828
// input. It should be safe to call multiple times with the same inputs. The
2929
// outer loop (call site) is responsible for managing state.
3030
::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
3131
TensorPtr& tokens,
32-
TensorPtr& start_pos) {
32+
int64_t start_pos) {
3333
// ET_LOG(Info, "Input token %" PRIu64, input_token);
34-
if (use_kv_cache_) {
35-
auto outputs_res = module_->forward({tokens, start_pos});
34+
auto method_meta = ET_UNWRAP(module_->method_meta("forward"));
35+
// If only 1 input, we are not using kv cache
36+
bool use_kv_cache = method_meta.num_inputs() > 1;
37+
38+
if (use_kv_cache) {
39+
// Size of the second argument. This could be either input_pos or
40+
// cache_positions
41+
42+
// Check if we are using cache positions instead of input pos.
43+
auto second_input_info = ET_UNWRAP(method_meta.input_tensor_meta(1));
44+
// For input_pos, numel is 1, for cache_positions, numel is max_seq_len
45+
auto sizes = second_input_info.sizes();
46+
// Assuming 1D tensor
47+
ET_CHECK_OR_RETURN_ERROR(
48+
sizes.size() == 1,
49+
InvalidProgram,
50+
"The second input tensor is not 1D tensor. Got dimension (%zu)",
51+
sizes.size());
52+
auto numel = sizes[0];
53+
std::vector<::executorch::aten::SizesType> sizes_vec = {numel};
54+
55+
// Assuming the last dimension is the one with the variable token length,
56+
// for example [1, S] or [1, 1, S]
57+
sizes_vec[sizes_vec.size() - 1] = numel;
58+
TensorPtr start_pos_tensor;
59+
if (numel > 1) {
60+
// Assuming model is exported with cache_positions, create a tensor with
61+
// the same size as cache_positions
62+
start_pos_tensor = empty(sizes_vec, ::executorch::aten::ScalarType::Long);
63+
torch::executor::native::arange_out_impl(
64+
start_pos, start_pos + numel, 1.0, *start_pos_tensor);
65+
} else {
66+
// Assuming model is exported with input_pos, create a tensor with size 1
67+
start_pos_tensor = from_blob(
68+
&start_pos, sizes_vec, ::executorch::aten::ScalarType::Long);
69+
}
70+
ET_LOG(Info, "Start pos tensor numel: %zu", start_pos_tensor->numel());
71+
auto outputs_res = module_->forward({tokens, start_pos_tensor});
3672
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
3773
ET_CHECK_MSG(
3874
outputs_res.get().size() == 1,

0 commit comments

Comments
 (0)