|
if grad_weight is not None and input_requires_grad: |
|
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() |
Looking at the grad_weight computation in fused_linear_cross_entropy_foward() there is an inefficiency that can be improved.
The line grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() could instead be written as torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32), which is both faster and more memory efficient.
Memory Efficiency
grad_logits_chunk is [chunk, V] and _input_chunk is [chunk, H].
Original: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
torch.mm allocates a [V, H] tensor in input dtype (bf16) → V × H × 2 bytes
.float() allocates a [V, H] tensor in fp32 → V × H × 4 bytes
Both are alive simultaneously during +=
Peak intermediate memory: V × H × 6 bytes.
addmm: torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32)
The result is accumulated directly into (or into a same-shaped buffer matching) the existing fp32 grad_weight. No intermediate tensor is allocated.
Peak intermediate memory: 0 bytes.
For V = 128k and H = 4096, this would save 3GB of memory.
Compute Efficiency
Original: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
This requires three kernel launches: mm, the .float() cast, and the += add. Each pays for a full read and write of the full matrix to memory.
addmm: torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32)
addmm is a single kernel, so there is only one round trip to memory.
However, torch.addmm with out_dtype support isnt available in the minimum version of torch used by liger-kernel, 2.1.2. An alternative solution could look wrapping grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() in torch.compile().
I think this is possibly a good beginner issue.
Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py
Lines 211 to 212 in 547cf4c
Looking at the
grad_weightcomputation infused_linear_cross_entropy_foward()there is an inefficiency that can be improved.The line
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()could instead be written astorch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32), which is both faster and more memory efficient.Memory Efficiency
grad_logits_chunk is [chunk, V] and _input_chunk is [chunk, H].
Original: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
torch.mm allocates a [V, H] tensor in input dtype (bf16) → V × H × 2 bytes
.float() allocates a [V, H] tensor in fp32 → V × H × 4 bytes
Both are alive simultaneously during +=
Peak intermediate memory: V × H × 6 bytes.
addmm: torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32)
The result is accumulated directly into (or into a same-shaped buffer matching) the existing fp32 grad_weight. No intermediate tensor is allocated.
Peak intermediate memory: 0 bytes.
For V = 128k and H = 4096, this would save 3GB of memory.
Compute Efficiency
Original: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
This requires three kernel launches:
mm, the.float()cast, and the+=add. Each pays for a full read and write of the full matrix to memory.addmm: torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32)
addmmis a single kernel, so there is only one round trip to memory.However,
torch.addmmwith out_dtype support isnt available in the minimum version of torch used by liger-kernel, 2.1.2. An alternative solution could look wrappinggrad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()in torch.compile().I think this is possibly a good beginner issue.