Skip to content

Performance issue in Zygote AD for TensorOperations.@tensor on CUDA: costly CPU copies in rrule #235

@XingyuZhang2018

Description

@XingyuZhang2018

I am benchmarking reverse-mode AD with Zygote on CUDA tensors and observe that contractions written using TensorOperations.@tensor are significantly slower than equivalent implementations using OMEinsum or explicit reshape + matmul.

The slowdown seems to originate from the AD rule (_rrule_tensorcontract!) in TensorOperationsChainRulesCoreExt, which appears to introduce expensive CPU copies during the backward pass in the section.

Here is the minimal working example:

using TensorOperations
using cuTENSOR
using OMEinsum
using Zygote
using Test
using LinearAlgebra
using BenchmarkTools

@testset "ad" begin
    D = 2^5
    A = [CUDA.rand(ComplexF64, D,D,D) for _ in 1:10]
    B = [CUDA.rand(ComplexF64, D,D) for _ in 1:10]

    function foo1(A)
        C = Zygote.Buffer(A)
        for i in 1:length(A)
            @tensor C[i][1,2,4] := A[i][1,2,3] * B[i][3,4]
        end
        return real(dot(C, C))
    end

    function foo2(A)
        C = Zygote.Buffer(A)
        for i in 1:length(A)
            C[i] = ein"abc,cd->abd"(A[i], B[i])
        end
        return real(dot(C, C))
    end

    function foo3(A)
        C = Zygote.Buffer(A)
        for i in 1:length(A)
            C[i] =  reshape(reshape(A[i], D^2, D) * B[i], D, D, D)
        end
        return real(dot(C, C))
    end

    @btime CUDA.@sync Zygote.gradient($foo1, $A)
    @btime CUDA.@sync Zygote.gradient($foo2, $A)
    @btime CUDA.@sync Zygote.gradient($foo3, $A)
end


foo1 (@tensor):     370.627 ms (41219 allocations: 947.14 KiB)
foo2 (OMEinsum):     91.993 ms (25300 allocations: 749.11 KiB)
foo3 (reshape+*):  107.602 ms (14940 allocations: 400.62 KiB)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions