Skip to content

Commit ae425b5

Browse files
Update GPU syntax to allow for backend choices
1 parent 79f83a6 commit ae425b5

File tree

3 files changed

+735
-696
lines changed

3 files changed

+735
-696
lines changed

R/diffeqr.R

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,32 @@ jitoptimize_sde <- function (de,prob){
126126
#' }
127127
#'
128128
#' @export
129-
diffeqgpu_setup <- function (){
129+
diffeqgpu_setup <- function (backend){
130130
JuliaCall::julia_install_package_if_needed("DiffEqGPU")
131131
JuliaCall::julia_library("DiffEqGPU")
132132
functions <- JuliaCall::julia_eval("filter(isascii, replace.(string.(propertynames(DiffEqGPU)),\"!\"=>\"_bang\"))")
133133
degpu <- julia_pkg_import("DiffEqGPU",functions)
134+
135+
if (backend == "CUDA") {
136+
JuliaCall::julia_install_package_if_needed("CUDA")
137+
JuliaCall::julia_library("CUDA")
138+
backend <- julia_pkg_import("CUDA",c("CUDABackend"))
139+
} else if (backend == "AMDGPU") {
140+
JuliaCall::julia_install_package_if_needed("AMDGPU")
141+
JuliaCall::julia_library("AMDGPU")
142+
backend <- julia_pkg_import("AMDGPU",c("AMDGPUBackend"))
143+
} else if (backend == "Metal") {
144+
JuliaCall::julia_install_package_if_needed("Metal")
145+
JuliaCall::julia_library("Metal")
146+
backend <- julia_pkg_import("Metal",c("MetalBackend"))
147+
} else if (backend == "oneAPI") {
148+
JuliaCall::julia_install_package_if_needed("oneAPI")
149+
JuliaCall::julia_library("oneAPI")
150+
backend <- julia_pkg_import("oneAPI",c("oneAPIBackend"))
151+
} else {
152+
stop(paste("Illegal backend choice found. Allowed choices: CUDA, AMDGPU, Metal, and oneAPI. Chosen backend: ", backend)
153+
}
154+
list(degpu, backend)
134155
}
135156

136157
julia_function <- function(func_name, pkg_name = "Main",

0 commit comments

Comments
 (0)