Skip to content

Commit cf7f7d0

Browse files
authored
Extend kron support (#1458)
* Bump patch * Generalise kron implementation
1 parent 415ec0a commit cf7f7d0

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Zygote"
22
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
3-
version = "0.6.64"
3+
version = "0.6.65"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/lib/array.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,10 @@ function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
368368
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
369369
end
370370
_kron(a::AbstractVector, b::AbstractVector) = vec(_kron(reshape(a, :, 1), reshape(b, :, 1)))
371+
_kron(a::AbstractVector, b::AbstractMatrix) = _kron(reshape(a, :, 1), b)
372+
_kron(a::AbstractMatrix, b::AbstractVector) = _kron(a, reshape(b, :, 1))
371373

372-
function _pullback(cx::AContext, ::typeof(kron), a::AbstractVector, b::AbstractVector)
373-
res, back = _pullback(cx, _kron, a, b)
374-
return res, back unthunk_tangent
375-
end
376-
function _pullback(cx::AContext, ::typeof(kron), a::AbstractMatrix, b::AbstractMatrix)
374+
function _pullback(cx::AContext, ::typeof(kron), a::AbstractVecOrMat, b::AbstractVecOrMat)
377375
res, back = _pullback(cx, _kron, a, b)
378376
return res, back unthunk_tangent
379377
end

test/gradcheck.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ end
275275
@test gradtest(kron, rand(5,1), rand(3,1))
276276
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
277277
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
278+
@test gradtest(kron, rand(5), rand(3, 2))
279+
@test gradtest(kron, rand(3, 2), rand(5))
278280

279281
for mapfunc in [map,pmap]
280282
@testset "$mapfunc" begin

0 commit comments

Comments
 (0)