Skip to content

Commit 4c5e5a3

Browse files
TroyGardenPaulZhang12
authored andcommitted
add test for torch.float16 and torch.bfloat16 (#2301)
Summary: Pull Request resolved: #2301 # context * We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test * added test to cover the dtype support * before the operator change, we see the following error ``` Failures: 1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype 1) RuntimeError: expected scalar type Float but found Half File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype outputs = torch.ops.fbgemm.permute_multi_embedding( File "torch/_ops.py", line 1113, in __call__ return self._op(*args, **(kwargs or {})) ``` * suspicion is that in the cpu operator, there are tensor data access with `data_ptr<float>` in the code, which limited the dtype could only be `float32` ``` auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset; auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset; ``` # changes * use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`. * after the change the operator can support `float16`, `bfloat16` WARNING: somehow this operator still can't support `int` types. Reviewed By: dstaay-fb Differential Revision: D56051305 fbshipit-source-id: 9b15f82ff11c77fbdbfb59afc170698d5b0c826d
1 parent 865cf55 commit 4c5e5a3

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2775,6 +2775,47 @@ def test_multi_permute_forward(self, device_str: str, batch_size: int) -> None:
27752775
for out, ref in zip(outputs, refs):
27762776
torch.testing.assert_close(out, ref)
27772777

2778+
@repeat_test(
2779+
device_str=["meta", "cpu", "cuda"],
2780+
dtype=[
2781+
torch.float,
2782+
torch.float32,
2783+
torch.float16,
2784+
torch.bfloat16,
2785+
],
2786+
)
2787+
def test_multi_permute_dtype(self, device_str: str, dtype: torch.dtype) -> None:
2788+
if device_str == "cuda" and not torch.cuda.is_available():
2789+
return
2790+
else:
2791+
device = torch.device(device_str)
2792+
batch_size = 4
2793+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
2794+
lengths = [[3, 4], [5, 6, 7], [8]]
2795+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
2796+
values = [
2797+
torch.randn(batch_size, sum(L), device=device, dtype=dtype) for L in lengths
2798+
]
2799+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments(
2800+
values[0], keys, lengths, groups
2801+
)
2802+
outputs = torch.ops.fbgemm.permute_multi_embedding(
2803+
values, permutes, in_shapes, out_shapes, out_lengths
2804+
)
2805+
2806+
if device_str == "meta":
2807+
for out, ref in zip(outputs, out_lengths):
2808+
self.assertEqual(out.shape, (batch_size, ref))
2809+
else:
2810+
refs = [[] for _ in groups]
2811+
for i in range(permutes.size(0)):
2812+
in_idx, out, in_start, _, length, _ = permutes[i].tolist()
2813+
refs[out].append(values[in_idx][:, in_start : (in_start + length)])
2814+
refs = [torch.cat(ref, dim=1) for ref in refs]
2815+
for out, ref in zip(outputs, refs):
2816+
torch.testing.assert_close(out, ref)
2817+
self.assertEqual(out.dtype, ref.dtype)
2818+
27782819
@repeat_test(
27792820
["cpu", 32, [[3, 4], [5, 6, 7], [8]]],
27802821
["cuda", 128, [[96, 256], [512, 128, 768], [1024]]],

0 commit comments

Comments
 (0)