Skip to content

Commit b157179

Browse files
committed
address lucidrains#203 by extracting cluster centers from scikit-learn kmeans and doing it all in torch
1 parent 95b02cd commit b157179

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

audiolm_pytorch/hubert_kmeans.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from pathlib import Path
22

33
import torch
4-
from torch import nn
5-
from einops import rearrange, pack, unpack
4+
from torch import nn, einsum
5+
from einops import rearrange, repeat, pack, unpack
66

77
import joblib
88

@@ -54,8 +54,14 @@ def __init__(
5454
self.model.eval()
5555

5656
kmeans = joblib.load(kmeans_path)
57+
5758
self.kmeans = kmeans
5859

60+
self.register_buffer(
61+
'cluster_centers',
62+
torch.from_numpy(kmeans.cluster_centers_)
63+
)
64+
5965
@property
6066
def groups(self):
6167
return 1
@@ -76,7 +82,7 @@ def forward(
7682
flatten = True,
7783
input_sample_hz = None
7884
):
79-
device = wav_input.device
85+
batch, device = wav_input.shape[0], wav_input.device
8086

8187
if exists(input_sample_hz):
8288
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
@@ -89,17 +95,13 @@ def forward(
8995
features_only = True,
9096
mask = False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
9197
output_layer = self.output_layer
92-
)
93-
94-
embed, packed_shape = pack([embed['x']], '* d')
95-
96-
codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
97-
98-
codebook_indices = torch.from_numpy(codebook_indices).to(device).long()
98+
)['x']
9999

100-
codebook_indices, = unpack(codebook_indices, packed_shape, '*')
100+
batched_cluster_centers = repeat(self.cluster_centers, 'c d -> b c d', b = embed.shape[0])
101+
dists = -torch.cdist(embed, batched_cluster_centers, p = 2)
102+
clusters = dists.argmax(dim = -1)
101103

102104
if flatten:
103-
return codebook_indices
105+
return clusters
104106

105-
return rearrange(codebook_indices, 'b ... -> b (...)')
107+
return rearrange(clusters, 'b ... -> b (...)')

audiolm_pytorch/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.2.10'
1+
__version__ = '1.2.11'

0 commit comments

Comments
 (0)