diff --git a/cubed/__init__.py b/cubed/__init__.py index 9134619d0..a87e912fb 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -337,6 +337,6 @@ # extensions -from .array_api import linalg +from .array_api import fft, linalg -__all__ += ["linalg"] +__all__ += ["fft", "linalg"] diff --git a/cubed/array_api/fft.py b/cubed/array_api/fft.py new file mode 100644 index 000000000..d690e4221 --- /dev/null +++ b/cubed/array_api/fft.py @@ -0,0 +1,39 @@ +from cubed.backend_array_api import namespace as nxp +from cubed.core.ops import map_blocks + + +def fft(x, /, *, n=None, axis=-1, norm="backward"): + return fft_1d(nxp.fft.fft, x, n=n, axis=axis, norm=norm) + + +def ifft(x, /, *, n=None, axis=-1, norm="backward"): + return fft_1d(nxp.fft.ifft, x, n=n, axis=axis, norm=norm) + + +def fft_1d(fft_func, x, /, *, n=None, axis=-1, norm="backward"): + if x.numblocks[axis] > 1: + raise ValueError( + "FFT can only be applied along axes with a single chunk. " + # TODO: give details about what was tried and mention rechunking (see qr message) + ) + + if n is None: + chunks = x.chunks + else: + chunks = list(x.chunks) + chunks[axis] = (n,) + + return map_blocks( + _fft, + x, + dtype=nxp.complex128, + chunks=chunks, + fft_func=fft_func, + n=n, + axis=axis, + norm=norm, + ) + + +def _fft(a, fft_func=None, n=None, axis=None, norm=None): + return fft_func(a, n=n, axis=axis, norm=norm) diff --git a/cubed/tests/test_fft.py b/cubed/tests/test_fft.py new file mode 100644 index 000000000..3b408538d --- /dev/null +++ b/cubed/tests/test_fft.py @@ -0,0 +1,28 @@ +import pytest +from numpy.testing import assert_allclose + +import cubed +import cubed.array_api as xp +from cubed.backend_array_api import namespace as nxp + + +@pytest.mark.parametrize("funcname", ["fft", "ifft"]) +@pytest.mark.parametrize("n", [None, 5, 13]) +def test_fft(funcname, n): + nxp_fft = getattr(nxp.fft, funcname) + cb_fft = getattr(xp.fft, funcname) + + an = nxp.arange(100).reshape(10, 10) + bn = nxp_fft(an, n=n) + a = cubed.from_array(an, chunks=(1, 10)) + b = cb_fft(a, n=n) + + assert_allclose(b.compute(), bn) + + +def test_fft_chunked_axis_fails(): + an = nxp.arange(100).reshape(10, 10) + a = cubed.from_array(an, chunks=(1, 10)) + + with pytest.raises(ValueError): + xp.fft.fft(a, axis=0)