-
Notifications
You must be signed in to change notification settings - Fork 93
Add rules and tests for kron
#741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5bb9766
f0902e3
7c53f4b
c1226eb
236daf1
b2d4f4a
8b94cfc
b71b8ef
2ad5473
fde509e
4386143
1b97828
5b74071
72060d0
6650adc
f104172
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -394,3 +394,64 @@ function rrule( | |
end | ||
return Ω, lyap_pullback | ||
end | ||
|
||
##### | ||
##### `kron` | ||
##### | ||
|
||
@static if VERSION ≥ v"1.9.0-DEV.1267" | ||
function frule((_, Δx, Δy), ::typeof(kron), x::AbstractVecOrMat{<:Number}, y::AbstractVecOrMat{<:Number}) | ||
return kron(x, y), kron(Δx, y) + kron(x, Δy) | ||
end | ||
|
||
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractVector{<:Number}) | ||
project_x = ProjectTo(x) | ||
project_y = ProjectTo(y) | ||
function kron_pullback(z̄) | ||
dz = reshape(unthunk(z̄), length(y), length(x)) | ||
x̄ = @thunk(project_x(conj.(dz' * y))) | ||
ȳ = @thunk(project_y(dz * conj.(x))) | ||
return NoTangent(), x̄, ȳ | ||
end | ||
return kron(x, y), kron_pullback | ||
end | ||
|
||
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) | ||
project_x = ProjectTo(x) | ||
project_y = ProjectTo(y) | ||
function kron_pullback(z̄) | ||
dz = reshape(unthunk(z̄), length(y), size(x)...) | ||
x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 3))))) | ||
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = 1)))) | ||
return NoTangent(), x̄, ȳ | ||
end | ||
return kron(x, y), kron_pullback | ||
end | ||
|
||
function rrule(::typeof(kron), x::AbstractVector{<:Number}, y::AbstractMatrix{<:Number}) | ||
project_x = ProjectTo(x) | ||
project_y = ProjectTo(y) | ||
function kron_pullback(z̄) | ||
simsurace marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dz = reshape(unthunk(z̄), size(y, 1), length(x), size(y, 2)) | ||
x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = 2)))) | ||
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3))))) | ||
return NoTangent(), x̄, ȳ | ||
end | ||
return kron(x, y), kron_pullback | ||
end | ||
|
||
function rrule(::typeof(kron), x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) | ||
project_x = ProjectTo(x) | ||
project_y = ProjectTo(y) | ||
function kron_pullback(z̄) | ||
simsurace marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2)) | ||
x̄ = @thunk(project_x(_dot_collect.(Ref(y), eachslice(dz; dims = (2, 4))))) | ||
ȳ = @thunk(project_y(_dot_collect.(Ref(x), eachslice(dz; dims = (1, 3))))) | ||
Comment on lines
+446
to
+449
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was wondering if you have to make slices, given that kron is just reshape and using ChainRulesCore
function pr_rule(x::AbstractMatrix{<:Number}, y::AbstractMatrix{<:Number}) # from https://github.com/JuliaDiff/ChainRules.jl/pull/741
project_x = ProjectTo(x)
project_y = ProjectTo(y)
function kron_pullback(z̄)
dz = reshape(unthunk(z̄), size(y, 1), size(x, 1), size(y, 2), size(x, 2))
x̄ = @thunk(project_x(dot.(Ref(y), eachslice(dz; dims = (2, 4)))))
ȳ = @thunk(project_y(dot.(Ref(x), eachslice(dz; dims = (1, 3)))))
return NoTangent(), x̄, ȳ
end
end
# using TensorCast
# mykron(x,y) = @cast z[(a,b), (c,d)] := x[b,d] * y[a,c]
# @pretty @cast z[(a,b), (c,d)] := x[b,d] * y[a,c]
function shape_rule(x::AbstractMatrix, y::AbstractMatrix)
function back(dz)
x4 = reshape(x, 1, size(x,1), 1, size(x,2))
y4 = reshape(y, size(y,1), 1, size(y,2), 1)
dz4 = reshape(unthunk(dz), size(y,1), size(x,1), size(y,2), size(x,2))
dx = @thunk ProjectTo(x)(reshape(sum(dz4 .* y4, dims=(1,3)), size(x))) # might be missing conj
dy = @thunk ProjectTo(y)(reshape(sum(dz4 .* x4, dims=(2,4)), size(y)))
0, dx, dy
end
end
let x = rand(10,20), y = rand(30,10)
b1 = pr_rule(x, y)
b2 = shape_rule(x, y)
z = kron(x,y)
_, dx1, _ = @btime map(unthunk, $b1($z))
_, dx2, _ = @btime map(unthunk, $b2($z))
dx1 ≈ dx2
end
# min 181.458 μs, mean 185.668 μs (4 allocations, 4.39 KiB)
# min 80.583 μs, mean 169.305 μs (32 allocations, 943.05 KiB)
# true It's a pity to allocate these big arrays bc = Broadcast.instantiate(Broadcast.broadcasted(*, [1 2 3], [4, 5]));
sum(bc) # OK
sum(bc; dims=1) # ERROR: MethodError: no method matching reducedim_init(::typeof(identity), ::typeof(Base.add_sum), ::Base.Broadcast.Broadcasted{…}, ::Int64)
sum!([0 0 0], bc) # ERROR: MethodError: no method matching sum!(::Matrix{Int64}, ::Base.Broadcast.Broadcasted
sum(bc; dims=1, init=0.0) # OK, not sure if it's fast or not On StaticArrays (mentioned above) both at present make a SizedMatrix, which I think is ProjectTo's attempt to fix things up. Surely this reshaping could be done in a static-friendly way but IDK exactly how. julia> let x = @SMatrix(rand(5,5)), y = @SMatrix(rand(5,5))
b1 = pr_rule(x, y)
b2 = shape_rule(x, y)
z = kron(x,y)
_, dx1, _ = @btime map(unthunk, $b1($z))
_, dx2, _ = @btime map(unthunk, $b2($z))
dx1 ≈ dx2
end
min 2.458 μs, mean 2.558 μs (2 allocations, 512 bytes)
min 4.006 μs, mean 5.198 μs (22 allocations, 11.38 KiB)
true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this result scale to larger arrays? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Result meaning speed difference? It will vary with size & machine. On very small arrays reshaping is Issues with StaticArrays will be similar at all sizes. I think broadcasting over slices will work badly on CuArrays, and tend to make Arrays. But right now neither idea seems to work, not sure why
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit curious at what sizes it's slower for you? But mainly I think the issue is less about the race than that simple solid-array operations have a better chance of behaving well with StaticArrays, and CuArrays. I haven't taken another pass to see if the first draft can be improved on. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't benchmarked anything myself yet. I will give it a go later. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, results seem to be mixed. For larger sizes the allocations are taking their price: let x = rand(100,200), y = rand(300,100)
b1 = pr_rule(x, y)
b2 = shape_rule(x, y)
z = kron(x,y)
_, dx1, _ = @btime map(unthunk, $b1($z))
_, dx2, _ = @btime map(unthunk, $b2($z))
dx1 ≈ dx2
end
# 3.376 s (6 allocations: 390.84 KiB)
# 3.797 s (34 allocations: 8.94 GiB)
# true I would suggest staying with the current implementation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One way to ensure any implementation isn't excluding all GPU array types would be to toss a |
||
return NoTangent(), x̄, ȳ | ||
end | ||
return kron(x, y), kron_pullback | ||
end | ||
|
||
_dot_collect(A::AbstractMatrix, B::SubArray) = dot(A, B) | ||
_dot_collect(A::Diagonal, B::SubArray) = dot(A, collect(B)) | ||
end |
Uh oh!
There was an error while loading. Please reload this page.