Skip to content

Jax float64 precision issues do not play ball with hypothesis #368

Open
@ev-br

Description

@ev-br

A typical example is (test_diff):

self = <hypothesis.extra.array_api.ArrayStrategy object at 0x7e6a6cf6c990>, val = 2.112233982580733, val_0d = Array(2.1122339, dtype=float32)
strategy = FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308)

    def check_set_value(self, val, val_0d, strategy):
        if val == val and self.builtin(val_0d) != val:
            if self.builtin is float:
                assert self.finfo is not None  # for mypy
                try:
                    is_subnormal = 0 < abs(val) < self.finfo.smallest_normal
                except Exception:
                    # val may be a non-float that does not support the
                    # operations __lt__ and __abs__
                    is_subnormal = False
                if is_subnormal:
                    raise InvalidArgument(
                        f"Generated subnormal float {val} from strategy "
                        f"{strategy} resulted in {val_0d!r}, probably "
                        f"as a result of array module {self.xp.__name__} "
                        "being built with flush-to-zero compiler options. "
                        "Consider passing allow_subnormal=False."
                    )
>           raise InvalidArgument(
                f"Generated array element {val!r} from strategy {strategy} "
                f"cannot be represented with dtype {self.dtype}. "
                f"Array module {self.xp.__name__} instead "
                f"represents the element as {val_0d}. "
                "Consider using a more precise elements strategy, "
                "for example passing the width argument to floats()."
            )
E           hypothesis.errors.InvalidArgument: Generated array element 2.112233982580733 from strategy FloatStrategy(min_value=2.0, max_value=64.0, allow_nan=False, smallest_nonzero_magnitude=2.2250738585072014e-308) cannot be represented with dtype <class 'jax.numpy.float64'>. Array module jax.numpy instead represents the element as 2.112233877182007. Consider using a more precise elements strategy, for example passing the width argument to floats().
E           while generating 'x' from sampled_from((<class 'jax.numpy.uint8'>, <class 'jax.numpy.int8'>, <class 'jax.numpy.int16'>, <class 'jax.numpy.int32'>, <class 'jax.numpy.float32'>, <class 'jax.numpy.float64'>, <class 'jax.numpy.complex64'>, <class 'jax.numpy.complex128'>)).flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
E           Explanation:
E               These lines were always and only run by failing examples:
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:328
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/jax/_src/array.py:651
E                   /home/ev-br/.conda/envs/array-api/lib/python3.11/site-packages/numpy/_core/getlimits.py:609

Activity

jakevdp

jakevdp commented on May 5, 2025

@jakevdp
Contributor

JAX is only compliant with the array api drype semantics when jax_enable_x64 is set to true. Any testing would have to take that into account.

ev-br

ev-br commented on May 9, 2025

@ev-br
MemberAuthor

Thanks Jake!
So for completeness, the stanza to locally run a test from the test suite is

$ JAX_ENABLE_X64=true ARRAY_API_TESTS_VERSION="2024.12" ARRAY_API_TESTS_MODULE=jax.numpy pytest path/to/test

(EDITED to account for the correction below.)

jakevdp

jakevdp commented on May 9, 2025

@jakevdp
Contributor

Thanks Jake! So for completeness, the stanza to locally run a test from the test suite is

$ JAX_ENABLE_FLOAT64=True ARRAY_API_TESTS_VERSION="2024.12" ARRAY_API_TESTS_MODULE=jax.numpy pytest path/to/test

Almost – the env variable is JAX_ENABLE_X64

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      Jax float64 precision issues do not play ball with hypothesis · Issue #368 · data-apis/array-api-tests