Skip to content

reflect change in rotate! from LinearAlgebra.jl #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,8 @@ function LinearAlgebra.rotate!(x::AbstractGPUArray, y::AbstractGPUArray, c::Numb
i = @index(Global, Linear)
@inbounds xi = x[i]
@inbounds yi = y[i]
@inbounds x[i] = c * xi + s * yi
@inbounds y[i] = -conj(s) * xi + c * yi
@inbounds x[i] = s*yi + c *xi
@inbounds y[i] = c*yi - conj(s)*xi
end
rotate_kernel!(get_backend(x))(x, y, c, s; ndrange = size(x))
return x, y
Expand Down
14 changes: 13 additions & 1 deletion test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,24 @@ function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...)
end

function compare(f, AT::Type{<:Array}, xs...; kwargs...)
# no need to actually run this tests: we have nothing to compoare against,
# no need to actually run this tests: we have nothing to compare against,
# and we'll run it on a CPU array anyhow when comparing to a GPU array.
#
# this method exists so that we can at least run the test suite with Array,
# and make sure we cover other tests (that don't call `compare`) too.
return true
end

has_NaNs(a::AbstractArray) = isfloattype(eltype(a)) && any(isnan, collect(a))
has_NaNs(as::NTuple) = any(a -> has_NaNs(a), as)

out_has_NaNs(f, AT::Type{<:Array}, xs...) = false # we do not test stdlibs/LinAlg for NaNs (maybe they should?)
function out_has_NaNs(f, AT::Type{<:AbstractGPUArray}, xs...)
arg_in = map(x -> isa(x, Base.RefValue) ? x[] : adapt(AT, x), xs)
arg_out = f(arg_in...)
return has_NaNs(arg_out)
end

# element types that are supported by the array type
supported_eltypes(AT, test) = supported_eltypes(AT)
supported_eltypes(AT) = supported_eltypes()
Expand All @@ -67,6 +77,8 @@ isrealtype(T) = T <: Real
iscomplextype(T) = T <: Complex
isrealfloattype(T) = T <: AbstractFloat
isfloattype(T) = T <: AbstractFloat || T <: Complex{<:AbstractFloat}
NaN_T(T::Type{<:AbstractFloat}) = T(NaN)
NaN_T(T::Type{<:Complex{<:AbstractFloat}}) = T(NaN, NaN)

# list of tests
const tests = Dict()
Expand Down
27 changes: 27 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,19 @@
@testset "lmul! and rmul!" for (a,b) in [((3,4),(4,3)), ((3,), (1,3)), ((1,3), (3))], T in eltypes
@test compare(rmul!, AT, rand(T, a), Ref(rand(T)))
@test compare(lmul!, AT, Ref(rand(T)), rand(T, b))
if isfloattype(T)
@test compare(rmul!, AT, fill(NaN_T(T), a), Ref(false))
@test compare(lmul!, AT, Ref(false), fill(NaN_T(T), b))
end
end

@testset "axp{b}y" for T in eltypes
@test compare(axpby!, AT, Ref(rand(T)), rand(T,5), Ref(rand(T)), rand(T,5))
@test compare(axpy!, AT, Ref(rand(T)), rand(T,5), rand(T,5))
if isfloattype(T)
@test compare(axpby!, AT, Ref(false), fill(NaN_T(T), 5), Ref(false), fill(NaN_T(T), 5))
@test compare(axpy!, AT, Ref(false), fill(NaN_T(T), 5), rand(T, 5))
end
end

@testset "dot" for T in eltypes
Expand All @@ -295,10 +303,18 @@

@testset "rotate!" for T in eltypes
@test compare(rotate!, AT, rand(T,5), rand(T,5), Ref(rand(real(T))), Ref(rand(T)))
if isfloattype(T)
# skip compare until https://github.com/JuliaLang/LinearAlgebra.jl/pull/1323 is released and only check correct strong zero behaviour of AbstractGPUArray
# @test compare(rotate!, AT, fill(NaN_T(T), 5), fill(NaN_T(T), 5), Ref(false), Ref(false))
@test !out_has_NaNs(rotate!, AT, fill(NaN_T(T), 5), fill(NaN_T(T), 5), Ref(false), Ref(false))
end
end

@testset "reflect!" for T in eltypes
@test compare(reflect!, AT, rand(T,5), rand(T,5), Ref(rand(real(T))), Ref(rand(T)))
if isfloattype(T)
@test compare(reflect!, AT, fill(NaN_T(T), 5), fill(NaN_T(T), 5), Ref(false), Ref(false))
end
end

@testset "iszero and isone" for T in eltypes
Expand Down Expand Up @@ -330,6 +346,13 @@ end
@test compare(*, AT, f(A), x)
@test compare(mul!, AT, y, f(A), x)
@test compare(mul!, AT, y, f(A), x, Ref(T(4)), Ref(T(5)))
if isfloattype(T)
y_NaN, A_NaN, x_NaN = fill(NaN_T(T), 4), fill(NaN_T(T), 4, 4), fill(NaN_T(T), 4)
if !(T==Float16) && !(T == ComplexF16) # skip Float16/ComplexF16 until https://github.com/JuliaLang/LinearAlgebra.jl/issues/1399 is fixed and only check correct strong zero behaviour of AbstractGPUArray
@test compare(mul!, AT, y_NaN, f(A_NaN), x_NaN, Ref(false), Ref(false))
end
@test !out_has_NaNs(mul!, AT, y_NaN, f(A_NaN), x_NaN, Ref(false), Ref(false))
end
@test typeof(AT(rand(T, 3, 3)) * AT(rand(T, 3))) <: AbstractVector

if f !== identity
Expand All @@ -348,6 +371,10 @@ end
@test compare(*, AT, f(A), g(B))
@test compare(mul!, AT, C, f(A), g(B))
@test compare(mul!, AT, C, f(A), g(B), Ref(T(4)), Ref(T(5)))
if isfloattype(T)
A_NaN, B_NaN, C_NaN = fill(NaN_T(T), 4, 4), fill(NaN_T(T), 4, 4), fill(NaN_T(T), 4, 4)
@test compare(mul!, AT, C_NaN, f(A_NaN), g(B_NaN), Ref(false), Ref(false))
end
@test typeof(AT(rand(T, 3, 3)) * AT(rand(T, 3, 3))) <: AbstractMatrix
end
end
Expand Down
Loading