Skip to content

perf: VQE ablation tests #1416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
version = "0.2.137"
version = "0.2.138"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -90,7 +90,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.15"
Reactant_jll = "0.0.209"
Reactant_jll = "0.0.210"
ScopedValues = "1.3.0"
Scratch = "1.2"
Sockets = "1.10"
Expand Down
257 changes: 257 additions & 0 deletions perf/common.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
using BenchmarkTools: @benchmark
using Reactant, Enzyme, PrettyTables, Statistics

function simple_mse_loss(model, x, z, ps, st)
y, _ = Lux.apply(model, x, ps, st)
return MSELoss()(y, z)
end

function simple_mse_loss_gradient(model, x, z, ps, st)
return Enzyme.gradient(
Reverse, simple_mse_loss, Const(model), Const(x), Const(z), ps, Const(st)
)
end

function benchmark_nn_primal(
model, x, z, ps, st; disable_scatter_gather_bench=true, disable_pad_bench=true
)
results = Vector{Tuple{String,String,Float64,Float64,Float64}}()

# Only XLA
compiled_fwd_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
sync=true
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(results, ("Primal", "Only XLA", median(bench).time, std(bench).time, 1.0))
baseline = median(bench).time

# Default
compiled_fwd = @compile sync = true simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(
results,
(
"Primal",
"All",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

# Disable Scatter
if disable_scatter_gather_bench
compiled_fwd_no_scatter = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true, sync=true
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Primal",
"No Scatter/Gather Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)
end

# Disable Pad
if disable_pad_bench
compiled_fwd_no_pad = @compile compile_options = CompileOptions(;
disable_pad_optimization_passes=true, sync=true
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Primal",
"No Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)
end

# Disable Scatter and Pad
if disable_scatter_gather_bench && disable_pad_bench
compiled_fwd_no_scatter_pad = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true,
disable_pad_optimization_passes=true,
sync=true,
) simple_mse_loss(model, x, z, ps, st)
bench = @benchmark $compiled_fwd_no_scatter_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Primal",
"No Scatter/Gather and Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)
end

sort!(results; by=x -> x[3])
return results
end

function benchmark_nn_gradient(model, x, z, ps, st; kwargs...)
return vcat(
[
benchmark_nn_gradient_internal(model, x, z, ps, st, mode; kwargs...) for
mode in [:all, :before_enzyme, :after_enzyme]
]...,
)
end

function benchmark_nn_gradient_internal(
model, x, z, ps, st, mode; disable_scatter_gather_bench=true, disable_pad_bench=true
)
@info "Benchmarking gradient with mode: $(Meta.quot(mode))"

results = Vector{Tuple{String,String,Float64,Float64,Float64}}()

# Only XLA
compiled_grad_xla = @compile compile_options = Reactant.DefaultXLACompileOptions(;
sync=true
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_xla($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(
results, ("Gradient ($mode)", "Only XLA", median(bench).time, std(bench).time, 1.0)
)
baseline = median(bench).time

display(results[end])

# Default
compiled_grad = @compile sync = true optimize = mode simple_mse_loss_gradient(
model, x, z, ps, st
)
bench = @benchmark $compiled_grad($model, $x, $z, $ps, $st) setup = (GC.gc(true))
push!(
results,
(
"Gradient ($mode)",
"All",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])

# Disable Scatter
if disable_scatter_gather_bench
compiled_grad_no_scatter = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true,
optimization_passes=mode,
sync=true,
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_no_scatter($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Gradient ($mode)",
"No Scatter/Gather Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])
end

# Disable Pad
if disable_pad_bench
compiled_grad_no_pad = @compile compile_options = CompileOptions(;
disable_pad_optimization_passes=true, optimization_passes=mode, sync=true
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Gradient ($mode)",
"No Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])
end

# Disable Pad and Scatter
if disable_scatter_gather_bench && disable_pad_bench
compiled_grad_no_scatter_no_pad = @compile compile_options = CompileOptions(;
disable_scatter_gather_optimization_passes=true,
disable_pad_optimization_passes=true,
optimization_passes=mode,
sync=true,
) simple_mse_loss_gradient(model, x, z, ps, st)
bench = @benchmark $compiled_grad_no_scatter_no_pad($model, $x, $z, $ps, $st) setup = (GC.gc(
true
))

push!(
results,
(
"Gradient ($mode)",
"No Scatter/Gather/Pad Optimizations",
median(bench).time,
std(bench).time,
median(bench).time / baseline,
),
)

display(results[end])
end

sort!(results; by=x -> x[3])
return results
end

function pretty_print_table(results)
header = (
["Mode", "Optimization Passes", "Median Time", "Std. Dev. Time", "Relative Timing"],
["", "", "s", "s", "Time / XLA Time"],
)

results = copy(results)
results[:, 3] ./= 1e9
results[:, 4] ./= 1e9

hl_r = Highlighter((data, i, j) -> j == 5 && data[i, j] > 1.0, crayon"bold red")
hl_g = Highlighter((data, i, j) -> j == 5 && data[i, j] < 1.0, crayon"bold green")
display(
pretty_table(
results;
header,
header_crayon=crayon"yellow bold",
highlighters=(hl_r, hl_g),
tf=tf_unicode_rounded,
),
)
return nothing
end
21 changes: 21 additions & 0 deletions perf/neuraloperators/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NeuralOperators = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"

[sources]
Reactant = {path = "../../"}

[compat]
BenchmarkTools = "1.6"
CSV = "0.10.15"
Lux = "1.13.4"
NeuralOperators = "0.6"
PrettyTables = "2.4.0"
Random = "1.11"
julia = "1.11"
Loading