Skip to content

Mem and compute inefficiency in fused_linear_cross_entropy_foward #1232

@tyler-romero

Description

@tyler-romero

if grad_weight is not None and input_requires_grad:
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()

Looking at the grad_weight computation in fused_linear_cross_entropy_foward() there is an inefficiency that can be improved.

The line grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() could instead be written as torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32), which is both faster and more memory efficient.

Memory Efficiency

grad_logits_chunk is [chunk, V] and _input_chunk is [chunk, H].

Original: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()

torch.mm allocates a [V, H] tensor in input dtype (bf16) → V × H × 2 bytes
.float() allocates a [V, H] tensor in fp32 → V × H × 4 bytes
Both are alive simultaneously during +=

Peak intermediate memory: V × H × 6 bytes.

addmm: torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32)

The result is accumulated directly into (or into a same-shaped buffer matching) the existing fp32 grad_weight. No intermediate tensor is allocated.
Peak intermediate memory: 0 bytes.

For V = 128k and H = 4096, this would save 3GB of memory.

Compute Efficiency

Original: grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()

This requires three kernel launches: mm, the .float() cast, and the += add. Each pays for a full read and write of the full matrix to memory.

addmm: torch.addmm(grad_weight, grad_logits_chunk.t(), _input_chunk, out=grad_weight, out_dtype=torch.float32)

addmm is a single kernel, so there is only one round trip to memory.


However, torch.addmm with out_dtype support isnt available in the minimum version of torch used by liger-kernel, 2.1.2. An alternative solution could look wrapping grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float() in torch.compile().

I think this is possibly a good beginner issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions