Skip to content

Commit cf11e07

Browse files
CoreyGilesdfalbel
authored andcommitted
Included bfloat16 in unit tests for dtype
1 parent b744db7 commit cf11e07

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/testthat/test-dtype.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ test_that("Can create dtypes", {
77
expect_s3_class(torch_double(), "torch_dtype")
88
expect_s3_class(torch_float16(), "torch_dtype")
99
expect_s3_class(torch_half(), "torch_dtype")
10+
expect_s3_class(torch_bfloat16(), "torch_dtype")
1011
expect_s3_class(torch_uint8(), "torch_dtype")
1112
expect_s3_class(torch_int8(), "torch_dtype")
1213
expect_s3_class(torch_int16(), "torch_dtype")
@@ -47,6 +48,7 @@ test_that("can set select devices using strings", {
4748
"double" = torch_double(),
4849
"float16" = torch_float16(),
4950
"half" = torch_half(),
51+
"bfloat16" = torch_bfloat16(),
5052
"uint8" = torch_uint8(),
5153
"int8" = torch_int8(),
5254
"int16" = torch_int16(),

0 commit comments

Comments
 (0)