Skip to content

Commit 8c93470

Browse files
Merge pull request #154 from avik-pal/ap/ambiguous
Fix more ambiguity
2 parents 1e7fca1 + 5147dff commit 8c93470

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
version:
2323
- '1.6' # LTS
2424
- '1'
25+
- '~1.10.0-0'
2526
- 'nightly'
2627
os:
2728
- ubuntu-latest

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Tracker"
22
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3-
version = "0.2.28"
3+
version = "0.2.30"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/lib/array.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,26 +164,44 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
164164
end
165165

166166
for (T, S) in [(:TrackedArray, :TrackedArray), (:TrackedArray, :AbstractArray), (:AbstractArray, :TrackedArray)]
167-
@eval Base.vcat(A::$T, B::$S, Cs::AbstractArray...) = track(vcat, A, B, Cs...)
168-
@eval Base.hcat(A::$T, B::$S, Cs::AbstractArray...) = track(hcat, A, B, Cs...)
167+
for op in (:vcat, :hcat)
168+
@eval Base.$(op)(A::$T, B::$S, Cs::AbstractArray...) = track($(op), A, B, Cs...)
169+
@eval Base.$(op)(A::$T, B::$S, Cs::TrackedArray...) = track($(op), A, B, Cs...)
170+
@eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::AbstractArray{<:Number}...) = track($(op), A, B, Cs...)
171+
@eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::TrackedArray{<:Number}...) = track($(op), A, B, Cs...)
172+
@eval Base.$(op)(A::$T, B::$S) = track($(op), A, B)
173+
end
169174
end
170175
for (T, S) in [(:TrackedVector, :TrackedVector), (:TrackedVector, :AbstractVector), (:AbstractVector, :TrackedVector)]
171-
@eval Base.vcat(A::$T, B::$S, Cs::AbstractVector...) = track(vcat, A, B, Cs...)
176+
for op in (:vcat, :hcat)
177+
@eval Base.$(op)(A::$T, B::$S, Cs::AbstractVector...) = track($(op), A, B, Cs...)
178+
@eval Base.$(op)(A::$T, B::$S, Cs::TrackedVector...) = track($(op), A, B, Cs...)
179+
@eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::AbstractVector{<:Number}...) = track($(op), A, B, Cs...)
180+
@eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::TrackedVector{<:Number}...) = track($(op), A, B, Cs...)
181+
@eval Base.$(op)(A::$T, B::$S) = track($(op), A, B)
182+
end
172183
end
173184
for (T, S) in [(:TrackedVecOrMat, :TrackedVecOrMat), (:TrackedVecOrMat, :AbstractVecOrMat), (:AbstractVecOrMat, :TrackedVecOrMat)]
174-
@eval Base.vcat(A::$T, B::$S, Cs::AbstractVecOrMat...) = track(vcat, A, B, Cs...)
175-
@eval Base.hcat(A::$T, B::$S, Cs::AbstractVecOrMat...) = track(hcat, A, B, Cs...)
185+
for op in (:vcat, :hcat)
186+
@eval Base.$(op)(A::$T, B::$S, Cs::AbstractVecOrMat...) = track($(op), A, B, Cs...)
187+
@eval Base.$(op)(A::$T, B::$S, Cs::TrackedVecOrMat...) = track($(op), A, B, Cs...)
188+
@eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::AbstractVecOrMat{<:Number}...) = track($(op), A, B, Cs...)
189+
@eval Base.$(op)(A::$T{<:Number}, B::$S{<:Number}, Cs::TrackedVecOrMat{<:Number}...) = track($(op), A, B, Cs...)
190+
@eval Base.$(op)(A::$T, B::$S) = track($(op), A, B)
191+
end
176192
end
177193
for (T, S) in [(:TrackedArray, :Real), (:Real, :TrackedArray), (:TrackedArray, :TrackedArray)]
178-
@eval Base.vcat(A::$T, B::$S, Cs::Union{AbstractArray, Real}...) = track(vcat, A, B, Cs...)
179-
@eval Base.hcat(A::$T, B::$S, Cs::Union{AbstractArray, Real}...) = track(hcat, A, B, Cs...)
194+
@eval Base.vcat(A::$T, B::$S, Cs::Union{TrackedArray, AbstractArray, Real}...) = track(vcat, A, B, Cs...)
195+
@eval Base.hcat(A::$T, B::$S, Cs::Union{TrackedArray, AbstractArray, Real}...) = track(hcat, A, B, Cs...)
196+
if T == :Real || S == :Real
197+
@eval Base.vcat(A::$T, B::$S) = track(vcat, A, B)
198+
@eval Base.hcat(A::$T, B::$S) = track(hcat, A, B)
199+
end
180200
end
181201
for (T, S) in [(:TrackedReal, :Real), (:Real, :TrackedReal), (:TrackedReal, :TrackedReal)]
182202
@eval Base.vcat(A::$T, B::$S, Cs::Real...) = track(vcat, A, B, Cs...)
183203
@eval Base.hcat(A::$T, B::$S, Cs::Real...) = track(hcat, A, B, Cs...)
184204
end
185-
Base.vcat(A::TrackedVecOrMat{T1, <:AbstractArray}, B::TrackedVecOrMat{T2, <:AbstractArray}) where {T1, T2} = track(vcat, A, B)
186-
Base.hcat(A::TrackedVecOrMat{T1, <:AbstractArray}, B::TrackedVecOrMat{T2, <:AbstractArray}) where {T1, T2} = track(hcat, A, B)
187205

188206
Base.vcat(A::TrackedArray) = track(vcat, A)
189207
Base.hcat(A::TrackedArray) = track(hcat, A)

test/tracker.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ end
9292
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
9393
@test gradtest(hcatf, rand(5)', rand(1,3))
9494
@test gradtest(hcatf, rand(5), rand(5,2))
95-
end
95+
end
9696

9797
@testset "1-arg $catf" for catf in [vcat, cat1, rvcat, hcat, cat2, rhcat, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
9898
@test gradtest(catf, rand(5))

0 commit comments

Comments
 (0)