1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15-
1615#include " kws/keyword_spotting.h"
1716
1817#include < iostream>
@@ -35,30 +34,27 @@ KeywordSpotting::KeywordSpotting(const std::string& model_path) {
3534 out_names_ = {" output" , " r_cache" };
3635 auto metadata = session_->GetModelMetadata ();
3736 Ort::AllocatorWithDefaultOptions allocator;
38- cache_dim_ = std::stoi (metadata. LookupCustomMetadataMap ( " cache_dim " ,
39- allocator));
40- cache_len_ = std::stoi (metadata. LookupCustomMetadataMap ( " cache_len " ,
41- allocator));
37+ cache_dim_ =
38+ std::stoi (metadata. LookupCustomMetadataMap ( " cache_dim " , allocator));
39+ cache_len_ =
40+ std::stoi (metadata. LookupCustomMetadataMap ( " cache_len " , allocator));
4241 std::cout << " Kws Model Info:" << std::endl
4342 << " \t cache_dim: " << cache_dim_ << std::endl
4443 << " \t cache_len: " << cache_len_ << std::endl;
4544 Reset ();
4645}
4746
48-
4947void KeywordSpotting::Reset () {
5048 Ort::MemoryInfo memory_info =
51- Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
49+ Ort::MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
5250 cache_.resize (cache_dim_ * cache_len_, 0.0 );
5351 const int64_t cache_shape[] = {1 , cache_dim_, cache_len_};
54- cache_ort_ = Ort::Value::CreateTensor<float >(
55- memory_info, cache_. data (), cache_.size (), cache_shape, 3 );
52+ cache_ort_ = Ort::Value::CreateTensor<float >(memory_info, cache_. data (),
53+ cache_.size (), cache_shape, 3 );
5654}
5755
58-
59- void KeywordSpotting::Forward (
60- const std::vector<std::vector<float >>& feats,
61- std::vector<std::vector<float >>* prob) {
56+ void KeywordSpotting::Forward (const std::vector<std::vector<float >>& feats,
57+ std::vector<std::vector<float >>* prob) {
6258 prob->clear ();
6359 if (feats.size () == 0 ) return ;
6460 Ort::MemoryInfo memory_info =
@@ -78,9 +74,9 @@ void KeywordSpotting::Forward(
7874 inputs.emplace_back (std::move (feats_ort));
7975 inputs.emplace_back (std::move (cache_ort_));
8076 // ort_outputs.size() == 2
81- std::vector<Ort::Value> ort_outputs = session_-> Run (
82- Ort::RunOptions{nullptr }, in_names_.data (), inputs.data (),
83- inputs.size (), out_names_.data (), out_names_.size ());
77+ std::vector<Ort::Value> ort_outputs =
78+ session_-> Run ( Ort::RunOptions{nullptr }, in_names_.data (), inputs.data (),
79+ inputs.size (), out_names_.data (), out_names_.size ());
8480
8581 // 3. Update cache
8682 cache_ort_ = std::move (ort_outputs[1 ]);
@@ -92,9 +88,9 @@ void KeywordSpotting::Forward(
9288 int output_dim = type_info.GetShape ()[2 ];
9389 prob->resize (num_outputs);
9490 for (int i = 0 ; i < num_outputs; i++) {
95- (*prob)[i].resize (output_dim);
96- memcpy ((*prob)[i].data (), data + i * output_dim,
97- sizeof (float ) * output_dim);
91+ (*prob)[i].resize (output_dim);
92+ memcpy ((*prob)[i].data (), data + i * output_dim,
93+ sizeof (float ) * output_dim);
9894 }
9995}
10096
0 commit comments