You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
add test for torch.float16 and torch.bfloat16 (#2301)
Summary:
Pull Request resolved: #2301
# context
* We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test
* added test to cover the dtype support
* before the operator change, we see the following error
```
Failures:
1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype
1) RuntimeError: expected scalar type Float but found Half
File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype
outputs = torch.ops.fbgemm.permute_multi_embedding(
File "torch/_ops.py", line 1113, in __call__
return self._op(*args, **(kwargs or {}))
```
* suspicion is that in the cpu operator, there are tensor data access with `data_ptr<float>` in the code, which limited the dtype could only be `float32`
```
auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset;
auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset;
```
# changes
* use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`.
* after the change the operator can support `float16`, `bfloat16`
WARNING: somehow this operator still can't support `int` types.
Reviewed By: dstaay-fb
Differential Revision: D56051305
fbshipit-source-id: 9b15f82ff11c77fbdbfb59afc170698d5b0c826d
0 commit comments