-
Notifications
You must be signed in to change notification settings - Fork 154
Description
DAC doesn't stop gradient for residual = residual - z_q_i.detach()
.
descript-audio-codec/dac/nn/quantize.py
Line 186 in c7cfc5d
residual = residual - z_q_i |
However, vector_quantize_pytorch stop gradient on it.
https://github.com/lucidrains/vector-quantize-pytorch/blob/976335f03b259536a06ed88a7adb7248cdb3d17c/vector_quantize_pytorch/residual_vq.py#L386
RVQ is trained greedily, stage by stage: Each quantizer should only be responsible for correcting the residual from the previous quantizers, without those previous quantizers being updated again.
If you don’t stop the gradient: Earlier quantizers will receive gradients based on the later quantizer errors.
That breaks the stage-wise assumption, and learning becomes entangled.
This can destabilise training and lead to inefficient or overlapping codebooks.
What do you think?