1
1
from pathlib import Path
2
2
3
3
import torch
4
- from torch import nn
5
- from einops import rearrange , pack , unpack
4
+ from torch import nn , einsum
5
+ from einops import rearrange , repeat , pack , unpack
6
6
7
7
import joblib
8
8
@@ -54,8 +54,14 @@ def __init__(
54
54
self .model .eval ()
55
55
56
56
kmeans = joblib .load (kmeans_path )
57
+
57
58
self .kmeans = kmeans
58
59
60
+ self .register_buffer (
61
+ 'cluster_centers' ,
62
+ torch .from_numpy (kmeans .cluster_centers_ )
63
+ )
64
+
59
65
@property
60
66
def groups (self ):
61
67
return 1
@@ -76,7 +82,7 @@ def forward(
76
82
flatten = True ,
77
83
input_sample_hz = None
78
84
):
79
- device = wav_input .device
85
+ batch , device = wav_input . shape [ 0 ], wav_input .device
80
86
81
87
if exists (input_sample_hz ):
82
88
wav_input = resample (wav_input , input_sample_hz , self .target_sample_hz )
@@ -89,17 +95,13 @@ def forward(
89
95
features_only = True ,
90
96
mask = False , # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
91
97
output_layer = self .output_layer
92
- )
93
-
94
- embed , packed_shape = pack ([embed ['x' ]], '* d' )
95
-
96
- codebook_indices = self .kmeans .predict (embed .cpu ().detach ().numpy ())
97
-
98
- codebook_indices = torch .from_numpy (codebook_indices ).to (device ).long ()
98
+ )['x' ]
99
99
100
- codebook_indices , = unpack (codebook_indices , packed_shape , '*' )
100
+ batched_cluster_centers = repeat (self .cluster_centers , 'c d -> b c d' , b = embed .shape [0 ])
101
+ dists = - torch .cdist (embed , batched_cluster_centers , p = 2 )
102
+ clusters = dists .argmax (dim = - 1 )
101
103
102
104
if flatten :
103
- return codebook_indices
105
+ return clusters
104
106
105
- return rearrange (codebook_indices , 'b ... -> b (...)' )
107
+ return rearrange (clusters , 'b ... -> b (...)' )
0 commit comments