Skip to content

Commit b744db7

Browse files
CoreyGilesdfalbel
authored andcommitted
Improved support for bfloat16 dtype
1 parent 1b0433f commit b744db7

File tree

7 files changed

+30
-0
lines changed

7 files changed

+30
-0
lines changed

R/RcppExports.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,10 @@ cpp_torch_float16 <- function() {
329329
.Call(`_torch_cpp_torch_float16`)
330330
}
331331

332+
cpp_torch_bfloat16 <- function() {
333+
.Call(`_torch_cpp_torch_bfloat16`)
334+
}
335+
332336
cpp_torch_uint8 <- function() {
333337
.Call(`_torch_cpp_torch_uint8`)
334338
}

R/dtype.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ torch_float16 <- function() torch_dtype$new(cpp_torch_float16())
109109
#' @rdname torch_dtype
110110
#' @export
111111
torch_half <- function() torch_dtype$new(cpp_torch_float16())
112+
#' @rdname torch_dtype
113+
#' @export
114+
torch_bfloat16 <- function() torch_dtype$new(cpp_torch_bfloat16())
112115

113116
#' @rdname torch_dtype
114117
#' @export

inst/include/lantern/lantern.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ LANTERN_OPTIONAL_DECLS(string_view)
256256
HOST_API void * lantern_Dtype_float64() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float64(); LANTERN_HOST_HANDLER return ret;}
257257
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_float16)();
258258
HOST_API void * lantern_Dtype_float16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float16(); LANTERN_HOST_HANDLER return ret;}
259+
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_bfloat16)();
260+
HOST_API void * lantern_Dtype_bfloat16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_bfloat16(); LANTERN_HOST_HANDLER return ret;}
259261
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_uint8)();
260262
HOST_API void * lantern_Dtype_uint8() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_uint8(); LANTERN_HOST_HANDLER return ret;}
261263
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_int8)();
@@ -10530,6 +10532,7 @@ bool lanternInit(const std::string &libPath, std::string *pError)
1053010532
LOAD_SYMBOL(_lantern_Dtype_float32);
1053110533
LOAD_SYMBOL(_lantern_Dtype_float64);
1053210534
LOAD_SYMBOL(_lantern_Dtype_float16);
10535+
LOAD_SYMBOL(_lantern_Dtype_bfloat16);
1053310536
LOAD_SYMBOL(_lantern_Dtype_uint8);
1053410537
LOAD_SYMBOL(_lantern_Dtype_int8);
1053510538
LOAD_SYMBOL(_lantern_Dtype_int16);

src/RcppExports.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,16 @@ BEGIN_RCPP
910910
return rcpp_result_gen;
911911
END_RCPP
912912
}
913+
// cpp_torch_bfloat16
914+
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_bfloat16();
915+
RcppExport SEXP _torch_cpp_torch_bfloat16() {
916+
BEGIN_RCPP
917+
Rcpp::RObject rcpp_result_gen;
918+
Rcpp::RNGScope rcpp_rngScope_gen;
919+
rcpp_result_gen = Rcpp::wrap(cpp_torch_bfloat16());
920+
return rcpp_result_gen;
921+
END_RCPP
922+
}
913923
// cpp_torch_uint8
914924
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_uint8();
915925
RcppExport SEXP _torch_cpp_torch_uint8() {
@@ -48231,6 +48241,7 @@ static const R_CallMethodDef CallEntries[] = {
4823148241
{"_torch_cpp_torch_float32", (DL_FUNC) &_torch_cpp_torch_float32, 0},
4823248242
{"_torch_cpp_torch_float64", (DL_FUNC) &_torch_cpp_torch_float64, 0},
4823348243
{"_torch_cpp_torch_float16", (DL_FUNC) &_torch_cpp_torch_float16, 0},
48244+
{"_torch_cpp_torch_bfloat16", (DL_FUNC) &_torch_cpp_torch_bfloat16, 0},
4823448245
{"_torch_cpp_torch_uint8", (DL_FUNC) &_torch_cpp_torch_uint8, 0},
4823548246
{"_torch_cpp_torch_int8", (DL_FUNC) &_torch_cpp_torch_int8, 0},
4823648247
{"_torch_cpp_torch_int16", (DL_FUNC) &_torch_cpp_torch_int16, 0},

src/dtype.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ Rcpp::XPtr<XPtrTorchDtype> cpp_torch_float16() {
2020
return make_xptr<XPtrTorchDtype>(lantern_Dtype_float16());
2121
}
2222

23+
// [[Rcpp::export]]
24+
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_bfloat16() {
25+
return make_xptr<XPtrTorchDtype>(lantern_Dtype_bfloat16());
26+
}
27+
2328
// [[Rcpp::export]]
2429
Rcpp::XPtr<XPtrTorchDtype> cpp_torch_uint8() {
2530
return make_xptr<XPtrTorchDtype>(lantern_Dtype_uint8());

src/lantern/include/lantern/lantern.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ LANTERN_OPTIONAL_DECLS(string_view)
256256
HOST_API void * lantern_Dtype_float64() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float64(); LANTERN_HOST_HANDLER return ret;}
257257
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_float16)();
258258
HOST_API void * lantern_Dtype_float16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_float16(); LANTERN_HOST_HANDLER return ret;}
259+
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_bfloat16)();
260+
HOST_API void * lantern_Dtype_bfloat16() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_bfloat16(); LANTERN_HOST_HANDLER return ret;}
259261
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_uint8)();
260262
HOST_API void * lantern_Dtype_uint8() {LANTERN_CHECK_LOADED void * ret = _lantern_Dtype_uint8(); LANTERN_HOST_HANDLER return ret;}
261263
LANTERN_API void *(LANTERN_PTR _lantern_Dtype_int8)();
@@ -10530,6 +10532,7 @@ bool lanternInit(const std::string &libPath, std::string *pError)
1053010532
LOAD_SYMBOL(_lantern_Dtype_float32);
1053110533
LOAD_SYMBOL(_lantern_Dtype_float64);
1053210534
LOAD_SYMBOL(_lantern_Dtype_float16);
10535+
LOAD_SYMBOL(_lantern_Dtype_bfloat16);
1053310536
LOAD_SYMBOL(_lantern_Dtype_uint8);
1053410537
LOAD_SYMBOL(_lantern_Dtype_int8);
1053510538
LOAD_SYMBOL(_lantern_Dtype_int16);

src/lantern/src/Dtype.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
LANTERN_DTYPE_FUN(float16, kFloat16)
1414
LANTERN_DTYPE_FUN(float32, kFloat32)
1515
LANTERN_DTYPE_FUN(float64, kFloat64)
16+
LANTERN_DTYPE_FUN(bfloat16, kBFloat16)
1617
LANTERN_DTYPE_FUN(int8, kInt8)
1718
LANTERN_DTYPE_FUN(int16, kInt16)
1819
LANTERN_DTYPE_FUN(int32, kInt32)

0 commit comments

Comments
 (0)