Skip to content

Commit ab16752

Browse files
committed
first commit
0 parents  commit ab16752

20 files changed

+1535
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
outputs
2+
checkpoints

.gitmodules

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[submodule "fairseq"]
2+
path = fairseq
3+
url = [email protected]:lstrgar/fairseq.git
4+
branch = lvs

LICENSE

Lines changed: 674 additions & 0 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Phoneme Segmentation Using Self-Supervised Speech Models
2+
3+
## Usage
4+
5+
### Obtain Pre-trained Model Checkpoints
6+
wav2vec2.0 and HuBERT checkpoints are available via fairseq at the following links. Download these models and place in a new folder titled `checkpoints`.
7+
8+
https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/README.md
9+
https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt
10+
11+
https://github.com/facebookresearch/fairseq/blob/main/examples/hubert/README.md
12+
https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt
13+
14+
### Obtain and Process TIMIT and/or Buckeye Speech Corpus
15+
16+
Once the data has been obtained it must be stored in disk an a fashion that can be read by the provided dataloader, the core of which is borrowed from Kreuk Et al. (https://github.com/felixkreuk/UnsupSeg). See the Data Structure section of this repo for specifics, or simply use the provided `utils/make_timit.py` and `utils/make_buckeye.py` to split and organize the data exactly how we did it. Note: both of these scripts we also credit to Kreuk Et al., save a few minor changes.
17+
18+
You can run `make_timit.py` and `make_buckeye.py` as follows:
19+
20+
`python utils/make_timit.py --inpath /path/to/original/timit --outpath /path/to/output/timit`
21+
22+
`python utils/make_buckeye.py --spkr --source /path/to/original/buckeye --target /path/to/output/buckeye --min_phonemes 20 --max_phonemes 50`
23+
24+
Note, here we do not provide the infrastructure to train these models using the pseudo-labels derived from a trained unsupervised model; however, the core implementation can be easily extended to train with alternate label supervision so long as the dataloader's interface remains unchanges. For those interested in training such a model, we would direct you to Kreuk Et al., where a pretrained unsupervised model can be used to generate pseudo-labels for TIMIT.
25+
26+
### Update Configuration YAML
27+
28+
The following fields will need to be updated to reflect local paths on your machine:
29+
30+
- timit_path
31+
- buckeye_path
32+
- base_ckpt_path
33+
34+
You may also want to experiment with the `num_workers` attribute depending on your hardware.
35+
36+
### Training and Testing
37+
38+
To freeze the pre-trained model weights and train only a classifier readout model on TIMIT with a wav2vec2.0 backbone run the following
39+
40+
`python run.py data=timit lr=0.001 base_ckpt_path=/path/to/wav2vec2.0_ckpt mode=readout`
41+
42+
`data=timit` can easily be swapped for `data=buckeye` just as `base_ckpt_path=/path/to/wav2vec2.0_ckpt` can be swapped with `base_ckpt_path=/path/to/hubert_ckpt`.
43+
44+
To finetune the whole pre-trained model and simply project final features with a linear readout run the you should set `lr=0.0001` and `mode=finetune`. Otherwise, the same swapping for TIMIT/Buckeye and wav2vec2.0/HuBERT applies.
45+
46+
Invoking `run.py` will train a model from scratch for 50 epochs while printing training stats every 10 batches and running model validation every 50 batches. Print preferences can be changed in the config with attributes `print_interval` and `val_interval`. `epochs` can also be modified if desired.
47+
48+
During training models are saved to disk if they so-far demonstrate the best R-Value on the validation set. After training is complete, the best model is loaded from disk and tested with the testing set. Performance metrics in the harsh and lenient evaluation scheme are logged to standard out.
49+
50+
Lastly, every invocation of `run.py` will create an output folder under `outputs/datestamp/{exp_name}_timestamp`, which is where model checkpoints are saved along with the whole runtime config and a run.log. Everything logged to standard output during training will also be logged to the run.log file.
51+
52+
### Additional
53+
54+
This codebase assumes CUDA availability.
55+
56+
The config `seed` attribute can be changed to control random shuffling and initialization.
57+
58+
`train_percent` indicates the fraction of the training set to use. Some may be interested in observing model / training data efficiency by sweeping over this attribute. Sweeps can be easily accomodated using hydra's multi-run command line option. For more see the hydra docs.

config/conf.yaml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# NOTE: do not name me "config.yml" to avoid conflict with fairseq defaults
2+
3+
hydra:
4+
run:
5+
dir: ./outputs/${now:%Y-%m-%d}/${exp_name}-${now:%H-%M-%S}
6+
exp_name: null
7+
buckeye_path: /home/lvs/data/buckeye
8+
timit_path: /home/lvs/data/timit
9+
data: timit
10+
val_ratio: 0.1 # Ratio of training set to use for timit validation
11+
train_percent: 1.0 # Percentage of training data to use
12+
num_workers: 5
13+
base_ckpt_path: /home/lvs/code/segment-public/checkpoints/w2v2_small_lib.pt
14+
seed: 0
15+
mode: readout
16+
label_dist_threshold: 1 # 20ms tolerance
17+
print_interval: 10 # Train batches to print loss stats
18+
val_interval: 50 # Train batches to eval step
19+
optim_type: adam
20+
beta1: 0.9
21+
beta2: 0.999
22+
momentum: 0.9
23+
weight_decay: 0
24+
pos_weight: 1.0 # BCE loss weighting
25+
epochs: 50
26+
batch_size: 16
27+
lr: 0.001
4.22 KB
Binary file not shown.

experiment/train_test.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import torch
2+
from torch.nn import BCEWithLogitsLoss
3+
from utils.eval import PrecisionRecallMetric
4+
from utils.dataloader import construct_mask
5+
from models.classifier import get_features
6+
from utils.misc import load_from_checkpoint, save_checkpoint, get_optimizer
7+
8+
def train_test(cfg, model, classifier, trainloader, valloader, testloader, logger):
9+
device = model.parameters().__next__().device
10+
logger.info("TRAINING MODEL")
11+
ckpt, _ = train(model, classifier, trainloader, valloader, cfg, logger, device)
12+
logger.info("Training complete. Loading best model from checkpoint: {}".format(ckpt))
13+
model, _, classifier, _, metrics = load_from_checkpoint(cfg, device, ckpt)
14+
logger.info("Best model's VALIDATION METRICS:")
15+
for k, v in metrics.items():
16+
logger.info(f"{k}:")
17+
for m, s in v.items():
18+
logger.info(f"\t{m+':':<10} {s:>4.4f}")
19+
logger.info("Testing best model")
20+
test(model, classifier, testloader, cfg, logger, device)
21+
22+
23+
24+
def train(model, classifier, trainloader, valloader, cfg, logger, device):
25+
loss_fn = BCEWithLogitsLoss(
26+
reduction="none",
27+
pos_weight=torch.tensor([cfg.pos_weight]).to(device)
28+
)
29+
30+
params_dict = {
31+
"classifier": classifier.parameters(),
32+
}
33+
if cfg.mode == "finetune":
34+
logger.info("Fine-tuning encoder layers")
35+
params_dict["model"] = model.parameters()
36+
else:
37+
logger.info("Training readout (classifier) weights ONLY")
38+
39+
optimizer = get_optimizer(cfg, params_dict)
40+
41+
global_step = 0
42+
best_rval = 0
43+
best_model = None
44+
45+
for e in range(cfg.epochs):
46+
running_loss = 0.0
47+
for i, samp in enumerate(trainloader):
48+
if cfg.mode == "finetune":
49+
model.train()
50+
else:
51+
model.eval()
52+
classifier.train()
53+
wavs, _, labels, _, lengths, _ = samp
54+
mask = construct_mask(lengths, device).float()
55+
wavs = wavs.to(device)
56+
labels = labels.to(device)
57+
optimizer.zero_grad()
58+
results = model.extract_features(wavs, padding_mask=None)
59+
features = get_features(results, cfg.mode)
60+
logits = classifier(features).squeeze()
61+
if len(logits.shape) == 1:
62+
logits = logits.unsqueeze(0)
63+
bce_loss = (loss_fn(logits, labels) * mask).sum() / mask.sum()
64+
loss = bce_loss
65+
running_loss += loss.item()
66+
loss.backward()
67+
optimizer.step()
68+
69+
if global_step % cfg.print_interval == cfg.print_interval - 1:
70+
logger.info("Epoch: {}/{} | Batch: {}/{} | Loss: {:.4f}".format(
71+
e+1, cfg.epochs, i+1, len(trainloader), running_loss/cfg.print_interval,
72+
))
73+
running_loss = 0.0
74+
75+
if cfg.val_interval and global_step % cfg.val_interval == cfg.val_interval - 1:
76+
logger.info("MODEL VALIDATION: Epoch: {}/{} | Batch: {}/{}".format(e+1, cfg.epochs, i+1, len(trainloader)))
77+
harsh_metrics_val, lenient_metrics_val = test(model, classifier, valloader, cfg, logger, device)
78+
if harsh_metrics_val["rval"] > best_rval:
79+
best_rval = harsh_metrics_val["rval"]
80+
logger.info("New best (harsh) validation rval: {:.4f}".format(best_rval))
81+
metrics = {"harsh": harsh_metrics_val, "lenient": lenient_metrics_val}
82+
checkpoint_path = save_checkpoint(model, classifier, optimizer, metrics, e+1)
83+
best_model = checkpoint_path
84+
logger.info("Checkpoint saved to: {}".format(checkpoint_path))
85+
86+
global_step += 1
87+
88+
return best_model, best_rval
89+
90+
91+
def test(model, classifier, dataloader, cfg, logger, device):
92+
93+
model.eval()
94+
classifier.eval()
95+
metric_tracker_harsh = PrecisionRecallMetric(tolerance=cfg.label_dist_threshold, mode="harsh")
96+
metric_tracker_lenient = PrecisionRecallMetric(tolerance=cfg.label_dist_threshold, mode="lenient")
97+
sigmoid = torch.nn.Sigmoid()
98+
logger.info("Evaluating model on {} samples".format(len(dataloader.dataset)))
99+
100+
for samp in dataloader:
101+
wavs, segs, labels, _, lengths, _ = samp
102+
segs = [[*segs[i][0]] + [s[1] for s in segs[i][1:]] for i in range(len(segs))]
103+
wavs = wavs.to(device)
104+
labels = labels.to(device)
105+
results = model.extract_features(wavs, padding_mask=None)
106+
features = get_features(results, cfg.mode)
107+
preds = classifier(features).squeeze()
108+
preds = sigmoid(preds)
109+
preds = preds > 0.5
110+
preds = [
111+
torch.where(preds[i, :lengths[i]] == 1)[0].tolist() for i in range(preds.size(0))
112+
]
113+
metric_tracker_harsh.update(segs, preds)
114+
metric_tracker_lenient.update(segs, preds)
115+
116+
logger.info("Computing metrics with distance threshold of {} frames".format(cfg.label_dist_threshold))
117+
118+
tracker_metrics_harsh = metric_tracker_harsh.get_stats()
119+
tracker_metrics_lenient = metric_tracker_lenient.get_stats()
120+
121+
logger.info(f"{'SCORES:':<15} {'Lenient':>10} {'Harsh':>10}")
122+
for k in tracker_metrics_harsh.keys():
123+
logger.info("{:<15} {:>10.4f} {:>10.4f}".format(k+":", tracker_metrics_lenient[k], tracker_metrics_harsh[k]))
124+
125+
return tracker_metrics_harsh, tracker_metrics_lenient

fairseq

Submodule fairseq added at 976e383
2.32 KB
Binary file not shown.

models/classifier.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch.nn as nn
2+
import torch
3+
4+
class Classifier(nn.Module):
5+
def __init__(
6+
self,
7+
mode="finetune",
8+
n_layers=12
9+
):
10+
super(Classifier, self).__init__()
11+
self.mode = mode
12+
13+
if self.mode == "readout":
14+
self.n_weights = n_layers
15+
self.weight = nn.parameter.Parameter(torch.ones(self.n_weights, 1, 1, 1) / self.n_weights)
16+
self.layerwise_convolutions = nn.ModuleList([
17+
nn.Sequential(
18+
nn.Conv1d(768, 768, kernel_size=9, padding=4, stride=1),
19+
nn.ReLU(),
20+
) for _ in range(self.n_weights)
21+
])
22+
self.network = nn.Sequential(
23+
nn.Conv1d(768, 512, kernel_size=3, stride=1, padding=1),
24+
nn.ReLU(),
25+
nn.Conv1d(512, 256, kernel_size=3, stride=1, padding=1),
26+
nn.ReLU(),
27+
nn.Conv1d(256, 128, kernel_size=3, stride=1, padding=1),
28+
nn.ReLU(),
29+
nn.Conv1d(128, 64, kernel_size=3, stride=1, padding=1),
30+
nn.ReLU(),
31+
nn.Conv1d(64, 32, kernel_size=3, stride=1, padding=1),
32+
nn.ReLU(),
33+
)
34+
self.out = nn.Linear(32, 1)
35+
elif self.mode == "finetune":
36+
self.out = nn.Linear(768, 1)
37+
38+
39+
def forward(self, x):
40+
if self.mode == "readout":
41+
layers = []
42+
for i in range(x.size(0)):
43+
layers.append(self.layerwise_convolutions[i](x[i, :, :, :].permute(0, 2, 1)).permute(0, 2, 1))
44+
x = torch.stack(layers, dim=0)
45+
x = torch.mul(x, self.weight).sum(0)
46+
x = x.permute(0, 2, 1)
47+
x = self.network(x)
48+
x = x.permute(0, 2, 1)
49+
50+
out = self.out(x)
51+
return out
52+
53+
54+
def get_features(results, mode):
55+
if mode == "finetune":
56+
return results["x"]
57+
elif mode == "readout":
58+
zeros = torch.zeros_like(results["x"])
59+
results = [r for r in results["layer_results"]]
60+
features = [r[0].permute(1, 0, 2) if r[0] is not None else zeros.clone() for r in results]
61+
features = torch.stack(features, dim=0).squeeze(0)
62+
return features

0 commit comments

Comments
 (0)