From 35c505c279b731a9b2ebafff19e2e00124fe1827 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 1 Oct 2024 18:56:03 +0200 Subject: [PATCH] Optimize `to_rarray` to avoid recursion for simple inputs --- src/Tracing.jl | 4 ++++ test/Project.toml | 1 + test/tracing.jl | 11 ++++++++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/Tracing.jl b/src/Tracing.jl index ae4f3b4c66..ab491e0082 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -445,3 +445,7 @@ end @inline function to_rarray(@nospecialize(x)) return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete) end + +to_rarray(x::Number) = x # TODO: should this be a `ConcreteRArray{_,0}`? +to_rarray(x::ConcreteRArray) = x +to_rarray(x::AbstractArray{<:Number}) = ConcreteRArray(x) diff --git a/test/Project.toml b/test/Project.toml index 937025d11c..7b107b97d6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/test/tracing.jl b/test/tracing.jl index d75a435988..01e1245983 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -1,6 +1,7 @@ using Reactant -using Reactant: traced_type, ConcreteRArray, TracedRArray, ConcreteToTraced +using Reactant: to_rarray, traced_type, ConcreteRArray, TracedRArray, ConcreteToTraced using Test +using JET: @test_opt @testset "Tracing" begin @testset "trace_type" begin @@ -100,4 +101,12 @@ using Test end end end + @testset "to_rarray" begin + @test to_rarray(1.0) isa Float64 + @test to_rarray([1.0]) isa ConcreteRArray{Float64,1} + @test to_rarray(ConcreteRArray([1.0])) isa ConcreteRArray{Float64,1} + @test_opt to_rarray(1.0) + @test_opt to_rarray([1.0]) + @test_opt to_rarray(ConcreteRArray([1.0])) + end end