Open
Description
The spec requires that matmul
follows the type promotion rules for the arguments, but pytorch requires that the dtypes match:
In [3]: import array_api_strict as xp
In [5]: xp.ones(3, dtype=xp.float32) @ xp.ones(3, dtype=xp.float64)
Out[5]: Array(3., dtype=array_api_strict.float64)
In [6]: torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 torch.ones(3, dtype=torch.float32) @ torch.ones(3, dtype=torch.float64)
RuntimeError: dot : expected both vectors to have same dtype, but found Float and Double
It's not immediately clear to me whether we want to paper over it in compat-
or leave the conversion to end users: it's easy to imagine a use case were the copying overhead is significant.
Activity
rgommers commentedon Jan 30, 2025
There's no good way to override
@
behavior I think. Formatmul
we can do same-kind type promotion I think, there shouldn't be extra overhead - no other library has mixed-dtype implementations either AFAIK (e.g., seenp.matmul.types
).ev-br commentedon Jan 30, 2025
Numpy seems to do it:
cross-ref https://discuss.pytorch.org/t/matmul-mixed-dtypes/216044 for a pytorch discourse question.
rgommers commentedon Jan 30, 2025
Yeah I know, I didn't say it doesn't - I meant it does internal upcasting and then calls a routine with both dtypes being the same.