|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated |
| 8 | + */ |
| 9 | + |
| 10 | +#include <executorch/extension/llm/runner/text_decoder_runner.h> |
| 11 | +#include <executorch/extension/module/module.h> |
| 12 | +#include <executorch/extension/tensor/tensor.h> |
| 13 | +#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> |
| 14 | +#include <gmock/gmock.h> |
| 15 | +#include <gtest/gtest.h> |
| 16 | +#include <cstdlib> |
| 17 | + |
| 18 | +using namespace ::testing; |
| 19 | +using executorch::extension::Module; |
| 20 | +using executorch::extension::TensorPtr; |
| 21 | +using executorch::extension::llm::TextDecoderRunner; |
| 22 | +using executorch::runtime::Error; |
| 23 | +using executorch::runtime::EValue; |
| 24 | +using executorch::runtime::Result; |
| 25 | +using executorch::runtime::testing::TensorFactory; |
| 26 | + |
| 27 | +// Mock Module class for testing |
| 28 | +class MockModule : public Module { |
| 29 | + public: |
| 30 | + MockModule() : Module("") {} |
| 31 | +}; |
| 32 | + |
| 33 | +class TextDecoderRunnerTest : public Test { |
| 34 | + protected: |
| 35 | + void SetUp() override { |
| 36 | + mock_module_ = std::make_unique<MockModule>(); |
| 37 | + runner_ = std::make_unique<TextDecoderRunner>(mock_module_.get()); |
| 38 | + } |
| 39 | + |
| 40 | + std::unique_ptr<MockModule> mock_module_; |
| 41 | + std::unique_ptr<TextDecoderRunner> runner_; |
| 42 | +}; |
| 43 | + |
| 44 | +// Test logits_to_token() method with Float tensor |
| 45 | +TEST_F(TextDecoderRunnerTest, LogitsToTokenFloat) { |
| 46 | + TensorFactory<executorch::aten::ScalarType::Float> tf_float; |
| 47 | + auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); |
| 48 | + |
| 49 | + // Call logits_to_token with temperature 0 (deterministic) |
| 50 | + int32_t token = runner_->logits_to_token(logits, 0.0f); |
| 51 | + |
| 52 | + // With temperature 0, should return the argmax (index 2) |
| 53 | + EXPECT_EQ(token, 2); |
| 54 | +} |
| 55 | + |
| 56 | +// Test logits_to_token() method with 3D tensor (batch, seq_length, vocab_size) |
| 57 | +TEST_F(TextDecoderRunnerTest, LogitsToToken3D) { |
| 58 | + TensorFactory<executorch::aten::ScalarType::Float> tf_float; |
| 59 | + // Shape: [1, 2, 4] - batch=1, seq_length=2, vocab_size=4 |
| 60 | + auto logits = tf_float.make( |
| 61 | + {1, 2, 4}, |
| 62 | + { |
| 63 | + 0.1f, |
| 64 | + 0.2f, |
| 65 | + 0.3f, |
| 66 | + 0.4f, // First sequence position |
| 67 | + 0.5f, |
| 68 | + 0.6f, |
| 69 | + 0.9f, |
| 70 | + 0.8f // Second sequence position (last) |
| 71 | + }); |
| 72 | + |
| 73 | + // Call logits_to_token with temperature 0 (deterministic) |
| 74 | + int32_t token = runner_->logits_to_token(logits, 0.0f); |
| 75 | + |
| 76 | + // Should use the last sequence position and return argmax (index 2) |
| 77 | + EXPECT_EQ(token, 2); |
| 78 | +} |
| 79 | + |
| 80 | +// Test logits_to_token() method with Half tensor |
| 81 | +TEST_F(TextDecoderRunnerTest, LogitsToTokenHalf) { |
| 82 | + TensorFactory<executorch::aten::ScalarType::Half> tf_half; |
| 83 | + auto logits = tf_half.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); |
| 84 | + |
| 85 | + // Call logits_to_token with temperature 0 (deterministic) |
| 86 | + int32_t token = runner_->logits_to_token(logits, 0.0f); |
| 87 | + |
| 88 | + // With temperature 0, should return the argmax (index 2) |
| 89 | + EXPECT_EQ(token, 2); |
| 90 | +} |
| 91 | + |
| 92 | +// Test logits_to_token() method with BFloat16 tensor |
| 93 | +TEST_F(TextDecoderRunnerTest, LogitsToTokenBFloat16) { |
| 94 | + TensorFactory<executorch::aten::ScalarType::BFloat16> tf_bfloat16; |
| 95 | + auto logits = tf_bfloat16.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); |
| 96 | + |
| 97 | + // Call logits_to_token with temperature 0 (deterministic) |
| 98 | + int32_t token = runner_->logits_to_token(logits, 0.0f); |
| 99 | + |
| 100 | + // With temperature 0, should return the argmax (index 2) |
| 101 | + EXPECT_EQ(token, 2); |
| 102 | +} |
| 103 | + |
| 104 | +// Test logits_to_token() method with non-zero temperature |
| 105 | +TEST_F(TextDecoderRunnerTest, LogitsToTokenWithTemperature) { |
| 106 | + TensorFactory<executorch::aten::ScalarType::Float> tf_float; |
| 107 | + auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); |
| 108 | + |
| 109 | + // Call logits_to_token with temperature > 0 (stochastic) |
| 110 | + int32_t token = runner_->logits_to_token(logits, 1.0f); |
| 111 | + |
| 112 | + // With temperature > 0, result should be within valid range |
| 113 | + EXPECT_GE(token, 0); |
| 114 | + EXPECT_LT(token, 4); |
| 115 | +} |
| 116 | + |
| 117 | +// Test step() method with all available PTE models |
| 118 | +TEST_F(TextDecoderRunnerTest, StepWithAllModels) { |
| 119 | + // List of all environment variables for PTE models |
| 120 | + std::vector<std::pair<std::string, const char*>> env_vars = { |
| 121 | + {"KVCACHE_CACHE_POS", "KVCACHE_CACHE_POS"}, |
| 122 | + {"KVCACHE_INPUT_POS", "KVCACHE_INPUT_POS"}, |
| 123 | + {"NO_KVCACHE", "NO_KVCACHE"}}; |
| 124 | + |
| 125 | + // Check if any environment variables are set up front |
| 126 | + bool any_env_set = false; |
| 127 | + for (const auto& [model_name, env_var] : env_vars) { |
| 128 | + if (std::getenv(env_var)) { |
| 129 | + any_env_set = true; |
| 130 | + break; |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + // Skip test if no environment variables are set |
| 135 | + if (!any_env_set) { |
| 136 | + GTEST_SKIP() << "No PTE model environment variables were set"; |
| 137 | + } |
| 138 | + |
| 139 | + bool any_model_tested = false; |
| 140 | + |
| 141 | + // Loop through all available models |
| 142 | + for (const auto& [model_name, env_var] : env_vars) { |
| 143 | + const char* model_path = std::getenv(env_var); |
| 144 | + if (!model_path) { |
| 145 | + continue; // Skip if environment variable not set |
| 146 | + } |
| 147 | + |
| 148 | + SCOPED_TRACE( |
| 149 | + "Testing model: " + model_name + " from " + std::string(model_path)); |
| 150 | + |
| 151 | + // Load the model |
| 152 | + auto module = std::make_unique<Module>(model_path); |
| 153 | + auto load_result = module->load(); |
| 154 | + if (load_result != Error::Ok) { |
| 155 | + ADD_FAILURE() << "Failed to load model " << model_name << " from " |
| 156 | + << model_path << " with error: " << (int)load_result; |
| 157 | + continue; |
| 158 | + } |
| 159 | + |
| 160 | + // Create TextDecoderRunner |
| 161 | + TextDecoderRunner runner(module.get()); |
| 162 | + auto runner_load_result = runner.load(); |
| 163 | + ASSERT_EQ(runner_load_result, Error::Ok) |
| 164 | + << "Failed to load runner for " << model_name; |
| 165 | + |
| 166 | + // Verify method is loaded |
| 167 | + EXPECT_TRUE(runner.is_method_loaded()) |
| 168 | + << "Method not loaded for " << model_name; |
| 169 | + |
| 170 | + // Create input tensor pointer |
| 171 | + |
| 172 | + TensorFactory<executorch::aten::ScalarType::Long> tf_long; |
| 173 | + auto input_tokens_ = |
| 174 | + tf_long.make({1, 3}, {50, 7, 11}); // Single token input |
| 175 | + |
| 176 | + auto input_ptr = std::make_shared<executorch::aten::Tensor>(input_tokens_); |
| 177 | + int64_t start_pos = 0; |
| 178 | + |
| 179 | + // Call step() and verify result is ok |
| 180 | + auto result = runner.step(input_ptr, start_pos); |
| 181 | + ASSERT_TRUE(result.ok()) << "step() failed for " << model_name |
| 182 | + << " with error: " << (int)result.error(); |
| 183 | + |
| 184 | + // Verify output tensor is valid |
| 185 | + auto output_tensor = result.get(); |
| 186 | + EXPECT_GT(output_tensor.numel(), 0) |
| 187 | + << "Output tensor empty for " << model_name; |
| 188 | + |
| 189 | + // Test logits_to_token works |
| 190 | + int32_t token = runner.logits_to_token(output_tensor, 0.0f); |
| 191 | + EXPECT_GE(token, 0) << "Invalid token for " << model_name; |
| 192 | + |
| 193 | + any_model_tested = true; |
| 194 | + } |
| 195 | + |
| 196 | + // This should not happen since we checked environment variables up front |
| 197 | + ASSERT_TRUE(any_model_tested) |
| 198 | + << "No models were tested despite environment variables being set"; |
| 199 | +} |
0 commit comments