-
Notifications
You must be signed in to change notification settings - Fork 67
Open
Description
I am benchmarking reverse-mode AD with Zygote on CUDA tensors and observe that contractions written using TensorOperations.@tensor are significantly slower than equivalent implementations using OMEinsum or explicit reshape + matmul.
The slowdown seems to originate from the AD rule (_rrule_tensorcontract!) in TensorOperationsChainRulesCoreExt, which appears to introduce expensive CPU copies during the backward pass in the section.
Here is the minimal working example:
using TensorOperations
using cuTENSOR
using OMEinsum
using Zygote
using Test
using LinearAlgebra
using BenchmarkTools
@testset "ad" begin
D = 2^5
A = [CUDA.rand(ComplexF64, D,D,D) for _ in 1:10]
B = [CUDA.rand(ComplexF64, D,D) for _ in 1:10]
function foo1(A)
C = Zygote.Buffer(A)
for i in 1:length(A)
@tensor C[i][1,2,4] := A[i][1,2,3] * B[i][3,4]
end
return real(dot(C, C))
end
function foo2(A)
C = Zygote.Buffer(A)
for i in 1:length(A)
C[i] = ein"abc,cd->abd"(A[i], B[i])
end
return real(dot(C, C))
end
function foo3(A)
C = Zygote.Buffer(A)
for i in 1:length(A)
C[i] = reshape(reshape(A[i], D^2, D) * B[i], D, D, D)
end
return real(dot(C, C))
end
@btime CUDA.@sync Zygote.gradient($foo1, $A)
@btime CUDA.@sync Zygote.gradient($foo2, $A)
@btime CUDA.@sync Zygote.gradient($foo3, $A)
end
foo1 (@tensor): 370.627 ms (41219 allocations: 947.14 KiB)
foo2 (OMEinsum): 91.993 ms (25300 allocations: 749.11 KiB)
foo3 (reshape+*): 107.602 ms (14940 allocations: 400.62 KiB)
Metadata
Metadata
Assignees
Labels
No labels