Skip to content

Insert call instructions (without any caching) #1276

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
4dde0d3
use `args_in_result=:all` instead of `result_and_mutated`. In the fut…
jumerckx May 5, 2025
50f6d88
fix implementation of `Base.similar(::Broadcasted, ...)` to call `Ops…
jumerckx May 5, 2025
a72a1a1
add `make_tracer` for `Memory`
jumerckx May 5, 2025
454e280
fix for `TracedToType`
jumerckx May 5, 2025
6a0fc38
allow `path=nothing` in `make_tracer` to keep path as is and just rec…
jumerckx May 5, 2025
885b09c
`make_mlir_fn` fixes
jumerckx May 5, 2025
8d64d4d
WIP automatic function call insertion
jumerckx May 5, 2025
cc64c9e
add `mutate_args`
jumerckx May 12, 2025
caf0fe8
TracedToTypes for TracedRArray should look at objectid.
jumerckx May 12, 2025
fb60071
with automatic call tracing, it turns out that `tobatch` can be !isno…
jumerckx May 12, 2025
f14c755
automatic function call insertion but don't cache at all
jumerckx May 12, 2025
3990b0e
disable const dedup test for now.
jumerckx May 12, 2025
68e7079
Merge remote-tracking branch 'origin/main' into jm/make_mlir_fn
jumerckx May 12, 2025
264d3dd
fix?
jumerckx May 12, 2025
7c732dc
disable 1.10 CI until I fix it.
jumerckx May 13, 2025
0eabdc7
fix `seen_args` error
jumerckx May 13, 2025
886b16c
change `should_trace_call` to `trace_call_within` to still generate a…
jumerckx May 13, 2025
31a55b8
add try-finally in generated code
jumerckx May 13, 2025
1735b43
Merge branch 'main' into jm/make_mlir_fn
jumerckx May 13, 2025
0ad0017
`set_mlir_data!` for `MissingTracedValue`
jumerckx May 13, 2025
02ca88f
Revert "disable 1.10 CI until I fix it."
jumerckx May 13, 2025
efee7ff
fallback for `set_mlir_data!`
jumerckx May 14, 2025
364c9b9
don't trace through functions that capture values (sizeof != 0) or Br…
jumerckx May 17, 2025
12b689e
also take into acount resarg
jumerckx May 17, 2025
5fd1b4c
fix
jumerckx May 17, 2025
899bf9f
remove Dict typing in `LinearAlgebra._diagm` as this gives trouble wh…
jumerckx May 17, 2025
4e0af6f
formatting
jumerckx May 21, 2025
9f15feb
Merge branch 'main' into jm/make_mlir_fn
jumerckx May 21, 2025
aeb068c
guard `make_tracer` for Memory for v1.11
jumerckx May 21, 2025
c680db1
Merge branch 'main' into jm/make_mlir_fn
jumerckx Jun 15, 2025
0905b42
bail is_traced when a field is undefined
jumerckx Jun 18, 2025
c0448bb
don't trace calls that have no traced arguments
jumerckx Jun 18, 2025
9fd2246
insert function name in verify_arg_names when call is wrapped in `Rea…
jumerckx Jun 25, 2025
02fe9b1
locally disable call tracing in `Base._cat_t(..., ::TracedRArray)`
jumerckx Jun 25, 2025
6270b53
disable call tracing in `from_locals` expression in `trace_while`
jumerckx Jun 25, 2025
49c1345
don't blacklist methods from Reactant modules now that methods are no…
jumerckx Jun 25, 2025
5f91692
disable for some control flow tests when call tracing is enabled.
jumerckx Jun 25, 2025
1936393
Merge remote-tracking branch 'origin/main' into jm/make_mlir_fn
jumerckx Jul 16, 2025
6b15ca2
fix
jumerckx Jul 17, 2025
c288fce
formatting
jumerckx Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export @trace, within_compile, MissingTracedValue, promote_to_traced
function is_traced((@nospecialize x::T), seen=Base.IdSet()) where {T}
if !isprimitivetype(x)
for fn in fieldnames(T)
isdefined(x, fn) || continue
f = getfield(x, fn)
if !(f in seen)
push!(seen, f)
Expand Down Expand Up @@ -247,7 +248,10 @@ function trace_while(expr; track_numbers, mincut, checkpointing, first_arg=nothi
$body_fn_sym = $(arg_syms) -> begin
$(to_locals...)
$body
temp = Reactant.TRACE_CALLS[]
Reactant.TRACE_CALLS[] = false
$(from_locals...)
Reactant.TRACE_CALLS[] = temp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be changed because we can't depend on Reactant here. I could add a macro @notrace_calls in ReactantCore?

nothing
end

Expand All @@ -256,7 +260,6 @@ function trace_while(expr; track_numbers, mincut, checkpointing, first_arg=nothi
else
($(QuoteNode.(args_names.args)...),)
end

$(ReactantCore).traced_while(
$(cond_fn_sym),
$(body_fn_sym),
Expand Down
6 changes: 5 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,8 @@ function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}
end

function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
temp = Reactant.TRACE_CALLS[]
Reactant.TRACE_CALLS[] = false
dims = dispatch_val(dims)
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."

Expand All @@ -860,7 +862,7 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
# convert to the target eltype
X = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{RT,length(shape)}), X)

