@@ -126,11 +126,32 @@ jitoptimize_sde <- function (de,prob){
126
126
# ' }
127
127
# '
128
128
# ' @export
129
- diffeqgpu_setup <- function (){
129
+ diffeqgpu_setup <- function (backend ){
130
130
JuliaCall :: julia_install_package_if_needed(" DiffEqGPU" )
131
131
JuliaCall :: julia_library(" DiffEqGPU" )
132
132
functions <- JuliaCall :: julia_eval(" filter(isascii, replace.(string.(propertynames(DiffEqGPU)),\" !\" =>\" _bang\" ))" )
133
133
degpu <- julia_pkg_import(" DiffEqGPU" ,functions )
134
+
135
+ if (backend == " CUDA" ) {
136
+ JuliaCall :: julia_install_package_if_needed(" CUDA" )
137
+ JuliaCall :: julia_library(" CUDA" )
138
+ backend <- julia_pkg_import(" CUDA" ,c(" CUDABackend" ))
139
+ } else if (backend == " AMDGPU" ) {
140
+ JuliaCall :: julia_install_package_if_needed(" AMDGPU" )
141
+ JuliaCall :: julia_library(" AMDGPU" )
142
+ backend <- julia_pkg_import(" AMDGPU" ,c(" AMDGPUBackend" ))
143
+ } else if (backend == " Metal" ) {
144
+ JuliaCall :: julia_install_package_if_needed(" Metal" )
145
+ JuliaCall :: julia_library(" Metal" )
146
+ backend <- julia_pkg_import(" Metal" ,c(" MetalBackend" ))
147
+ } else if (backend == " oneAPI" ) {
148
+ JuliaCall :: julia_install_package_if_needed(" oneAPI" )
149
+ JuliaCall :: julia_library(" oneAPI" )
150
+ backend <- julia_pkg_import(" oneAPI" ,c(" oneAPIBackend" ))
151
+ } else {
152
+ stop(paste(" Illegal backend choice found. Allowed choices: CUDA, AMDGPU, Metal, and oneAPI. Chosen backend: " , backend )
153
+ }
154
+ list (degpu , backend )
134
155
}
135
156
136
157
julia_function <- function (func_name , pkg_name = " Main" ,
0 commit comments