-
Notifications
You must be signed in to change notification settings - Fork 131
Description
使用example的hi_xiaowen跟著run一路做完
自己弄了個測試程式
`import torch
import torchaudio
import yaml
from wekws.model.kws_model import init_model
from wekws.utils.checkpoint import load_checkpoint
from wenet.text.char_tokenizer import CharTokenizer
=== 設定 ===
DICT_PATH = '/wekws/examples/hi_xiaowen/s0/dict/dict.txt'
MODEL_PATH = '/wekws/examples/hi_xiaowen/s0/exp/ds_tcn/final.pt'
CONFIG_PATH = '/wekws/examples/hi_xiaowen/s0/exp/ds_tcn/config.yaml'
WAV_PATH = '/wekws/examples/hi_xiaowen/s0/data/local/mobvoi_hotword_dataset/2a5d869e59b8dd33b584f88e015e56fc.wav'
SAMPLING_RATE = 16000
THRESHOLD = 0.9
=== 載入模型與設定 ===
with open(CONFIG_PATH, 'r') as f:
configs = yaml.load(f, Loader=yaml.FullLoader)
model = init_model(configs['model'])
load_checkpoint(model, MODEL_PATH)
model.eval().to('cpu')
=== Tokenizer ===
tokenizer = CharTokenizer(DICT_PATH, DICT_PATH.replace("dict.txt", "words.txt"), unk='')
=== 讀取音檔 ===
waveform, sr = torchaudio.load(WAV_PATH)
if sr != SAMPLING_RATE:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLING_RATE)
waveform = waveform[0].unsqueeze(0) # (1, N)
=== 轉 FBank 特徵 ===
fbank = torchaudio.compliance.kaldi.fbank(
waveform,
num_mel_bins=40,
frame_length=25,
frame_shift=10,
dither=0.0,
energy_floor=0.0,
sample_frequency=SAMPLING_RATE
)
feats = fbank.unsqueeze(0) # (1, T, F)
=== 模型推論 ===
with torch.no_grad():
logits, _ = model(feats)
avg_scores = torch.mean(logits[0], dim=0) # 對時間軸平均
best_score, best_idx = torch.max(avg_scores, dim=0)
keyword = tokenizer.ids2tokens([best_idx.item()])[0]
print(f"辨識結果:{keyword},分數:{best_score.item():.3f}")
if best_score.item() > THRESHOLD:
print(f" 偵測到關鍵字:{keyword}")
else:
print("偵測失敗")
`
發現不管使用測試集還是訓練集都無法成功偵測