return TracedRArray{RT,length(shape)}(
result = TracedRArray{RT,length(shape)}(
(),
MLIR.IR.result(
# TODO maybe we should do some conversion?
Expand All @@ -873,6 +875,8 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
),
shape,
)
Reactant.TRACE_CALLS[] = temp
return result
end

for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber))
Expand Down
26 changes: 22 additions & 4 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ function set_mlir_data!(x::AnyTracedRArray{T}, data) where {T}
return x
end

function set_mlir_data!(x::T, data) where {T}
@warn "Setting mlir data on a $T is a no-op."
return x
end

get_ancestor_indices(::TracedRArray, indices) = indices
get_ancestor_indices(::TracedRArray, indices, args...) = (indices, args...)

Expand Down Expand Up @@ -279,6 +284,10 @@ function make_mlir_fn(
optimize_then_pad::Bool=true,
)
if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction
if !isnothing(verify_arg_names)
verify_arg_names = (nameof(f), verify_arg_names...)
end

mlir_fn_res = make_mlir_fn(
Reactant.apply,
(f, args...),
Expand Down Expand Up @@ -307,6 +316,7 @@ function make_mlir_fn(
args,
name,
concretein,
false, # mutate_args
toscalar,
argprefix,
runtime,
Expand Down Expand Up @@ -347,7 +357,7 @@ function make_mlir_fn(
end
end

(func2, traced_result, ret, linear_args, in_tys, linear_results, skipped_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
(; func2, traced_result, ret, linear_args, in_tys, linear_results, skipped_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
result,
traced_args,
linear_args,
Expand Down Expand Up @@ -409,6 +419,7 @@ function prepare_mlir_fn_args(
args,
name,
concretein,
mutate_args,
toscalar,
argprefix,
runtime,
Expand All @@ -423,7 +434,11 @@ function prepare_mlir_fn_args(
@assert !toscalar
Reactant.ConcreteToTraced
else
Reactant.TracedSetPath
if mutate_args
Reactant.TracedTrack
else
Reactant.TracedSetPath
end
end
fnbody = MLIR.IR.Block(MLIR.IR.Type[], MLIR.IR.Location[])
MLIR.IR.activate!(fnbody)
Expand Down Expand Up @@ -836,9 +851,11 @@ function finalize_mlir_fn(
MLIR.IR.deactivate!(fnbody)
end

f_name = __lookup_unique_name_in_module(mod, name)

func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=__lookup_unique_name_in_module(mod, name),
sym_name=f_name,
function_type=MLIR.IR.FunctionType(in_tys, out_tys),
body=MLIR.IR.Region(),
arg_attrs=MLIR.IR.attr(func, "arg_attrs"),
Expand Down Expand Up @@ -969,8 +986,9 @@ function finalize_mlir_fn(
MLIR.API.mlirOperationDestroy(func.operation)
func.operation = MLIR.API.MlirOperation(C_NULL)

return (
return (;
func2,
f_name,
traced_result,
ret,
linear_args,
Expand Down
75 changes: 61 additions & 14 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,10 @@ Base.@nospecializeinfer function traced_type_inner(
}
end
error("Unsupported runtime $runtime")
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
elseif mode == TracedTrack ||
mode == NoStopTracedTrack ||
mode == TracedSetPath ||
mode == TracedToTypes
return T
else
throw("Abstract RArray cannot be made concrete in mode $mode")
Expand Down Expand Up @@ -427,7 +430,10 @@ Base.@nospecializeinfer function traced_type_inner(
}
end
error("Unsupported runtime $runtime")
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
elseif mode == TracedTrack ||
mode == NoStopTracedTrack ||
mode == TracedSetPath ||
mode == TracedToTypes
return T
else
throw("Abstract RNumber cannot be made concrete in mode $mode")
Expand Down Expand Up @@ -921,6 +927,7 @@ function make_tracer(
return prev
end
append_path(@nospecialize(path), i) = (path..., i)
append_path(::Nothing, i) = nothing

Base.@nospecializeinfer function make_tracer_via_immutable_constructor(
seen,
Expand Down Expand Up @@ -1148,6 +1155,23 @@ function make_tracer(
)
end

@static if VERSION >= v"1.11.0"
Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Memory),
@nospecialize(path),
mode;
@nospecialize(sharding = Sharding.NoSharding()),
kwargs...,
)
if mode == TracedToTypes
return nothing
end
# TODO: does anything more need to be done here?
return prev
end
end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::ConcretePJRTArray{T,N}),
Expand Down Expand Up @@ -1239,31 +1263,41 @@ Base.@nospecializeinfer function make_tracer(
throw("Cannot trace existing trace type")
end
if mode == TracedToTypes
push!(path, MLIR.IR.type(prev.mlir_data))
# for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
# i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
if haskey(seen, objectid(prev))
push!(path, seen[objectid(prev)])
else
push!(path, MLIR.IR.type(prev.mlir_data))
seen[objectid(prev)] = VisitedObject(length(seen) + 1)
end
return nothing
end
if mode == TracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
return seen[prev] = prev
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if haskey(seen, prev)
return seen[prev]
end
res = if toscalar
TracedRNumber{T}((path,), nothing)
elseif tobatch !== nothing
elseif tobatch !== nothing && prev.shape != tobatch
error("This should not happen...")
else
TracedRArray{T,N}((path,), prev.mlir_data, size(prev))
Expand Down Expand Up @@ -1317,25 +1351,35 @@ Base.@nospecializeinfer function make_tracer(
throw("Cannot trace existing trace type")
end
if mode == TracedToTypes
push!(path, MLIR.IR.type(prev.mlir_data))
# for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
# i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
if haskey(seen, objectid(prev))
push!(path, seen[objectid(prev)])
else
push!(path, MLIR.IR.type(prev.mlir_data))
seen[objectid(prev)] = VisitedObject(length(seen) + 1)
end
return nothing
end
if mode == TracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
return seen[prev] = prev
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if haskey(seen, prev)
return seen[prev]
end
Expand Down Expand Up @@ -1390,21 +1434,24 @@ Base.@nospecializeinfer function make_tracer(
throw("Cannot have MissingTracedValue as function call argument.")
end
if mode == TracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
return seen[prev] = prev
end
return prev
end
if mode == NoStopTracedTrack
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
if !haskey(seen, prev)
seen[prev] = prev # don't return!
end
return prev
end
if mode == TracedSetPath
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
!isnothing(path) &&
TracedUtils.set_paths!(prev, (TracedUtils.get_paths(prev)..., path))
haskey(seen, prev) && return seen[prev]
res = MissingTracedValue((path,))
seen[res] = res
Expand Down
2 changes: 1 addition & 1 deletion src/stdlibs/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ function LinearAlgebra._diagm(
m, n = LinearAlgebra.diagm_size(shape, kv...)

# For repeated indices we need to aggregate the values
kv_updated = Dict{Integer,AnyTracedRArray{T,1}}()
kv_updated = Dict()
for (k, v) in kv
if haskey(kv_updated, k)
kv_updated[k] = kv_updated[k] + v
Expand Down
Loading
Loading