Skip to content

Commit 240eccf

Browse files
committed
feat: type-restrict arrays
1 parent bf01b7f commit 240eccf

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

src/Reactant.jl

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,26 @@ include("OrderedIdDict.jl")
77

88
using Enzyme
99

10-
abstract type RArray{T,N} <: AbstractArray{T,N} end
11-
abstract type RNumber{T} <: Number end
10+
const ReactantPrimitives = Union{
11+
Bool,
12+
Int8,
13+
UInt8,
14+
Int16,
15+
UInt16,
16+
Int32,
17+
UInt32,
18+
Int64,
19+
UInt64,
20+
Float16,
21+
Float32,
22+
# BFloat16,
23+
Float64,
24+
Complex{Float32},
25+
Complex{Float64},
26+
}
27+
28+
abstract type RArray{T<:ReactantPrimitives,N} <: AbstractArray{T,N} end
29+
abstract type RNumber{T<:ReactantPrimitives} <: Number end
1230

1331
function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
1432
return reshape(A, Base._reshape_uncolon(A, dims))

src/TracedRArray.jl

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,7 @@ end
1919

2020
TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x
2121

22-
const ReactantPrimitives = Union{
23-
Bool,
24-
Int8,
25-
UInt8,
26-
Int16,
27-
UInt16,
28-
Int32,
29-
UInt32,
30-
Int64,
31-
UInt64,
32-
Float16,
33-
Float32,
34-
# BFloat16,
35-
Float64,
36-
Complex{Float32},
37-
Complex{Float64},
38-
}
39-
40-
# `<: ReactantPrimitives` ensures we don't end up with nested `TracedRNumber`s
41-
mutable struct TracedRNumber{T<:ReactantPrimitives} <: RNumber{T}
22+
mutable struct TracedRNumber{T} <: RNumber{T}
4223
paths::Tuple
4324
mlir_data::Union{Nothing,MLIR.IR.Value}
4425

0 commit comments

Comments
 (0)