Skip to content

Improved support for Bfloat16 #1309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,10 @@ cpp_torch_float16 <- function() {
.Call(`_torch_cpp_torch_float16`)
}

cpp_torch_bfloat16 <- function() {
.Call(`_torch_cpp_torch_bfloat16`)
}

cpp_torch_uint8 <- function() {
.Call(`_torch_cpp_torch_uint8`)
}
Expand Down
3 changes: 3 additions & 0 deletions R/dtype.R
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ torch_float16 <- function() torch_dtype$new(cpp_torch_float16())
#' @rdname torch_dtype
#' @export
torch_half <- function() torch_dtype$new(cpp_torch_float16())
#' @rdname torch_dtype
#' @export
torch_bfloat16 <- function() torch_dtype$new(cpp_torch_bfloat16())

#' @rdname torch_dtype
#' @export
Expand Down
3 changes: 3 additions & 0 deletions inst/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ LANTERN_OPTIONAL_DECLS(string_view)
HOST_API void * lantern_Dtype_float64() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float64(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_float16)();
HOST_API void * lantern_Dtype_float16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float16(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_bfloat16)();
HOST_API void * lantern_Dtype_bfloat16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_bfloat16(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_uint8)();
HOST_API void * lantern_Dtype_uint8() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_uint8(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_int8)();
Expand Down Expand Up @@ -10530,6 +10532,7 @@ bool lanternInit(const std::string &libPath, std::string *pError)
LOAD_SYMBOL(_lantern_Dtype_float32);
LOAD_SYMBOL(_lantern_Dtype_float64);
LOAD_SYMBOL(_lantern_Dtype_float16);
LOAD_SYMBOL(_lantern_Dtype_bfloat16);
LOAD_SYMBOL(_lantern_Dtype_uint8);
LOAD_SYMBOL(_lantern_Dtype_int8);
LOAD_SYMBOL(_lantern_Dtype_int16);
Expand Down
11 changes: 11 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,16 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// cpp_torch_bfloat16
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_bfloat16();
RcppExport SEXP _torch_cpp_torch_bfloat16() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(cpp_torch_bfloat16());
return rcpp_result_gen;
END_RCPP
}
// cpp_torch_uint8
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_uint8();
RcppExport SEXP _torch_cpp_torch_uint8() {
Expand Down Expand Up @@ -48231,6 +48241,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_torch_cpp_torch_float32", (DL_FUNC) &_torch_cpp_torch_float32, 0},
{"_torch_cpp_torch_float64", (DL_FUNC) &_torch_cpp_torch_float64, 0},
{"_torch_cpp_torch_float16", (DL_FUNC) &_torch_cpp_torch_float16, 0},
{"_torch_cpp_torch_bfloat16", (DL_FUNC) &_torch_cpp_torch_bfloat16, 0},
{"_torch_cpp_torch_uint8", (DL_FUNC) &_torch_cpp_torch_uint8, 0},
{"_torch_cpp_torch_int8", (DL_FUNC) &_torch_cpp_torch_int8, 0},
{"_torch_cpp_torch_int16", (DL_FUNC) &_torch_cpp_torch_int16, 0},
Expand Down
5 changes: 5 additions & 0 deletions src/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ Rcpp::XPtr<XPtrTorchDtype> cpp_torch_float16() {
return make_xptr<XPtrTorchDtype>(lantern_Dtype_float16());
}

// [[Rcpp::export]]
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_bfloat16() {
return make_xptr<XPtrTorchDtype>(lantern_Dtype_bfloat16());
}

// [[Rcpp::export]]
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_uint8() {
return make_xptr<XPtrTorchDtype>(lantern_Dtype_uint8());
Expand Down
3 changes: 3 additions & 0 deletions src/lantern/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ LANTERN_OPTIONAL_DECLS(string_view)
HOST_API void * lantern_Dtype_float64() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float64(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_float16)();
HOST_API void * lantern_Dtype_float16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float16(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_bfloat16)();
HOST_API void * lantern_Dtype_bfloat16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_bfloat16(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_uint8)();
HOST_API void * lantern_Dtype_uint8() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_uint8(); LANTERN_HOST_HANDLER return ret;}
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_int8)();
Expand Down Expand Up @@ -10530,6 +10532,7 @@ bool lanternInit(const std::string &libPath, std::string *pError)
LOAD_SYMBOL(_lantern_Dtype_float32);
LOAD_SYMBOL(_lantern_Dtype_float64);
LOAD_SYMBOL(_lantern_Dtype_float16);
LOAD_SYMBOL(_lantern_Dtype_bfloat16);
LOAD_SYMBOL(_lantern_Dtype_uint8);
LOAD_SYMBOL(_lantern_Dtype_int8);
LOAD_SYMBOL(_lantern_Dtype_int16);
Expand Down
1 change: 1 addition & 0 deletions src/lantern/src/Dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LANTERN_DTYPE_FUN(float16, kFloat16)
LANTERN_DTYPE_FUN(float32, kFloat32)
LANTERN_DTYPE_FUN(float64, kFloat64)
LANTERN_DTYPE_FUN(bfloat16, kBFloat16)
LANTERN_DTYPE_FUN(int8, kInt8)
LANTERN_DTYPE_FUN(int16, kInt16)
LANTERN_DTYPE_FUN(int32, kInt32)
Expand Down
2 changes: 2 additions & 0 deletions tests/testthat/test-dtype.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ test_that("Can create dtypes", {
expect_s3_class(torch_double(), "torch_dtype")
expect_s3_class(torch_float16(), "torch_dtype")
expect_s3_class(torch_half(), "torch_dtype")
expect_s3_class(torch_bfloat16(), "torch_dtype")
expect_s3_class(torch_uint8(), "torch_dtype")
expect_s3_class(torch_int8(), "torch_dtype")
expect_s3_class(torch_int16(), "torch_dtype")
Expand Down Expand Up @@ -47,6 +48,7 @@ test_that("can set select devices using strings", {
"double" = torch_double(),
"float16" = torch_float16(),
"half" = torch_half(),
"bfloat16" = torch_bfloat16(),
"uint8" = torch_uint8(),
"int8" = torch_int8(),
"int16" = torch_int16(),
Expand Down
Loading