diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..e354eaae --- /dev/null +++ b/.flake8 @@ -0,0 +1,2 @@ +[flake8] +select = F diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ded906e6..9b49c09b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -6,15 +6,11 @@ jobs: build: runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.8, 3.9] - steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: "3.10" - name: Run pre-commit hook - uses: pre-commit/action@v2.0.3 + uses: pre-commit/action@v3.0.1 diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml deleted file mode 100644 index 74eeccf4..00000000 --- a/.github/workflows/numpy.yml +++ /dev/null @@ -1,78 +0,0 @@ -name: NumPy Array API - -on: [push, pull_request] - -jobs: - build: - - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.8, 3.9] - - steps: - - name: Checkout array-api-tests - uses: actions/checkout@v1 - with: - submodules: 'true' - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install numpy==1.22.1 - python -m pip install -r requirements.txt - - name: Run the test suite - env: - ARRAY_API_TESTS_MODULE: numpy.array_api - run: | - # Skip testing functions with known issues - cat << EOF >> skips.txt - - # copy not implemented - array_api_tests/test_creation_functions.py::test_asarray_arrays - # https://github.com/numpy/numpy/issues/20870 - array_api_tests/test_data_type_functions.py::test_can_cast - # The return dtype for trace is not consistent in the spec - # https://github.com/data-apis/array-api/issues/202#issuecomment-952529197 - array_api_tests/test_linalg.py::test_trace - # waiting on NumPy to allow/revert distinct NaNs for np.unique - # https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448 - array_api_tests/test_set_functions.py - - # https://github.com/numpy/numpy/issues/21373 - array_api_tests/test_array_object.py::test_getitem - - # missing copy arg - array_api_tests/test_signatures.py::test_func_signature[reshape] - - # https://github.com/numpy/numpy/issues/21211 - array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] - # https://github.com/numpy/numpy/issues/21213 - array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] - array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] - # noted diversions from spec - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] - array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] - array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] - array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] - array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] - array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] - array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] - array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] - array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] - array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] - array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] - array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] - array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - - EOF - - pytest -v -rxXfE --ci diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..2fab2072 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,34 @@ +name: Test Array API Strict + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11"] + + steps: + - name: Checkout array-api-tests + uses: actions/checkout@v1 + with: + submodules: 'true' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install array-api-strict + python -m pip install -r requirements.txt + - name: Run the test suite + env: + ARRAY_API_TESTS_MODULE: array_api_strict + ARRAY_API_STRICT_API_VERSION: 2024.12 + run: | + pytest -v -rxXfE --skips-file array-api-strict-skips.txt array_api_tests/ + # We also have internal tests that isn't really necessary for adopters + pytest -v -rxXfE meta_tests/ diff --git a/.gitignore b/.gitignore index 49d7dca5..fc5b8b8a 100644 --- a/.gitignore +++ b/.gitignore @@ -117,6 +117,10 @@ venv.bak/ # Rope project settings .ropeproject +# IDE +.idea/ +.vscode/ + # mkdocs documentation /site diff --git a/README.md b/README.md index 9eebc397..fa17b763 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@ This is the test suite for array libraries adopting the [Python Array API standard](https://data-apis.org/array-api/latest). -Note the suite is still a **work in progress**. Feedback and contributions are -welcome! +Keeping full coverage of the spec is an on-going priority as the Array API evolves. +Feedback and contributions are welcome! ## Quickstart @@ -33,17 +33,29 @@ You need to specify the array library to test. It can be specified via the `ARRAY_API_TESTS_MODULE` environment variable, e.g. ```bash -$ export ARRAY_API_TESTS_MODULE=numpy.array_api +$ export ARRAY_API_TESTS_MODULE=array_api_strict ``` -Alternately, change the `array_module` variable in `array_api_tests/_array_module.py` -line, e.g. +To specify a runtime-defined module, define `xp` using the `exec('...')` syntax: -```diff -- array_module = None -+ import numpy.array_api as array_module +```bash +$ export ARRAY_API_TESTS_MODULE="exec('import quantity_array, numpy; xp = quantity_array.quantity_namespace(numpy)')" ``` +Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`. + +### Specifying the API version + +You can specify the API version to use when testing via the +`ARRAY_API_TESTS_VERSION` environment variable, e.g. + +```bash +$ export ARRAY_API_TESTS_VERSION="2023.12" +``` + +Currently this defaults to the array module's `__array_api_version__` value, and +if that attribute doesn't exist then we fallback to `"2021.12"`. + ### Run the suite Simply run `pytest` against the `array_api_tests/` folder to run the full suite. @@ -144,9 +156,9 @@ issues](https://github.com/data-apis/array-api-tests/issues/) to us. ## Running on CI -See our existing [GitHub Actions workflow for -Numpy](https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/numpy.yml) -for an example of using the test suite on CI. +See our existing [GitHub Actions workflow for `array-api-strict`](https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/test.yml) +for an example of using the test suite on CI. Note [`array-api-strict`](https://github.com/data-apis/array-api-strict) +is an implementation of the array API that uses NumPy under the hood. ### Releases @@ -160,12 +172,6 @@ library to fail. ### Configuration -#### CI flag - -Use the `--ci` flag to run only the primary and special cases tests. You can -ignore the other test cases as they are redundant for the purposes of checking -compliance. - #### Data-dependent shapes Use the `--disable-data-dependent-shapes` flag to skip testing functions which have @@ -178,13 +184,22 @@ By default, tests for the optional Array API extensions such as will be skipped if not present in the specified array module. You can purposely skip testing extension(s) via the `--disable-extension` option. -#### Skip test cases +#### Skip or XFAIL test cases + +Test cases you want to skip can be specified in a skips or XFAILS file. The +difference between skip and XFAIL is that XFAIL tests are still run and +reported as XPASS if they pass. -Test cases you want to skip can be specified in a `skips.txt` file in the root -of this repository, e.g.: +By default, the skips and xfails files are `skips.txt` and `fails.txt` in the root +of this repository, but any file can be specified with the `--skips-file` and +`--xfails-file` command line flags. + +The files should list the test ids to be skipped/xfailed. Empty lines and +lines starting with `#` are ignored. The test id can be any substring of the +test ids to skip/xfail. ``` -# ./skips.txt +# skips.txt or xfails.txt # Line comments can be denoted with the hash symbol (#) # Skip specific test case, e.g. when argsort() does not respect relative order @@ -200,39 +215,114 @@ array_api_tests/test_add[__iadd__(x, s)] array_api_tests/test_set_functions.py ``` -For GitHub Actions, you might like to keep everything in the workflow config -instead of having a seperate `skips.txt` file, e.g.: +Here is an example GitHub Actions workflow file, where the xfails are stored +in `array-api-tests.xfails.txt` in the base of the `your-array-library` repo. + +If you want, you can use `-o xfail_strict=True`, which causes XPASS tests (XFAIL +tests that actually pass) to fail the test suite. However, be aware that +XFAILures can be flaky (see below, so this may not be a good idea unless you +use some other mitigation of such flakyness). + +If you don't want this behavior, you can remove it, or use `--skips-file` +instead of `--xfails-file`. ```yaml # ./.github/workflows/array_api.yml -... - ... - - name: Run the test suite +jobs: + tests: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11'] + + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + path: your-array-library + + - name: Checkout array-api-tests + uses: actions/checkout@v3 + with: + repository: data-apis/array-api-tests + submodules: 'true' + path: array-api-tests + + - name: Run the array API test suite env: ARRAY_API_TESTS_MODULE: your.array.api.namespace run: | - # Skip test cases with known issues - cat << EOF >> skips.txt - - # Comments can still work here - array_api_tests/test_sorting_functions.py::test_argsort - array_api_tests/test_add[__iadd__(x1, x2)] - array_api_tests/test_add[__iadd__(x, s)] - array_api_tests/test_set_functions.py - - EOF - - pytest -v -rxXfE --ci + export PYTHONPATH="${GITHUB_WORKSPACE}/your-array-library" + cd ${GITHUB_WORKSPACE}/array-api-tests + pytest -v -rxXfE --ci --xfails-file ${GITHUB_WORKSPACE}/your-array-library/array-api-tests-xfails.txt array_api_tests/ ``` +> **Warning** +> +> XFAIL tests that use Hypothesis (basically every test in the test suite except +> those in test_has_names.py) can be flaky, due to the fact that Hypothesis +> might not always run the test with an input that causes the test to fail. +> There are several ways to avoid this problem: +> +> - Increase the maximum number of examples, e.g., by adding `--max-examples +> 200` to the test command (the default is `20`, see below). This will +> make it more likely that the failing case will be found, but it will also +> make the tests take longer to run. +> - Don't use `-o xfail_strict=True`. This will make it so that if an XFAIL +> test passes, it will alert you in the test summary but will not cause the +> test run to register as failed. +> - Use skips instead of XFAILS. The difference between XFAIL and skip is that +> a skipped test is never run at all, whereas an XFAIL test is always run +> but ignored if it fails. +> - Save the [Hypothesis examples +> database](https://hypothesis.readthedocs.io/en/latest/database.html) +> persistently on CI. That way as soon as a run finds one failing example, +> it will always re-run future runs with that example. But note that the +> Hypothesis examples database may be cleared when a new version of +> Hypothesis or the test suite is released. + #### Max examples The tests make heavy use [Hypothesis](https://hypothesis.readthedocs.io/en/latest/). You can configure -how many examples are generated using the `--max-examples` flag, which defaults -to 100. Lower values can be useful for quick checks, and larger values should -result in more rigorous runs. For example, `--max-examples 10_000` may find bugs -where default runs don't but will take much longer to run. +how many examples are generated using the `--max-examples` flag, which +defaults to `20`. Lower values can be useful for quick checks, and larger +values should result in more rigorous runs. For example, `--max-examples +10_000` may find bugs where default runs don't but will take much longer to +run. + +#### Skipping Dtypes + +The test suite will automatically skip testing of inessential dtypes if they +are not present on the array module namespace, but dtypes can also be skipped +manually by setting the environment variable `ARRAY_API_TESTS_SKIP_DTYPES` to +a comma separated list of dtypes to skip. For example + +``` +ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 pytest array_api_tests/ +``` + +Note that skipping certain essential dtypes such as `bool` and the default +floating-point dtype is not supported. + +#### Turning xfails into skips + +Keeping a large number of ``xfails`` can have drastic effects on the run time. This is due +to the way `hypothesis` works: when it detects a failure, it does a large amount +of work to simplify the failing example. +If the run time of the test suite becomes a problem, you can use the +``ARRAY_API_TESTS_XFAIL_MARK`` environment variable: setting it to ``skip`` skips the +entries from the ``xfail.txt`` file instead of xfailing them. Anecdotally, we saw +speed-ups by a factor of 4-5---which allowed us to use 4-5 larger values of +``--max-examples`` within the same time budget. + +#### Limiting the array sizes + +The test suite generates random arrays as inputs to functions it tests. "unvectorized" +tests iterate over elements of arrays, which might be slow. If the run time becomes +a problem, you can limit the maximum number of elements in generated arrays by +setting the environment variable ``ARRAY_API_TESTS_MAX_ARRAY_SIZE`` to the +desired value. By default, it is set to 1024. ## Contributing @@ -303,26 +393,6 @@ into a release. If you want, you can add release notes, which GitHub can generate for you. -## Future plans - -Keeping full coverage of the spec is an on-going priority as the Array API -evolves. - -Additionally, we have features and general improvements planned. Work on such -functionality is guided primarily by the concerete needs of developers -implementing and using the Array API—be sure to [let us -know](https://github.com/data-apis/array-api-tests/issues) any limitations you -come across. - -* A dependency graph for every test case, which could be used to modify pytest's - collection so that low-dependency tests are run first, and tests with faulty - dependencies would skip/xfail. - -* In some tests we've found it difficult to find appropaite assertion parameters - for output values (particularly epsilons for floating-point outputs), so we - need to review these and either implement assertions or properly note the lack - thereof. - --- 1The only exceptions to having just one primary test per function are: diff --git a/array-api b/array-api index c5808f2b..772fb461 160000 --- a/array-api +++ b/array-api @@ -1 +1 @@ -Subproject commit c5808f2b173ea52d813c450bec7b1beaf2973299 +Subproject commit 772fb461da6ff904ecfcac4a24676e40efcbdb0c diff --git a/array-api-strict-skips.txt b/array-api-strict-skips.txt new file mode 100644 index 00000000..afc1b845 --- /dev/null +++ b/array-api-strict-skips.txt @@ -0,0 +1,34 @@ +# Known special case issue in NumPy. Not worth working around here +# https://github.com/numpy/numpy/issues/21213 +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] + +# The test suite is incorrectly checking sums that have loss of significance +# (https://github.com/data-apis/array-api-tests/issues/168) +array_api_tests/test_statistical_functions.py::test_sum + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, all libraries do just that +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] +array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# FIXME needs array-api-strict >=2.3.2 +array_api_tests/test_data_type_functions.py::test_finfo +array_api_tests/test_data_type_functions.py::test_finfo_dtype +array_api_tests/test_data_type_functions.py::test_iinfo +array_api_tests/test_data_type_functions.py::test_iinfo_dtype diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index c472b862..d01af52d 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -1,11 +1,55 @@ +import os from functools import wraps +from importlib import import_module from hypothesis import strategies as st from hypothesis.extra import array_api -from ._array_module import mod as _xp +from . import _version -__all__ = ["xps"] +__all__ = ["xp", "api_version", "xps"] + + +# You can comment the following out and instead import the specific array module +# you want to test, e.g. `import array_api_strict as xp`. +if "ARRAY_API_TESTS_MODULE" in os.environ: + env_var = os.environ["ARRAY_API_TESTS_MODULE"] + if env_var.startswith("exec('") and env_var.endswith("')"): + script = env_var[6:][:-2] + namespace = {} + exec(script, namespace) + xp = namespace["xp"] + xp_name = xp.__name__ + else: + xp_name = env_var + _module, _sub = xp_name, None + if "." in xp_name: + _module, _sub = xp_name.split(".", 1) + xp = import_module(_module) + if _sub: + try: + xp = getattr(xp, _sub) + except AttributeError: + # _sub may be a submodule that needs to be imported. We can't + # do this in every case because some array modules are not + # submodules that can be imported (like mxnet.nd). + xp = import_module(xp_name) +else: + raise RuntimeError( + "No array module specified - either edit __init__.py or set the " + "ARRAY_API_TESTS_MODULE environment variable." + ) + + +# If xp.bool is not available, like in some versions of NumPy and CuPy, try +# patching in xp.bool_. +try: + xp.bool +except AttributeError as e: + if hasattr(xp, "bool_"): + xp.bool = xp.bool_ + else: + raise e # We monkey patch floats() to always disable subnormals as they are out-of-scope @@ -41,9 +85,9 @@ def _from_dtype(*a, **kw): pass -xps = array_api.make_strategies_namespace(_xp, api_version="2021.12") - - -from . import _version +api_version = os.getenv( + "ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2024.12") +) +xps = array_api.make_strategies_namespace(xp, api_version=api_version) __version__ = _version.get_versions()["version"] diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py index e83cd6ca..1c52a983 100644 --- a/array_api_tests/_array_module.py +++ b/array_api_tests/_array_module.py @@ -1,35 +1,5 @@ -import os -from importlib import import_module +from . import stubs, xp -from . import stubs - -# Replace this with a specific array module to test it, for example, -# -# import numpy as array_module -array_module = None - -if array_module is None: - if 'ARRAY_API_TESTS_MODULE' in os.environ: - mod_name = os.environ['ARRAY_API_TESTS_MODULE'] - _module, _sub = mod_name, None - if '.' in mod_name: - _module, _sub = mod_name.split('.', 1) - mod = import_module(_module) - if _sub: - try: - mod = getattr(mod, _sub) - except AttributeError: - # _sub may be a submodule that needs to be imported. WE can't - # do this in every case because some array modules are not - # submodules that can be imported (like mxnet.nd). - mod = import_module(mod_name) - else: - raise RuntimeError("No array module specified. Either edit _array_module.py or set the ARRAY_API_TESTS_MODULE environment variable") -else: - mod = array_module - mod_name = mod.__name__ -# Names from the spec. This is what should actually be imported from this -# file. class _UndefinedStub: """ @@ -45,7 +15,7 @@ def __init__(self, name): self.name = name def _raise(self, *args, **kwargs): - raise AssertionError(f"{self.name} is not defined in {mod_name}") + raise AssertionError(f"{self.name} is not defined in {xp.__name__}") def __repr__(self): return f"" @@ -58,13 +28,15 @@ def __repr__(self): "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "float32", "float64", + "complex64", "complex128", ] _constants = ["e", "inf", "nan", "pi"] _funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs] -_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS +_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout +_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"] for attr in _top_level_attrs: try: - globals()[attr] = getattr(mod, attr) + globals()[attr] = getattr(xp, attr) except AttributeError: globals()[attr] = _UndefinedStub(attr) diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py index ef4f719a..a74dab24 100644 --- a/array_api_tests/array_helpers.py +++ b/array_api_tests/array_helpers.py @@ -1,5 +1,5 @@ from ._array_module import (isnan, all, any, equal, not_equal, logical_and, - logical_or, isfinite, greater, less, less_equal, + logical_or, isfinite, greater, less_equal, zeros, ones, full, bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, nan, inf, pi, remainder, divide, isinf, @@ -7,9 +7,9 @@ # These are exported here so that they can be included in the special cases # tests from this file. from ._array_module import logical_not, subtract, floor, ceil, where +from . import _array_module as xp from . import dtype_helpers as dh - __all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less', 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil', 'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN', @@ -164,7 +164,17 @@ def notequal(x, y): return not_equal(x, y) -def assert_exactly_equal(x, y): +def less(x, y): + """ + Same as less(x, y) except it allows comparing uint64 with signed int dtypes + """ + if x.dtype == uint64 and dh.dtype_signed[y.dtype]: + return xp.where(y < 0, xp.asarray(False), xp.less(x, xp.astype(y, uint64))) + if y.dtype == uint64 and dh.dtype_signed[x.dtype]: + return xp.where(x < 0, xp.asarray(True), xp.less(xp.astype(x, uint64), y)) + return xp.less(x, y) + +def assert_exactly_equal(x, y, msg_extra=None): """ Test that the arrays x and y are exactly equal. @@ -172,11 +182,13 @@ def assert_exactly_equal(x, y): equal. """ - assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape})" + extra = '' if not msg_extra else f' ({msg_extra})' + + assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}" - assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype})" + assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}" - assert all(exactly_equal(x, y)), "The input arrays have different values" + assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}" def assert_finite(x): """ @@ -306,3 +318,13 @@ def same_sign(x, y): def assert_same_sign(x, y): assert all(same_sign(x, y)), "The input arrays do not have the same sign" +def _matrix_transpose(x): + if not isinstance(xp.matrix_transpose, xp._UndefinedStub): + return xp.matrix_transpose(x) + if hasattr(x, 'mT'): + return x.mT + if not isinstance(xp.permute_dims, xp._UndefinedStub): + perm = list(range(x.ndim)) + perm[-1], perm[-2] = perm[-2], perm[-1] + return xp.permute_dims(x, axes=tuple(perm)) + raise NotImplementedError("No way to compute matrix transpose") diff --git a/array_api_tests/conftest.py b/array_api_tests/conftest.py deleted file mode 100644 index 092a2961..00000000 --- a/array_api_tests/conftest.py +++ /dev/null @@ -1,16 +0,0 @@ -import re - -import pytest - -r_int_pow_promotion = re.compile(r"test.+promotion\[.*pow.*\(u?int.+\]") - - -def pytest_collection_modifyitems(config, items): - """Skips the faulty integer type promotion tests for pow-related functions""" - for item in items: - if r_int_pow_promotion.match(item.name): - item.add_marker( - pytest.mark.skip( - reason="faulty test logic - negative exponents generated" - ) - ) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index 1527611c..f7fa306b 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -1,35 +1,49 @@ +import os import re +from collections import defaultdict from collections.abc import Mapping from functools import lru_cache -from inspect import signature -from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union +from typing import Any, DefaultDict, Dict, List, NamedTuple, Sequence, Tuple, Union from warnings import warn -from . import _array_module as xp -from ._array_module import _UndefinedStub +from . import api_version +from . import xp from .stubs import name_to_func from .typing import DataType, ScalarType __all__ = [ + "uint_names", + "int_names", + "all_int_names", + "real_float_names", + "real_names", + "complex_names", + "numeric_names", + "dtype_names", "int_dtypes", "uint_dtypes", "all_int_dtypes", - "float_dtypes", + "real_float_dtypes", + "real_dtypes", "numeric_dtypes", "all_dtypes", - "dtype_to_name", + "all_float_dtypes", "bool_and_all_int_dtypes", - "dtype_to_scalars", + "dtype_to_name", + "kind_to_dtypes", "is_int_dtype", "is_float_dtype", "get_scalar_type", + "is_scalar", "dtype_ranges", "default_int", "default_uint", "default_float", + "default_complex", "promotion_table", "dtype_nbits", "dtype_signed", + "dtype_components", "func_in_dtypes", "func_returns_bool", "binary_op_to_symbol", @@ -83,92 +97,226 @@ def __repr__(self): return f"EqualityMapping({self})" -_uint_names = ("uint8", "uint16", "uint32", "uint64") -_int_names = ("int8", "int16", "int32", "int64") -_float_names = ("float32", "float64") -_dtype_names = ("bool",) + _uint_names + _int_names + _float_names +uint_names = ("uint8", "uint16", "uint32", "uint64") +int_names = ("int8", "int16", "int32", "int64") +all_int_names = uint_names + int_names +real_float_names = ("float32", "float64") +real_names = uint_names + int_names + real_float_names +complex_names = ("complex64", "complex128") +numeric_names = real_names + complex_names +dtype_names = ("bool",) + numeric_names + +_skip_dtypes = os.getenv("ARRAY_API_TESTS_SKIP_DTYPES", '') +_skip_dtypes = _skip_dtypes.split(',') +skip_dtypes = [] +for dtype in _skip_dtypes: + if dtype and dtype not in dtype_names: + raise ValueError(f"Invalid dtype name in ARRAY_API_TESTS_SKIP_DTYPES: {dtype}") + skip_dtypes.append(dtype) + +_name_to_dtype = {} +for name in dtype_names: + if name in skip_dtypes: + continue + try: + dtype = getattr(xp, name) + except AttributeError: + continue + _name_to_dtype[name] = dtype +dtype_to_name = EqualityMapping([(d, n) for n, d in _name_to_dtype.items()]) -uint_dtypes = tuple(getattr(xp, name) for name in _uint_names) -int_dtypes = tuple(getattr(xp, name) for name in _int_names) -float_dtypes = tuple(getattr(xp, name) for name in _float_names) +def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]: + dtypes = [] + for name in names: + try: + dtype = _name_to_dtype[name] + except KeyError: + continue + dtypes.append(dtype) + return tuple(dtypes) + + +uint_dtypes = _make_dtype_tuple_from_names(uint_names) +int_dtypes = _make_dtype_tuple_from_names(int_names) +real_float_dtypes = _make_dtype_tuple_from_names(real_float_names) all_int_dtypes = uint_dtypes + int_dtypes -numeric_dtypes = all_int_dtypes + float_dtypes +real_dtypes = all_int_dtypes + real_float_dtypes +complex_dtypes = _make_dtype_tuple_from_names(complex_names) +numeric_dtypes = real_dtypes +if api_version > "2021.12": + numeric_dtypes += complex_dtypes all_dtypes = (xp.bool,) + numeric_dtypes +all_float_dtypes = real_float_dtypes +if api_version > "2021.12": + all_float_dtypes += complex_dtypes bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes -dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names]) - +kind_to_dtypes = { + "bool": [xp.bool], + "signed integer": int_dtypes, + "unsigned integer": uint_dtypes, + "integral": all_int_dtypes, + "real floating": real_float_dtypes, + "complex floating": complex_dtypes, + "numeric": numeric_dtypes, +} -dtype_to_scalars = EqualityMapping( - [ - (xp.bool, [bool]), - *[(d, [int]) for d in all_int_dtypes], - *[(d, [int, float]) for d in float_dtypes], - ] -) +def available_kinds(): + return { + kind for kind, dtypes in kind_to_dtypes.items() if dtypes + } def is_int_dtype(dtype): return dtype in all_int_dtypes -def is_float_dtype(dtype): +def is_float_dtype(dtype, *, include_complex=True): # None equals NumPy's xp.float64 object, so we specifically check it here. # xp.float64 is in fact an alias of np.dtype('float64'), and its equality # with None is meant to be deprecated at some point. # See https://github.com/numpy/numpy/issues/18434 if dtype is None: return False - return dtype in float_dtypes - + valid_dtypes = real_float_dtypes + if api_version > "2021.12" and include_complex: + valid_dtypes += complex_dtypes + return dtype in valid_dtypes def get_scalar_type(dtype: DataType) -> ScalarType: - if is_int_dtype(dtype): + if dtype in all_int_dtypes: return int - elif is_float_dtype(dtype): + elif dtype in real_float_dtypes: return float + elif dtype in complex_dtypes: + return complex else: return bool +def is_scalar(x): + return isinstance(x, (int, float, complex, bool)) + + +def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping: + dtype_value_pairs = [] + for name, value in mapping.items(): + assert isinstance(name, str) and name in dtype_names # sanity check + if name in _name_to_dtype: + dtype = _name_to_dtype[name] + else: + continue + dtype_value_pairs.append((dtype, value)) + return EqualityMapping(dtype_value_pairs) + class MinMax(NamedTuple): min: Union[int, float] max: Union[int, float] - -dtype_ranges = EqualityMapping( - [ - (xp.int8, MinMax(-128, +127)), - (xp.int16, MinMax(-32_768, +32_767)), - (xp.int32, MinMax(-2_147_483_648, +2_147_483_647)), - (xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)), - (xp.uint8, MinMax(0, +255)), - (xp.uint16, MinMax(0, +65_535)), - (xp.uint32, MinMax(0, +4_294_967_295)), - (xp.uint64, MinMax(0, +18_446_744_073_709_551_615)), - (xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)), - (xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)), - ] + def __contains__(self, other): + assert isinstance(other, (int, float)) + return self.min <= other <= self.max + +dtype_ranges = _make_dtype_mapping_from_names( + { + "int8": MinMax(-128, +127), + "int16": MinMax(-32_768, +32_767), + "int32": MinMax(-2_147_483_648, +2_147_483_647), + "int64": MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807), + "uint8": MinMax(0, +255), + "uint16": MinMax(0, +65_535), + "uint32": MinMax(0, +4_294_967_295), + "uint64": MinMax(0, +18_446_744_073_709_551_615), + "float32": MinMax(-3.4028234663852886e38, 3.4028234663852886e38), + "float64": MinMax(-1.7976931348623157e308, 1.7976931348623157e308), + } ) -dtype_nbits = EqualityMapping( - [(d, 8) for d in [xp.int8, xp.uint8]] - + [(d, 16) for d in [xp.int16, xp.uint16]] - + [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]] - + [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]] + +r_nbits = re.compile(r"[a-z]+([0-9]+)") +_dtype_nbits: Dict[str, int] = {} +for name in numeric_names: + m = r_nbits.fullmatch(name) + assert m is not None # sanity check / for mypy + _dtype_nbits[name] = int(m.group(1)) +dtype_nbits = _make_dtype_mapping_from_names(_dtype_nbits) + + +dtype_signed = _make_dtype_mapping_from_names( + {**{name: True for name in int_names}, **{name: False for name in uint_names}} ) -dtype_signed = EqualityMapping( - [(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes] +dtype_components = _make_dtype_mapping_from_names( + {"complex64": xp.float32, "complex128": xp.float64} ) +def as_real_dtype(dtype): + """ + Return the corresponding real dtype for a given floating-point dtype. + """ + if dtype in real_float_dtypes: + return dtype + elif dtype_to_name[dtype] in complex_names: + return dtype_components[dtype] + else: + raise ValueError("as_real_dtype requires a floating-point dtype") + +def accumulation_result_dtype(x_dtype, dtype_kwarg): + """ + Result dtype logic for sum(), prod(), and trace() + + Note: may return None if a default uint cannot exist (e.g., for pytorch + which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106 -if isinstance(xp.asarray, _UndefinedStub): + """ + if dtype_kwarg is None: + if is_int_dtype(x_dtype): + if x_dtype in uint_dtypes: + default_dtype = default_uint + else: + default_dtype = default_int + if default_dtype is None: + _dtype = None + else: + m, M = dtype_ranges[x_dtype] + d_m, d_M = dtype_ranges[default_dtype] + if m < d_m or M > d_M: + _dtype = x_dtype + else: + _dtype = default_dtype + elif api_version >= '2023.12': + # Starting in 2023.12, floats should not promote with dtype=None + _dtype = x_dtype + elif is_float_dtype(x_dtype, include_complex=False): + if dtype_nbits[x_dtype] > dtype_nbits[default_float]: + _dtype = x_dtype + else: + _dtype = default_float + elif api_version > "2021.12": + # Complex dtype + if dtype_nbits[x_dtype] > dtype_nbits[default_complex]: + _dtype = x_dtype + else: + _dtype = default_complex + else: + raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.") + else: + _dtype = dtype_kwarg + + return _dtype + +if not hasattr(xp, "asarray"): default_int = xp.int32 default_float = xp.float32 + # TODO: when api_version > '2021.12', just assign to xp.complex64, + # otherwise default to None. Need array-api spec to be bumped first (#187). + try: + default_complex = xp.complex64 + except AttributeError: + default_complex = None warn( "array module does not have attribute asarray. " "default int is assumed int32, default float is assumed float32" @@ -178,59 +326,80 @@ class MinMax(NamedTuple): if default_int not in int_dtypes: warn(f"inferred default int is {default_int!r}, which is not an int") default_float = xp.asarray(float()).dtype - if default_float not in float_dtypes: + if default_float not in real_float_dtypes: warn(f"inferred default float is {default_float!r}, which is not a float") + if api_version > "2021.12" and ({'complex64', 'complex128'} - set(skip_dtypes)): + default_complex = xp.asarray(complex()).dtype + if default_complex not in complex_dtypes: + warn( + f"inferred default complex is {default_complex!r}, " + "which is not a complex" + ) + else: + default_complex = None + if dtype_nbits[default_int] == 32: - default_uint = xp.uint32 + default_uint = _name_to_dtype.get("uint32") else: - default_uint = xp.uint64 - + default_uint = _name_to_dtype.get("uint64") -_numeric_promotions = [ +_promotion_table: Dict[Tuple[str, str], str] = { + ("bool", "bool"): "bool", # ints - ((xp.int8, xp.int8), xp.int8), - ((xp.int8, xp.int16), xp.int16), - ((xp.int8, xp.int32), xp.int32), - ((xp.int8, xp.int64), xp.int64), - ((xp.int16, xp.int16), xp.int16), - ((xp.int16, xp.int32), xp.int32), - ((xp.int16, xp.int64), xp.int64), - ((xp.int32, xp.int32), xp.int32), - ((xp.int32, xp.int64), xp.int64), - ((xp.int64, xp.int64), xp.int64), + ("int8", "int8"): "int8", + ("int8", "int16"): "int16", + ("int8", "int32"): "int32", + ("int8", "int64"): "int64", + ("int16", "int16"): "int16", + ("int16", "int32"): "int32", + ("int16", "int64"): "int64", + ("int32", "int32"): "int32", + ("int32", "int64"): "int64", + ("int64", "int64"): "int64", # uints - ((xp.uint8, xp.uint8), xp.uint8), - ((xp.uint8, xp.uint16), xp.uint16), - ((xp.uint8, xp.uint32), xp.uint32), - ((xp.uint8, xp.uint64), xp.uint64), - ((xp.uint16, xp.uint16), xp.uint16), - ((xp.uint16, xp.uint32), xp.uint32), - ((xp.uint16, xp.uint64), xp.uint64), - ((xp.uint32, xp.uint32), xp.uint32), - ((xp.uint32, xp.uint64), xp.uint64), - ((xp.uint64, xp.uint64), xp.uint64), + ("uint8", "uint8"): "uint8", + ("uint8", "uint16"): "uint16", + ("uint8", "uint32"): "uint32", + ("uint8", "uint64"): "uint64", + ("uint16", "uint16"): "uint16", + ("uint16", "uint32"): "uint32", + ("uint16", "uint64"): "uint64", + ("uint32", "uint32"): "uint32", + ("uint32", "uint64"): "uint64", + ("uint64", "uint64"): "uint64", # ints and uints (mixed sign) - ((xp.int8, xp.uint8), xp.int16), - ((xp.int8, xp.uint16), xp.int32), - ((xp.int8, xp.uint32), xp.int64), - ((xp.int16, xp.uint8), xp.int16), - ((xp.int16, xp.uint16), xp.int32), - ((xp.int16, xp.uint32), xp.int64), - ((xp.int32, xp.uint8), xp.int32), - ((xp.int32, xp.uint16), xp.int32), - ((xp.int32, xp.uint32), xp.int64), - ((xp.int64, xp.uint8), xp.int64), - ((xp.int64, xp.uint16), xp.int64), - ((xp.int64, xp.uint32), xp.int64), + ("int8", "uint8"): "int16", + ("int8", "uint16"): "int32", + ("int8", "uint32"): "int64", + ("int16", "uint8"): "int16", + ("int16", "uint16"): "int32", + ("int16", "uint32"): "int64", + ("int32", "uint8"): "int32", + ("int32", "uint16"): "int32", + ("int32", "uint32"): "int64", + ("int64", "uint8"): "int64", + ("int64", "uint16"): "int64", + ("int64", "uint32"): "int64", # floats - ((xp.float32, xp.float32), xp.float32), - ((xp.float32, xp.float64), xp.float64), - ((xp.float64, xp.float64), xp.float64), -] -_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions] -_promotion_table = list(set(_numeric_promotions)) -_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool)) -promotion_table = EqualityMapping(_promotion_table) + ("float32", "float32"): "float32", + ("float32", "float64"): "float64", + ("float64", "float64"): "float64", + # complex + ("complex64", "complex64"): "complex64", + ("complex64", "complex128"): "complex128", + ("complex128", "complex128"): "complex128", +} +_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()}) +_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = [] +for (in_name1, in_name2), res_name in _promotion_table.items(): + if in_name1 not in _name_to_dtype or in_name2 not in _name_to_dtype or res_name not in _name_to_dtype: + continue + in_dtype1 = _name_to_dtype[in_name1] + in_dtype2 = _name_to_dtype[in_name2] + res_dtype = _name_to_dtype[res_name] + + _promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype)) +promotion_table = EqualityMapping(_promotion_table_pairs) def result_type(*dtypes: DataType): @@ -253,22 +422,22 @@ def result_type(*dtypes: DataType): category_to_dtypes = { "boolean": (xp.bool,), "integer": all_int_dtypes, - "floating-point": float_dtypes, + "floating-point": real_float_dtypes, + "real-valued": real_float_dtypes, + "real-valued floating-point": real_float_dtypes, + "complex floating-point": complex_dtypes, "numeric": numeric_dtypes, "integer or boolean": bool_and_all_int_dtypes, } -func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {} +func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes) for name, func in name_to_func.items(): + assert func.__doc__ is not None # for mypy if m := r_in_dtypes.search(func.__doc__): dtype_category = m.group(1) if dtype_category == "numeric" and r_int_note.search(func.__doc__): dtype_category = "floating-point" dtypes = category_to_dtypes[dtype_category] func_in_dtypes[name] = dtypes - elif any("x" in name for name in signature(func).parameters.keys()): - func_in_dtypes[name] = all_dtypes -# See https://github.com/data-apis/array-api/pull/413 -func_in_dtypes["expm1"] = float_dtypes func_returns_bool = { @@ -393,11 +562,10 @@ def result_type(*dtypes: DataType): } +# Construct func_in_dtypes and func_returns bool for op, elwise_func in op_to_func.items(): func_in_dtypes[op] = func_in_dtypes[elwise_func] func_returns_bool[op] = func_returns_bool[elwise_func] - - inplace_op_to_symbol = {} for op, symbol in binary_op_to_symbol.items(): if op == "__matmul__" or func_returns_bool[op]: @@ -406,12 +574,10 @@ def result_type(*dtypes: DataType): inplace_op_to_symbol[iop] = f"{symbol}=" func_in_dtypes[iop] = func_in_dtypes[op] func_returns_bool[iop] = func_returns_bool[op] - - func_in_dtypes["__bool__"] = (xp.bool,) func_in_dtypes["__int__"] = all_int_dtypes func_in_dtypes["__index__"] = all_int_dtypes -func_in_dtypes["__float__"] = float_dtypes +func_in_dtypes["__float__"] = real_float_dtypes func_in_dtypes["from_dlpack"] = numeric_dtypes func_in_dtypes["__dlpack__"] = numeric_dtypes diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 20cc0e03..e1df108c 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -1,52 +1,97 @@ -import itertools -from functools import reduce -from math import sqrt -from operator import mul -from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union +from __future__ import annotations -from hypothesis import assume +import os +import re +from contextlib import contextmanager +from functools import wraps +import math +import struct +from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union + +from hypothesis import assume, reject from hypothesis.strategies import (SearchStrategy, booleans, composite, floats, - integers, just, lists, none, one_of, - sampled_from, shared) + integers, complex_numbers, just, lists, none, one_of, + sampled_from, shared, builds, nothing, permutations) -from . import _array_module as xp +from . import _array_module as xp, api_version +from . import array_helpers as ah from . import dtype_helpers as dh from . import shape_helpers as sh from . import xps from ._array_module import _UndefinedStub from ._array_module import bool as bool_dtype -from ._array_module import broadcast_to, eye, float32, float64, full +from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128 from .stubs import category_to_funcs from .pytest_helpers import nargs -from .typing import Array, DataType, Shape +from .typing import Array, DataType, Scalar, Shape -# Set this to True to not fail tests just because a dtype isn't implemented. -# If no compatible dtype is implemented for a given test, the test will fail -# with a hypothesis health check error. Note that this functionality will not -# work for floating point dtypes as those are assumed to be defined in other -# places in the tests. -FILTER_UNDEFINED_DTYPES = True -integer_dtypes = sampled_from(dh.all_int_dtypes) -floating_dtypes = sampled_from(dh.float_dtypes) -numeric_dtypes = sampled_from(dh.numeric_dtypes) -integer_or_boolean_dtypes = sampled_from(dh.bool_and_all_int_dtypes) -boolean_dtypes = just(xp.bool) -dtypes = sampled_from(dh.all_dtypes) - -if FILTER_UNDEFINED_DTYPES: - integer_dtypes = integer_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - floating_dtypes = floating_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - numeric_dtypes = numeric_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - integer_or_boolean_dtypes = integer_or_boolean_dtypes.filter(lambda x: not - isinstance(x, _UndefinedStub)) - boolean_dtypes = boolean_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub)) - -shared_dtypes = shared(dtypes, key="dtype") -shared_floating_dtypes = shared(floating_dtypes, key="dtype") - -_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes] +def _float32ify(n: Union[int, float]) -> float: + n = float(n) + return struct.unpack("!f", struct.pack("!f", n))[0] + + +@wraps(xps.from_dtype) +def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]: + """xps.from_dtype() without the crazy large numbers.""" + if dtype == xp.bool: + return xps.from_dtype(dtype, **kwargs) + + if dtype in dh.complex_dtypes: + component_dtype = dh.dtype_components[dtype] + else: + component_dtype = dtype + + min_, max_ = dh.dtype_ranges[component_dtype] + + if "min_value" not in kwargs.keys() and min_ != 0: + assert min_ < 0 # sanity check + min_value = -1 * math.floor(math.sqrt(abs(min_))) + if component_dtype == xp.float32: + min_value = _float32ify(min_value) + kwargs["min_value"] = min_value + if "max_value" not in kwargs.keys(): + assert max_ > 0 # sanity check + max_value = math.floor(math.sqrt(max_)) + if component_dtype == xp.float32: + max_value = _float32ify(max_value) + kwargs["max_value"] = max_value + + if dtype in dh.complex_dtypes: + component_strat = xps.from_dtype(dh.dtype_components[dtype], **kwargs) + return builds(complex, component_strat, component_strat) + else: + return xps.from_dtype(dtype, **kwargs) + + +@wraps(xps.arrays) +def arrays_no_scalars(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]: + """xps.arrays() without the crazy large numbers.""" + if isinstance(dtype, SearchStrategy): + return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs)) + + if elements is None: + elements = from_dtype(dtype) + elif isinstance(elements, Mapping): + elements = from_dtype(dtype, **elements) + + return xps.arrays(dtype, *args, elements=elements, **kwargs) + + +def _f(a, flag): + return a[()] if a.ndim==0 and flag else a + + +@wraps(xps.arrays) +def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]: + """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars. + + Is only relevant for numpy: on all other libraries, array[()] is no-op. + """ + return builds(_f, arrays_no_scalars(dtype, *args, elements=elements, **kwargs), booleans()) + + +_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes] _sorted_dtypes = [d for category in _dtype_categories for d in category] def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]): @@ -67,11 +112,10 @@ def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]): return key _promotable_dtypes = list(dh.promotion_table.keys()) -if FILTER_UNDEFINED_DTYPES: - _promotable_dtypes = [ - (d1, d2) for d1, d2 in _promotable_dtypes - if not isinstance(d1, _UndefinedStub) or not isinstance(d2, _UndefinedStub) - ] +_promotable_dtypes = [ + (d1, d2) for d1, d2 in _promotable_dtypes + if not isinstance(d1, _UndefinedStub) or not isinstance(d2, _UndefinedStub) +] promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(_promotable_dtypes, key=_dtypes_sorter) def mutually_promotable_dtypes( @@ -79,9 +123,8 @@ def mutually_promotable_dtypes( *, dtypes: Sequence[DataType] = dh.all_dtypes, ) -> SearchStrategy[Tuple[DataType, ...]]: - if FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] - assert len(dtypes) > 0, "all dtypes undefined" # sanity check + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] + assert len(dtypes) > 0, "all dtypes undefined" # sanity check if max_size == 2: return sampled_from( [(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes] @@ -106,6 +149,74 @@ def mutually_promotable_dtypes( return one_of(strats).map(tuple) +@composite +def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes): + sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes)) + permuted = draw(permutations(sample)) + return sample, tuple(permuted) + + +class OnewayPromotableDtypes(NamedTuple): + input_dtype: DataType + result_dtype: DataType + + +@composite +def oneway_promotable_dtypes( + draw, dtypes: Sequence[DataType] +) -> OnewayPromotableDtypes: + """Return a strategy for input dtypes that promote to result dtypes.""" + d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes)) + result_dtype = dh.result_type(d1, d2) + if d1 == result_dtype: + return OnewayPromotableDtypes(d2, d1) + elif d2 == result_dtype: + return OnewayPromotableDtypes(d1, d2) + else: + reject() + + +class OnewayBroadcastableShapes(NamedTuple): + input_shape: Shape + result_shape: Shape + + +@composite +def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes: + """Return a strategy for input shapes that broadcast to result shapes.""" + result_shape = draw(shapes(min_side=1)) + input_shape = draw( + xps.broadcastable_shapes( + result_shape, + # Override defaults so bad shapes are less likely to be generated. + max_side=None if result_shape == () else max(result_shape), + max_dims=len(result_shape), + ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) + ) + return OnewayBroadcastableShapes(input_shape, result_shape) + + +# Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from +# ARRAY_API_TESTS_SKIP_DTYPES +all_dtypes = sampled_from(_sorted_dtypes) +int_dtypes = sampled_from(dh.all_int_dtypes) +uint_dtypes = sampled_from(dh.uint_dtypes) +real_dtypes = sampled_from(dh.real_dtypes) +# Warning: The hypothesis "floating_dtypes" is what we call +# "real_floating_dtypes" +floating_dtypes = sampled_from(dh.all_float_dtypes) +real_floating_dtypes = sampled_from(dh.real_float_dtypes) +numeric_dtypes = sampled_from(dh.numeric_dtypes) +# Note: this always returns complex dtypes, even if api_version < 2022.12 +complex_dtypes: SearchStrategy[Any] = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else nothing() + +def all_floating_dtypes() -> SearchStrategy[DataType]: + strat = floating_dtypes + if api_version >= "2022.12" and not complex_dtypes.is_empty: + strat |= complex_dtypes + return strat + + # shared() allows us to draw either the function or the function name and they # will both correspond to the same function. @@ -122,13 +233,10 @@ def mutually_promotable_dtypes( lambda i: getattr(xp, i)) # Limit the total size of an array shape -MAX_ARRAY_SIZE = 10000 +MAX_ARRAY_SIZE = int(os.environ.get("ARRAY_API_TESTS_MAX_ARRAY_SIZE", 1024)) # Size to use for 2-dim arrays -SQRT_MAX_ARRAY_SIZE = int(sqrt(MAX_ARRAY_SIZE)) +SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE)) -# np.prod and others have overflow and math.prod is Python 3.8+ only -def prod(seq): - return reduce(mul, seq, 1) # hypotheses.strategies.tuples only generates tuples of a fixed size def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False): @@ -137,13 +245,76 @@ def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False) # Use this to avoid memory errors with NumPy. # See https://github.com/numpy/numpy/issues/15753 +# Note, the hypothesis default for max_dims is min_dims + 2 (i.e., 0 + 2) def shapes(**kw): kw.setdefault('min_dims', 0) kw.setdefault('min_side', 0) return xps.array_shapes(**kw).filter( - lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE + lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE ) +def _factorize(n: int) -> List[int]: + # Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE + factors = [] + while n % 2 == 0: + factors.append(2) + n //= 2 + + for i in range(3, int(math.sqrt(n)) + 1, 2): + while n % i == 0: + factors.append(i) + n //= i + + if n > 1: # n is a prime number greater than 2 + factors.append(n) + + return factors + +MAX_SIDE = MAX_ARRAY_SIZE // 64 +# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs +MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32) + + +@composite +def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)): + """ + Generate shape tuples whose product equals the product of array_shape. + """ + shape = draw(arr_shape) + + array_size = math.prod(shape) + + n_dims = draw(ndims) + + # Handle special cases + if array_size == 0: + # Generate a random tuple, and ensure at least one of the entries is 0 + result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims))) + pos = draw(integers(0, n_dims - 1)) + result[pos] = 0 + return tuple(result) + + if array_size == 1: + return tuple(1 for _ in range(n_dims)) + + # Get prime factorization + factors = _factorize(array_size) + + # Distribute prime factors randomly + result = [1] * n_dims + for factor in factors: + pos = draw(integers(0, n_dims - 1)) + result[pos] *= factor + + assert math.prod(result) == array_size + + # An element of the reshape tuple can be -1, which means it is a stand-in + # for the remaining factors. + if draw(booleans()): + pos = draw(integers(0, n_dims - 1)) + result[pos] = -1 + + return tuple(result) one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE) @@ -153,14 +324,14 @@ def matrix_shapes(draw, stack_shapes=shapes()): stack_shape = draw(stack_shapes) mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2)) shape = stack_shape + mat_shape - assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE) + assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE) return shape square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2]) @composite def finite_matrices(draw, shape=matrix_shapes()): - return draw(xps.arrays(dtype=xps.floating_dtypes(), + return draw(arrays(dtype=floating_dtypes, shape=shape, elements=dict(allow_nan=False, allow_infinity=False))) @@ -169,7 +340,7 @@ def finite_matrices(draw, shape=matrix_shapes()): # Should we set a max_value here? _rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0) rtols = one_of(floats(**_rtol_float_kw), - xps.arrays(dtype=xps.floating_dtypes(), + arrays(dtype=real_floating_dtypes, shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]), elements=_rtol_float_kw)) @@ -198,52 +369,66 @@ def mutually_broadcastable_shapes( ) .map(lambda BS: BS.input_shapes) .filter(lambda shapes: all( - prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes + math.prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes )) ) two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2) -# Note: This should become hermitian_matrices when complex dtypes are added +# TODO: Add support for complex Hermitian matrices @composite -def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True): +def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.): + # for now, only generate elements from (1, bound); TODO: restore + # generating from (-bound, -1/bound).or.(1/bound, bound) + # Note that using `assume` triggers a HealthCheck for filtering too much. shape = draw(square_matrix_shapes) dtype = draw(dtypes) - elements = {'allow_nan': False, 'allow_infinity': False} if finite else None - a = draw(xps.arrays(dtype=dtype, shape=shape, elements=elements)) - upper = xp.triu(a) - lower = xp.triu(a, k=1).mT - return upper + lower + if not isinstance(finite, bool): + finite = draw(finite) + if finite: + elements = {'allow_nan': False, 'allow_infinity': False, + 'min_value': 1, 'max_value': bound} + else: + elements = None + a = draw(arrays(dtype=dtype, shape=shape, elements=elements)) + at = ah._matrix_transpose(a) + H = (a + at)*0.5 + if finite: + assume(not xp.any(xp.isinf(H))) + return H @composite -def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()): +def positive_definite_matrices(draw, dtypes=floating_dtypes): # For now just generate stacks of identity matrices # TODO: Generate arbitrary positive definite matrices, for instance, by # using something like # https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351. - n = draw(integers(0)) - shape = draw(shapes()) + (n, n) - assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE) + base_shape = draw(shapes()) + n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value + shape = base_shape + (n, n) + assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE) dtype = draw(dtypes) return broadcast_to(eye(n, dtype=dtype), shape) @composite -def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()): +def invertible_matrices(draw, dtypes=floating_dtypes, stack_shapes=shapes()): # For now, just generate stacks of diagonal matrices. - n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),) stack_shape = draw(stack_shapes) - shape = stack_shape + (n, n) - d = draw(xps.arrays(dtypes, shape=n*prod(stack_shape), - elements=dict(allow_nan=False, allow_infinity=False))) + n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),) + dtype = draw(dtypes) + elements = one_of( + from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False), + from_dtype(dtype, max_value=-0.5, allow_nan=False, allow_infinity=False), + ) + d = draw(arrays(dtype, shape=(*stack_shape, 1, n), elements=elements)) + # Functions that require invertible matrices may do anything when it is # singular, including raising an exception, so we make sure the diagonals # are sufficiently nonzero to avoid any numerical issues. - assume(xp.all(xp.abs(d) > 0.5)) + assert xp.all(xp.abs(d) >= 0.5) - a = xp.zeros(shape) - for j, (idx, i) in enumerate(itertools.product(sh.ndindex(stack_shape), range(n))): - a[idx + (i, i)] = d[j] - return a + diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1)) + return xp.where(diag_mask, d, xp.zeros_like(d)) # TODO: Better name @composite @@ -259,32 +444,44 @@ def two_broadcastable_shapes(draw): sizes = integers(0, MAX_ARRAY_SIZE) sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE) -numeric_arrays = xps.arrays( - dtype=shared(xps.floating_dtypes(), key='dtypes'), +numeric_arrays = arrays( + dtype=shared(floating_dtypes, key='dtypes'), shape=shared(xps.array_shapes(), key='shapes'), ) @composite -def scalars(draw, dtypes, finite=False): +def scalars(draw, dtypes, finite=False, **kwds): """ Strategy to generate a scalar that matches a dtype strategy dtypes should be one of the shared_* dtypes strategies. """ dtype = draw(dtypes) - if dtype in dh.dtype_ranges: - m, M = dh.dtype_ranges[dtype] + mM = kwds.pop('mM', None) + if dh.is_int_dtype(dtype): + if mM is None: + m, M = dh.dtype_ranges[dtype] + else: + m, M = mM return draw(integers(m, M)) elif dtype == bool_dtype: return draw(booleans()) elif dtype == float64: if finite: - return draw(floats(allow_nan=False, allow_infinity=False)) - return draw(floats()) + return draw(floats(allow_nan=False, allow_infinity=False, **kwds)) + return draw(floats(), **kwds) elif dtype == float32: if finite: - return draw(floats(width=32, allow_nan=False, allow_infinity=False)) - return draw(floats(width=32)) + return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds)) + return draw(floats(width=32, **kwds)) + elif dtype == complex64: + if finite: + return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False)) + return draw(complex_numbers(width=32)) + elif dtype == complex128: + if finite: + return draw(complex_numbers(allow_nan=False, allow_infinity=False)) + return draw(complex_numbers()) else: raise ValueError(f"Unrecognized dtype {dtype}") @@ -304,7 +501,7 @@ def python_integer_indices(draw, sizes): def integer_indices(draw, sizes): # Return either a Python integer or a 0-D array with some integer dtype idx = draw(python_integer_indices(sizes)) - dtype = draw(integer_dtypes) + dtype = draw(int_dtypes | uint_dtypes) m, M = dh.dtype_ranges[dtype] if m <= idx <= M: return draw(one_of(just(idx), @@ -380,21 +577,44 @@ def two_mutual_arrays( ) -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]: if not isinstance(dtypes, Sequence): raise TypeError(f"{dtypes=} not a sequence") - if FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] - assert len(dtypes) > 0 # sanity check + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] + assert len(dtypes) > 0 # sanity check mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes)) mutual_shapes = shared(two_shapes) - arrays1 = xps.arrays( + arrays1 = arrays( dtype=mutual_dtypes.map(lambda pair: pair[0]), shape=mutual_shapes.map(lambda pair: pair[0]), ) - arrays2 = xps.arrays( + arrays2 = arrays( dtype=mutual_dtypes.map(lambda pair: pair[1]), shape=mutual_shapes.map(lambda pair: pair[1]), ) return arrays1, arrays2 + +@composite +def array_and_py_scalar(draw, dtypes, mM=None, positive=False): + """Draw a pair: (array, scalar) or (scalar, array).""" + dtype = draw(sampled_from(dtypes)) + + scalar_var = draw(scalars(just(dtype), finite=True, mM=mM)) + if positive: + assume (scalar_var > 0) + + elements={} + if dtype in dh.real_float_dtypes: + elements = {'allow_nan': False, 'allow_infinity': False, + 'min_value': 1.0 / (2<<5), 'max_value': 2<<5} + if positive: + elements = {'min_value': 0} + array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements)) + + if draw(booleans()): + return scalar_var, array_var + else: + return array_var, scalar_var + + @composite def kwargs(draw, **kw): """ @@ -444,3 +664,14 @@ def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]: axes_strats.append(integers(-ndim, ndim - 1)) axes_strats.append(xps.valid_tuple_axes(ndim)) return one_of(axes_strats) + + +@contextmanager +def reject_overflow(): + try: + yield + except Exception as e: + if isinstance(e, OverflowError) or re.search("[Oo]verflow", str(e)): + reject() + else: + raise e diff --git a/array_api_tests/meta/test_array_helpers.py b/array_api_tests/meta/test_array_helpers.py deleted file mode 100644 index 68f96910..00000000 --- a/array_api_tests/meta/test_array_helpers.py +++ /dev/null @@ -1,19 +0,0 @@ -from .. import _array_module as xp -from ..array_helpers import exactly_equal, notequal - -# TODO: These meta-tests currently only work with NumPy - -def test_exactly_equal(): - a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) - b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) - - res = xp.asarray([True, False, True, False, True, False, True, False]) - assert xp.all(xp.equal(exactly_equal(a, b), res)) - -def test_notequal(): - a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) - b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) - - res = xp.asarray([False, True, False, False, False, True, False, True]) - assert xp.all(xp.equal(notequal(a, b), res)) - diff --git a/array_api_tests/meta/test_pytest_helpers.py b/array_api_tests/meta/test_pytest_helpers.py deleted file mode 100644 index a6851a15..00000000 --- a/array_api_tests/meta/test_pytest_helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -from pytest import raises - -from .. import _array_module as xp -from .. import pytest_helpers as ph - - -def test_assert_dtype(): - ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16) - with raises(AssertionError): - ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32) - ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool) - ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8) - ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool) - - -def test_assert_array_elements(): - ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0)) - ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0)) - with raises(AssertionError): - ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0)) - with raises(AssertionError): - ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0)) diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 051a063f..f6b7ae25 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -1,3 +1,4 @@ +import cmath import math from inspect import getfullargspec from typing import Any, Dict, Optional, Sequence, Tuple, Union @@ -6,6 +7,7 @@ from . import dtype_helpers as dh from . import shape_helpers as sh from . import stubs +from . import xp as _xp from .typing import Array, DataType, Scalar, ScalarType, Shape __all__ = [ @@ -81,10 +83,10 @@ def is_neg_zero(n: float) -> bool: def assert_dtype( func_name: str, + *, in_dtype: Union[DataType, Sequence[DataType]], out_dtype: DataType, expected: Optional[DataType] = None, - *, repr_name: str = "out.dtype", ): """ @@ -95,7 +97,7 @@ def assert_dtype( >>> x = xp.arange(5, dtype=xp.uint8) >>> out = xp.abs(x) - >>> assert_dtype('abs', x.dtype, out.dtype) + >>> assert_dtype('abs', in_dtype=x.dtype, out_dtype=out.dtype) is equivalent to @@ -107,7 +109,7 @@ def assert_dtype( >>> x1 = xp.arange(5, dtype=xp.uint8) >>> x2 = xp.arange(5, dtype=xp.uint16) >>> out = xp.add(x1, x2) - >>> assert_dtype('add', [x1.dtype, x2.dtype], out.dtype) + >>> assert_dtype('add', in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) is equivalent to @@ -118,9 +120,10 @@ def assert_dtype( >>> x = xp.arange(5, dtype=xp.int8) >>> out = xp.sum(x) >>> default_int = xp.asarray(0).dtype - >>> assert_dtype('sum', x, out.dtype, default_int) + >>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int) """ + __tracebackhide__ = True in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype] f_in_dtypes = dh.fmt_types(tuple(in_dtypes)) f_out_dtype = dh.dtype_to_name[out_dtype] @@ -134,15 +137,49 @@ def assert_dtype( assert out_dtype == expected, msg -def assert_kw_dtype(func_name: str, kw_dtype: DataType, out_dtype: DataType): +def assert_float_to_complex_dtype( + func_name: str, *, in_dtype: DataType, out_dtype: DataType +): + if in_dtype == xp.float32: + expected = xp.complex64 + else: + assert in_dtype == xp.float64 # sanity check + expected = xp.complex128 + assert_dtype( + func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected + ) + + +def assert_complex_to_float_dtype( + func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype" +): + if in_dtype == xp.complex64: + expected = xp.float32 + elif in_dtype == xp.complex128: + expected = xp.float64 + else: + assert in_dtype in (xp.float32, xp.float64) # sanity check + expected = in_dtype + assert_dtype( + func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name + ) + + +def assert_kw_dtype( + func_name: str, + *, + kw_dtype: DataType, + out_dtype: DataType, +): """ Assert the output dtype is the passed keyword dtype, e.g. >>> kw = {'dtype': xp.uint8} - >>> out = xp.ones(5, **kw) - >>> assert_kw_dtype('ones', kw['dtype'], out.dtype) + >>> out = xp.ones(5, kw=kw) + >>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype) """ + __tracebackhide__ = True f_kw_dtype = dh.dtype_to_name[kw_dtype] f_out_dtype = dh.dtype_to_name[out_dtype] msg = ( @@ -160,6 +197,7 @@ def assert_default_float(func_name: str, out_dtype: DataType): >>> assert_default_float('ones', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_float] msg = ( @@ -169,6 +207,24 @@ def assert_default_float(func_name: str, out_dtype: DataType): assert out_dtype == dh.default_float, msg +def assert_default_complex(func_name: str, out_dtype: DataType): + """ + Assert the output dtype is the default complex, e.g. + + >>> out = xp.asarray(4+2j) + >>> assert_default_complex('asarray', out.dtype) + + """ + __tracebackhide__ = True + f_dtype = dh.dtype_to_name[out_dtype] + f_default = dh.dtype_to_name[dh.default_complex] + msg = ( + f"out.dtype={f_dtype}, should be default " + f"complex dtype {f_default} [{func_name}()]" + ) + assert out_dtype == dh.default_complex, msg + + def assert_default_int(func_name: str, out_dtype: DataType): """ Assert the output dtype is the default int, e.g. @@ -177,6 +233,7 @@ def assert_default_int(func_name: str, out_dtype: DataType): >>> assert_default_int('full', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] f_default = dh.dtype_to_name[dh.default_int] msg = ( @@ -194,6 +251,7 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty >>> assert_default_int('argmax', out.dtype) """ + __tracebackhide__ = True f_dtype = dh.dtype_to_name[out_dtype] msg = ( f"{repr_name}={f_dtype}, should be the default index dtype, " @@ -204,19 +262,20 @@ def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dty def assert_shape( func_name: str, + *, out_shape: Union[int, Shape], expected: Union[int, Shape], - /, repr_name="out.shape", - **kw, + kw: dict = {}, ): """ Assert the output shape is as expected, e.g. >>> out = xp.ones((3, 3, 3)) - >>> assert_shape('ones', out.shape, (3, 3, 3)) + >>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3)) """ + __tracebackhide__ = True if isinstance(out_shape, int): out_shape = (out_shape,) if isinstance(expected, int): @@ -231,11 +290,10 @@ def assert_result_shape( func_name: str, in_shapes: Sequence[Shape], out_shape: Shape, - /, expected: Optional[Shape] = None, *, repr_name="out.shape", - **kw, + kw: dict = {}, ): """ Assert the output shape is as expected. @@ -244,13 +302,14 @@ def assert_result_shape( in_shapes, to test against out_shape, e.g. >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3))) - >>> assert_shape('add', [(3, 1), (1, 3)], out.shape) + >>> assert_result_shape('add', in_shape=[(3, 1), (1, 3)], out_shape=out.shape) is equivalent to >>> assert out.shape == (3, 3) """ + __tracebackhide__ = True if expected is None: expected = sh.broadcast_shapes(*in_shapes) f_in_shapes = " . ".join(str(s) for s in in_shapes) @@ -263,12 +322,12 @@ def assert_result_shape( def assert_keepdimable_shape( func_name: str, + *, in_shape: Shape, out_shape: Shape, axes: Tuple[int, ...], keepdims: bool, - /, - **kw, + kw: dict = {}, ): """ Assert the output shape from a keepdimable function is as expected, e.g. @@ -276,8 +335,8 @@ def assert_keepdimable_shape( >>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) >>> out1 = xp.max(x, keepdims=False) >>> out2 = xp.max(x, keepdims=True) - >>> assert_keepdimable_shape('max', x.shape, out1.shape, (0, 1), False) - >>> assert_keepdimable_shape('max', x.shape, out2.shape, (0, 1), True) + >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out1.shape, axes=(0, 1), keepdims=False) + >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out2.shape, axes=(0, 1), keepdims=True) is equivalent to @@ -285,29 +344,38 @@ def assert_keepdimable_shape( >>> assert out2.shape == (1, 1) """ + __tracebackhide__ = True if keepdims: shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape)) else: shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes) - assert_shape(func_name, out_shape, shape, **kw) + assert_shape(func_name, out_shape=out_shape, expected=shape, kw=kw) def assert_0d_equals( - func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw + func_name: str, + *, + x_repr: str, + x_val: Array, + out_repr: str, + out_val: Array, + kw: dict = {}, ): """ Assert a 0d array is as expected, e.g. >>> x = xp.asarray([0, 1, 2]) - >>> res = xp.asarray(x, copy=True) + >>> kw = {'copy': True} + >>> res = xp.asarray(x, **kw) >>> res[0] = 42 - >>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0]) + >>> assert_0d_equals('asarray', x_repr='x[0]', x_val=x[0], out_repr='x[0]', out_val=res[0], kw=kw) is equivalent to >>> assert res[0] == x[0] """ + __tracebackhide__ = True msg = ( f"{out_repr}={out_val}, but should be {x_repr}={x_val} " f"[{func_name}({fmt_kw(kw)})]" @@ -320,99 +388,214 @@ def assert_0d_equals( def assert_scalar_equals( func_name: str, + *, type_: ScalarType, idx: Shape, out: Scalar, expected: Scalar, - /, repr_name: str = "out", - **kw, + kw: dict = {}, ): """ - Assert a 0d array, convered to a scalar, is as expected, e.g. + Assert a 0d array, converted to a scalar, is as expected, e.g. >>> x = xp.ones(5, dtype=xp.uint8) >>> out = xp.sum(x) - >>> assert_scalar_equals('sum', int, (), int(out), 5) + >>> assert_scalar_equals('sum', type_int, out=(), out=int(out), expected=5) is equivalent to >>> assert int(out) == 5 + NOTE: This function does *exact* comparison, even for floats. For + approximate float comparisons use assert_scalar_isclose """ + __tracebackhide__ = True repr_name = repr_name if idx == () else f"{repr_name}[{idx}]" f_func = f"{func_name}({fmt_kw(kw)})" if type_ in [bool, int]: msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" assert out == expected, msg - elif math.isnan(expected): + elif cmath.isnan(expected): msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" - assert math.isnan(out), msg + assert cmath.isnan(out), msg else: - msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]" - assert math.isclose(out, expected, rel_tol=0.25, abs_tol=1), msg + msg = f"{repr_name}={out}, but should be {expected} [{f_func}]" + assert out == expected, msg + + +def assert_scalar_isclose( + func_name: str, + *, + rel_tol: float = 0.25, + abs_tol: float = 1, + type_: ScalarType, + idx: Shape, + out: Scalar, + expected: Scalar, + repr_name: str = "out", + kw: dict = {}, +): + """ + Assert a 0d array, converted to a scalar, is close to the expected value, e.g. + + >>> x = xp.ones(5., dtype=xp.float64) + >>> out = xp.sum(x) + >>> assert_scalar_isclose('sum', type_int, out=(), out=int(out), expected=5.) + + is equivalent to + + >>> assert math.isclose(float(out) == 5.) + + """ + __tracebackhide__ = True + repr_name = repr_name if idx == () else f"{repr_name}[{idx}]" + f_func = f"{func_name}({fmt_kw(kw)})" + msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]" + assert type_ in [float, complex] # Sanity check + assert cmath.isclose(out, expected, rel_tol=rel_tol, abs_tol=abs_tol), msg def assert_fill( - func_name: str, fill_value: Scalar, dtype: DataType, out: Array, /, **kw + func_name: str, + *, + fill_value: Scalar, + dtype: DataType, + out: Array, + kw: dict = {}, ): """ Assert all elements of an array is as expected, e.g. >>> out = xp.full(5, 42, dtype=xp.uint8) - >>> assert_fill('full', 42, xp.uint8, out, 5) + >>> assert_fill('full', fill_value=42, dtype=xp.uint8, out=out, kw=dict(shape=5)) is equivalent to >>> assert xp.all(out == 42) """ + __tracebackhide__ = True msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}" - if math.isnan(fill_value): + if cmath.isnan(fill_value): assert xp.all(xp.isnan(out)), msg else: assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg +def _has_functional_signbit() -> bool: + # signbit can be available but not implemented (e.g., in array-api-strict) + if not hasattr(_xp, "signbit"): + return False + try: + assert _xp.all(_xp.signbit(_xp.asarray(0.0)) == False) + except: + return False + return True + +def _real_float_strict_equals(out: Array, expected: Array) -> bool: + nan_mask = xp.isnan(out) + if not xp.all(nan_mask == xp.isnan(expected)): + return False + ignore_mask = nan_mask + + # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's + # not that big of a deal for the perf costs. + if _has_functional_signbit(): + out_zero_mask = out == 0 + out_sign_mask = _xp.signbit(out) + out_pos_zero_mask = out_zero_mask & out_sign_mask + out_neg_zero_mask = out_zero_mask & ~out_sign_mask + expected_zero_mask = expected == 0 + expected_sign_mask = _xp.signbit(expected) + expected_pos_zero_mask = expected_zero_mask & expected_sign_mask + expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask + pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask + neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask + if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)): + return False + ignore_mask |= out_zero_mask + + replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself + assert replacement == replacement # sanity check + match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected) + return xp.all(match) + + +def _assert_float_element(at_out: Array, at_expected: Array, msg: str): + if xp.isnan(at_expected): + assert xp.isnan(at_out), msg + elif at_expected == 0.0 or at_expected == -0.0: + scalar_at_expected = float(at_expected) + scalar_at_out = float(at_out) + if is_pos_zero(scalar_at_expected): + assert is_pos_zero(scalar_at_out), msg + else: + assert is_neg_zero(scalar_at_expected) # sanity check + assert is_neg_zero(scalar_at_out), msg + else: + assert at_out == at_expected, msg + + def assert_array_elements( - func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw + func_name: str, + *, + out: Array, + expected: Array, + out_repr: str = "out", + kw: dict = {}, ): """ Assert array elements are (strictly) as expected, e.g. >>> x = xp.arange(5) >>> out = xp.asarray(x) - >>> assert_array_elements('asarray', out, x) + >>> assert_array_elements('asarray', out=out, expected=x) is equivalent to >>> assert xp.all(out == x) """ + __tracebackhide__ = True dh.result_type(out.dtype, expected.dtype) # sanity check - assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check + assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check f_func = f"[{func_name}({fmt_kw(kw)})]" - if dh.is_float_dtype(out.dtype): + + # First we try short-circuit for a successful assertion by using vectorised checks. + if out.dtype in dh.real_float_dtypes: + if _real_float_strict_equals(out, expected): + return + elif out.dtype in dh.complex_dtypes: + real_match = _real_float_strict_equals(xp.real(out), xp.real(expected)) + imag_match = _real_float_strict_equals(xp.imag(out), xp.imag(expected)) + if real_match and imag_match: + return + else: + match = out == expected + if xp.all(match): + return + + # In case of mismatch, generate a more helpful error. Cycling through all indices is + # costly in some array api implementations, so we only do this in the case of a failure. + msg_template = "{}={}, but should be {} " + f_func + if out.dtype in dh.real_float_dtypes: for idx in sh.ndindex(out.shape): at_out = out[idx] at_expected = expected[idx] - msg = ( - f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} " - f"{f_func}" - ) - if xp.isnan(at_expected): - assert xp.isnan(at_out), msg - elif at_expected == 0.0 or at_expected == -0.0: - scalar_at_expected = float(at_expected) - scalar_at_out = float(at_out) - if is_pos_zero(scalar_at_expected): - assert is_pos_zero(scalar_at_out), msg - else: - assert is_neg_zero(scalar_at_expected) # sanity check - assert is_neg_zero(scalar_at_out), msg - else: - assert at_out == at_expected, msg + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) + _assert_float_element(at_out, at_expected, msg) + elif out.dtype in dh.complex_dtypes: + assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes) + for idx in sh.ndindex(out.shape): + at_out = out[idx] + at_expected = expected[idx] + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) + _assert_float_element(xp.real(at_out), xp.real(at_expected), msg) + _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg) else: - assert xp.all( - out == expected - ), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}" + for idx in sh.ndindex(out.shape): + at_out = out[idx] + at_expected = expected[idx] + msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected) + assert at_out == at_expected, msg diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index ba7d994e..52c6f3fc 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, Optional, Sequence, Tuple, Union from ndindex import iter_indices as _iter_indices @@ -8,7 +8,7 @@ __all__ = [ "broadcast_shapes", - "normalise_axis", + "normalize_axis", "ndindex", "axis_ndindex", "axes_ndindex", @@ -65,11 +65,13 @@ def broadcast_shapes(*shapes: Shape): return result -def normalise_axis( - axis: Optional[Union[int, Tuple[int, ...]]], ndim: int +def normalize_axis( + axis: Optional[Union[int, Sequence[int]]], ndim: int ) -> Tuple[int, ...]: if axis is None: return tuple(range(ndim)) + elif isinstance(axis, Sequence) and not isinstance(axis, tuple): + axis = tuple(axis) axes = axis if isinstance(axis, tuple) else (axis,) axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes) return axes diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py index 3fba33fc..9025c461 100644 --- a/array_api_tests/stubs.py +++ b/array_api_tests/stubs.py @@ -6,6 +6,8 @@ from types import FunctionType, ModuleType from typing import Dict, List +from . import api_version + __all__ = [ "name_to_func", "array_methods", @@ -15,20 +17,21 @@ "extension_to_funcs", ] +spec_module = "_" + api_version.replace('.', '_') -spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / "API_specification" +spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification" assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`" -sigs_dir = spec_dir / "signatures" +sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module assert sigs_dir.exists() -spec_abs_path: str = str(spec_dir.resolve()) -sys.path.append(spec_abs_path) -assert find_spec("signatures") is not None +sigs_abs_path: str = str(sigs_dir.parent.parent.resolve()) +sys.path.append(sigs_abs_path) +assert find_spec(f"array_api_stubs.{spec_module}") is not None name_to_mod: Dict[str, ModuleType] = {} for path in sigs_dir.glob("*.py"): name = path.name.replace(".py", "") - name_to_mod[name] = import_module(f"signatures.{name}") + name_to_mod[name] = import_module(f"array_api_stubs.{spec_module}.{name}") array = name_to_mod["array_object"].array array_methods = [ @@ -36,8 +39,7 @@ if n != "__init__" # probably exists for Sphinx ] array_attributes = [ - n for n, f in inspect.getmembers(array, predicate=lambda x: not inspect.isfunction(x)) - if n != "__init__" # probably exists for Sphinx + n for n, f in inspect.getmembers(array, predicate=lambda x: isinstance(x, property)) ] category_to_funcs: Dict[str, List[FunctionType]] = {} @@ -53,7 +55,26 @@ all_funcs.extend(funcs) name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs} -EXTENSIONS: str = ["linalg"] +info_funcs = [] +if api_version >= "2023.12": + # The info functions in the stubs are in info.py, but this is not a name + # in the standard. + info_mod = name_to_mod["info"] + + # Note that __array_namespace_info__ is in info.__all__ but it is in the + # top-level namespace, not the info namespace. + info_funcs = [getattr(info_mod, name) for name in info_mod.__all__ + if name != '__array_namespace_info__'] + assert all(isinstance(f, FunctionType) for f in info_funcs) + name_to_func.update({f.__name__: f for f in info_funcs}) + + all_funcs.append(info_mod.__array_namespace_info__) + name_to_func['__array_namespace_info__'] = info_mod.__array_namespace_info__ + category_to_funcs['info'] = [info_mod.__array_namespace_info__] + +EXTENSIONS: List[str] = ["linalg"] +if api_version >= "2022.12": + EXTENSIONS.append("fft") extension_to_funcs: Dict[str, List[FunctionType]] = {} for ext in EXTENSIONS: mod = name_to_mod[ext] @@ -71,3 +92,7 @@ for func in funcs: if func.__name__ not in name_to_func.keys(): name_to_func[func.__name__] = func + +# sanity check public attributes are not empty +for attr in __all__: + assert len(locals()[attr]) != 0, f"{attr} is empty" diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index df3edb88..4d4af350 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -1,3 +1,4 @@ +import cmath import math from itertools import product from typing import List, Sequence, Tuple, Union, get_args @@ -12,25 +13,22 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Index, Param, Scalar, ScalarType, Shape -pytestmark = pytest.mark.ci - def scalar_objects( dtype: DataType, shape: Shape ) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]: """Generates scalars or nested sequences which are valid for xp.asarray()""" size = math.prod(shape) - return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map( + return st.lists(hh.from_dtype(dtype), min_size=size, max_size=size).map( lambda l: sh.reshape(l, shape) ) -def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]: +def normalize_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]: """ - Normalise an indexing key. + Normalize an indexing key. * If a non-tuple index, wrap as a tuple. * Represent ellipsis as equivalent slices. @@ -50,7 +48,7 @@ def get_indexed_axes_and_out_shape( key: Tuple[Union[int, slice, None], ...], shape: Shape ) -> Tuple[Tuple[Sequence[int], ...], Shape]: """ - From the (normalised) key and input shape, calculates: + From the (normalized) key and input shape, calculates: * indexed_axes: For each dimension, the axes which the key indexes. * out_shape: The resulting shape of indexing an array (of the input shape) @@ -76,7 +74,7 @@ def get_indexed_axes_and_out_shape( return tuple(axes_indices), tuple(out_shape) -@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data()) +@given(shape=hh.shapes(), dtype=hh.all_dtypes, data=st.data()) def test_getitem(shape, dtype, data): zero_sided = any(side == 0 for side in shape) if zero_sided: @@ -89,11 +87,11 @@ def test_getitem(shape, dtype, data): out = x[key] - ph.assert_dtype("__getitem__", x.dtype, out.dtype) - _key = normalise_key(key, shape) - axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) - ph.assert_shape("__getitem__", out.shape, out_shape) - out_zero_sided = any(side == 0 for side in out_shape) + ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) + _key = normalize_key(key, shape) + axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape) + ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) + out_zero_sided = any(side == 0 for side in expected_shape) if not zero_sided and not out_zero_sided: out_obj = [] for idx in product(*axes_indices): @@ -101,14 +99,15 @@ def test_getitem(shape, dtype, data): for i in idx: val = val[i] out_obj.append(val) - out_obj = sh.reshape(out_obj, out_shape) + out_obj = sh.reshape(out_obj, expected_shape) expected = xp.asarray(out_obj, dtype=dtype) - ph.assert_array_elements("__getitem__", out, expected) + ph.assert_array_elements("__getitem__", out=out, expected=expected) +@pytest.mark.unvectorized @given( shape=hh.shapes(), - dtypes=oneway_promotable_dtypes(dh.all_dtypes), + dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), data=st.data(), ) def test_setitem(shape, dtypes, data): @@ -120,39 +119,44 @@ def test_setitem(shape, dtypes, data): x = xp.asarray(obj, dtype=dtypes.result_dtype) note(f"{x=}") key = data.draw(xps.indices(shape=shape), label="key") - _key = normalise_key(key, shape) + _key = normalize_key(key, shape) axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape) - value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape) + value_strat = hh.arrays(dtype=dtypes.result_dtype, shape=out_shape) if out_shape == (): # We can pass scalars if we're only indexing one element - value_strat |= xps.from_dtype(dtypes.result_dtype) + value_strat |= hh.from_dtype(dtypes.result_dtype) value = data.draw(value_strat, label="value") res = xp.asarray(x, copy=True) res[key] = value - ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") - ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape") + ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape") f_res = sh.fmt_idx("x", key) if isinstance(value, get_args(Scalar)): msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]" - if math.isnan(value): + if cmath.isnan(value): assert xp.isnan(res[key]), msg else: assert res[key] == value, msg else: - ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res) + ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res) unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices)) for idx in unaffected_indices: ph.assert_0d_equals( - "__setitem__", f"old {f_res}", x[idx], f"modified {f_res}", res[idx] + "__setitem__", + x_repr=f"old {f_res}", + x_val=x[idx], + out_repr=f"modified {f_res}", + out_val=res[idx], ) +@pytest.mark.unvectorized @pytest.mark.data_dependent_shapes @given(hh.shapes(), st.data()) def test_getitem_masking(shape, data): - x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") + x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x") mask_shapes = st.one_of( st.sampled_from([x.shape, ()]), st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map( @@ -160,7 +164,7 @@ def test_getitem_masking(shape, data): ), hh.shapes(), ) - key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key") + key = data.draw(hh.arrays(dtype=xp.bool, shape=mask_shapes), label="key") if key.ndim > x.ndim or not all( ks in (xs, 0) for xs, ks in zip(x.shape, key.shape) @@ -171,14 +175,14 @@ def test_getitem_masking(shape, data): out = x[key] - ph.assert_dtype("__getitem__", x.dtype, out.dtype) + ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype) if key.ndim == 0: - out_shape = (1,) if key else (0,) - out_shape += x.shape + expected_shape = (1,) if key else (0,) + expected_shape += x.shape else: size = int(xp.sum(xp.astype(key, xp.uint8))) - out_shape = (size,) + x.shape[key.ndim :] - ph.assert_shape("__getitem__", out.shape, out_shape) + expected_shape = (size,) + x.shape[key.ndim :] + ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape) if not any(s == 0 for s in key.shape): assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios out_indices = sh.ndindex(out.shape) @@ -187,64 +191,149 @@ def test_getitem_masking(shape, data): out_idx = next(out_indices) ph.assert_0d_equals( "__getitem__", - f"x[{x_idx}]", - x[x_idx], - f"out[{out_idx}]", - out[out_idx], + x_repr=f"x[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], ) +@pytest.mark.unvectorized @given(hh.shapes(), st.data()) def test_setitem_masking(shape, data): - x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x") - key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key") + x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x") + key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key") value = data.draw( - xps.from_dtype(x.dtype) | xps.arrays(dtype=x.dtype, shape=()), label="value" + hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value" ) res = xp.asarray(x, copy=True) res[key] = value - ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype") - ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.dtype") + ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype") + ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype") scalar_type = dh.get_scalar_type(x.dtype) for idx in sh.ndindex(x.shape): if key[idx]: if isinstance(value, get_args(Scalar)): ph.assert_scalar_equals( "__setitem__", - scalar_type, - idx, - scalar_type(res[idx]), - value, + type_=scalar_type, + idx=idx, + out=scalar_type(res[idx]), + expected=value, repr_name="modified x", ) else: ph.assert_0d_equals( - "__setitem__", "value", value, f"modified x[{idx}]", res[idx] + "__setitem__", + x_repr="value", + x_val=value, + out_repr=f"modified x[{idx}]", + out_val=res[idx] ) else: ph.assert_0d_equals( - "__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx] + "__setitem__", + x_repr=f"old x[{idx}]", + x_val=x[idx], + out_repr=f"modified x[{idx}]", + out_val=res[idx] ) -def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param: +# ### Fancy indexing ### + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@pytest.mark.parametrize("idx_max_dims", [1, None]) +@given(shape=hh.shapes(min_dims=2), data=st.data()) +def test_getitem_arrays_and_ints_1(shape, data, idx_max_dims): + # min_dims=2 : test multidim `x` arrays + # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None + _test_getitem_arrays_and_ints(shape, data, idx_max_dims) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@pytest.mark.parametrize("idx_max_dims", [1, None]) +@given(shape=hh.shapes(min_dims=1), data=st.data()) +def test_getitem_arrays_and_ints_2(shape, data, idx_max_dims): + # min_dims=1 : favor 1D `x` arrays + # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None + _test_getitem_arrays_and_ints(shape, data, idx_max_dims) + + +def _test_getitem_arrays_and_ints(shape, data, idx_max_dims): + assume((len(shape) > 0) and all(sh > 0 for sh in shape)) + + dtype = xp.int32 + obj = data.draw(scalar_objects(dtype, shape), label="obj") + x = xp.asarray(obj, dtype=dtype) + + # draw a mix of ints and index arrays + arr_index = [data.draw(st.booleans()) for _ in range(len(shape))] + assume(sum(arr_index) > 0) + + # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY + # max_dims=None ==> multidim indexing arrays + if sum(arr_index) > 0: + index_shapes = data.draw( + hh.mutually_broadcastable_shapes( + sum(arr_index), min_dims=1, max_dims=idx_max_dims, min_side=1 + ) + ) + index_shapes = list(index_shapes) + + # prepare the indexing tuple, a mix of integer indices and index arrays + key = [] + for i,typ in enumerate(arr_index): + if typ: + # draw an array index + this_idx = data.draw( + xps.arrays( + dtype, + shape=index_shapes.pop(), + elements=st.integers(0, shape[i]-1) + ) + ) + key.append(this_idx) + + else: + # draw an integer + key.append(data.draw(st.integers(-shape[i], shape[i]-1))) + + key = tuple(key) + out = x[key] + + arrays = [xp.asarray(k) for k in key] + bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays]) + bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays] + + for idx in sh.ndindex(bcast_shape): + tpl = tuple(k[idx] for k in bcast_key) + assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }" + + +def make_scalar_casting_param( + method_name: str, dtype: DataType, stype: ScalarType +) -> Param: + dtype_name = dh.dtype_to_name[dtype] return pytest.param( - method_name, dtype, stype, id=f"{method_name}({dh.dtype_to_name[dtype]})" + method_name, dtype, stype, id=f"{method_name}({dtype_name})" ) @pytest.mark.parametrize( "method_name, dtype, stype", - [make_param("__bool__", xp.bool, bool)] - + [make_param("__int__", d, int) for d in dh.all_int_dtypes] - + [make_param("__index__", d, int) for d in dh.all_int_dtypes] - + [make_param("__float__", d, float) for d in dh.float_dtypes], + [make_scalar_casting_param("__bool__", xp.bool, bool)] + + [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_dtypes] + + [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_dtypes] + + [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_dtypes], ) @given(data=st.data()) def test_scalar_casting(method_name, dtype, stype, data): - x = data.draw(xps.arrays(dtype, shape=()), label="x") + x = data.draw(hh.arrays(dtype, shape=()), label="x") method = getattr(x, method_name) out = method() assert isinstance( diff --git a/array_api_tests/test_constants.py b/array_api_tests/test_constants.py index 606bf897..145a2736 100644 --- a/array_api_tests/test_constants.py +++ b/array_api_tests/test_constants.py @@ -4,11 +4,9 @@ import pytest from . import dtype_helpers as dh -from ._array_module import mod as xp +from . import xp from .typing import Array -pytestmark = pytest.mark.ci - def assert_scalar_float(name: str, c: Any): assert isinstance(c, SupportsFloat), f"{name}={c!r} does not look like a float" @@ -51,3 +49,8 @@ def test_nan(): x = xp.asarray(xp.nan) assert_0d_float("nan", x) assert xp.isnan(x), "xp.isnan(xp.asarray(xp.nan))=False" + + +def test_newaxis(): + assert hasattr(xp, "newaxis") + assert xp.newaxis is None diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 4ae92f1f..1f144c72 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -1,8 +1,8 @@ +import cmath import math from itertools import count from typing import Iterator, NamedTuple, Union -import pytest from hypothesis import assume, given, note from hypothesis import strategies as st @@ -12,11 +12,8 @@ from . import pytest_helpers as ph from . import shape_helpers as sh from . import xps -from .test_operators_and_elementwise_functions import oneway_promotable_dtypes from .typing import DataType, Scalar -pytestmark = pytest.mark.ci - class frange(NamedTuple): start: float @@ -79,14 +76,14 @@ def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float] ) -@given(dtype=st.none() | hh.numeric_dtypes, data=st.data()) +@given(dtype=st.none() | hh.real_dtypes, data=st.data()) def test_arange(dtype, data): if dtype is None or dh.is_float_dtype(dtype): start = data.draw(reals(), label="start") stop = data.draw(reals() | st.none(), label="stop") else: - start = data.draw(xps.from_dtype(dtype), label="start") - stop = data.draw(xps.from_dtype(dtype), label="stop") + start = data.draw(hh.from_dtype(dtype), label="start") + stop = data.draw(hh.from_dtype(dtype), label="stop") if stop is None: _start = 0 _stop = start @@ -106,9 +103,9 @@ def test_arange(dtype, data): step_strats = [] if dtype in dh.int_dtypes: step_min = min(math.floor(-tol), -1) - step_strats.append(xps.from_dtype(dtype, max_value=step_min)) + step_strats.append(hh.from_dtype(dtype, max_value=step_min)) step_max = max(math.ceil(tol), 1) - step_strats.append(xps.from_dtype(dtype, min_value=step_max)) + step_strats.append(hh.from_dtype(dtype, min_value=step_max)) step = data.draw(st.one_of(step_strats), label="step") assert step != 0, "step must not equal 0" # sanity check @@ -128,6 +125,12 @@ def test_arange(dtype, data): assert m <= _start <= M assert m <= _stop <= M assert m <= step <= M + # Ignore ridiculous distances so we don't fail like + # + # >>> torch.arange(9132051521638391890, 0, -91320515216383920) + # RuntimeError: invalid size, possible overflow? + # + assume(abs(_start - _stop) < M // 2) r = frange(_start, _stop, step) size = len(r) @@ -152,7 +155,7 @@ def test_arange(dtype, data): else: ph.assert_default_float("arange", out.dtype) else: - ph.assert_kw_dtype("arange", dtype, out.dtype) + ph.assert_kw_dtype("arange", kw_dtype=dtype, out_dtype=out.dtype) f_sig = ", ".join(str(n) for n in args) if len(kwargs) > 0: f_sig += f", {ph.fmt_kw(kwargs)}" @@ -175,16 +178,17 @@ def test_arange(dtype, data): # min_size = math.floor(size * 0.9) max_size = max(math.ceil(size * 1.1), 1) + out_size = math.prod(out.shape) assert ( - min_size <= out.size <= max_size - ), f"{out.size=}, but should be roughly {size} {f_func}" + min_size <= out_size <= max_size + ), f"prod(out.shape)={out_size}, but should be roughly {size} {f_func}" if dh.is_int_dtype(_dtype): elements = list(r) - assume(out.size == len(elements)) - ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype)) + assume(out_size == len(elements)) + ph.assert_array_elements("arange", out=out, expected=xp.asarray(elements, dtype=_dtype)) else: - assume(out.size == size) - if out.size > 0: + assume(out_size == size) + if out_size > 0: assert xp.equal( out[0], xp.asarray(_start, dtype=out.dtype) ), f"out[0]={out[0]}, but should be {_start} {f_func}" @@ -193,7 +197,7 @@ def test_arange(dtype, data): @given(shape=hh.shapes(min_side=1), data=st.data()) def test_asarray_scalars(shape, data): kw = data.draw( - hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw" + hh.kwargs(dtype=st.none() | hh.all_dtypes, copy=st.none()), label="kw" ) dtype = kw.get("dtype", None) if dtype is None: @@ -207,11 +211,11 @@ def test_asarray_scalars(shape, data): else: _dtype = dtype if dh.is_float_dtype(_dtype): - elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32) + elements_strat = hh.from_dtype(_dtype) | hh.from_dtype(xp.int32) elif dh.is_int_dtype(_dtype): - elements_strat = xps.from_dtype(_dtype) | st.booleans() + elements_strat = hh.from_dtype(_dtype) | st.booleans() else: - elements_strat = xps.from_dtype(_dtype) + elements_strat = hh.from_dtype(_dtype) size = math.prod(shape) obj_strat = st.lists(elements_strat, min_size=size, max_size=size) scalar_type = dh.get_scalar_type(_dtype) @@ -239,27 +243,28 @@ def test_asarray_scalars(shape, data): assert out.dtype in dtype_family, msg else: assert kw["dtype"] == _dtype # sanity check - ph.assert_kw_dtype("asarray", _dtype, out.dtype) - ph.assert_shape("asarray", out.shape, shape) + ph.assert_kw_dtype("asarray", kw_dtype=_dtype, out_dtype=out.dtype) + ph.assert_shape("asarray", out_shape=out.shape, expected=shape) for idx, v_expect in zip(sh.ndindex(out.shape), _obj): v = scalar_type(out[idx]) - ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw) + ph.assert_scalar_equals("asarray", type_=scalar_type, idx=idx, out=v, expected=v_expect, kw=kw) def scalar_eq(s1: Scalar, s2: Scalar) -> bool: - if math.isnan(s1): - return math.isnan(s2) + if cmath.isnan(s1): + return cmath.isnan(s2) else: return s1 == s2 @given( shape=hh.shapes(), - dtypes=oneway_promotable_dtypes(dh.all_dtypes), + dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes), data=st.data(), ) def test_asarray_arrays(shape, dtypes, data): - x = data.draw(xps.arrays(dtype=dtypes.input_dtype, shape=shape), label="x") + # generate arrays only since we draw the copy= kwd below (and np.asarray(scalar, copy=False) error out) + x = data.draw(hh.arrays_no_scalars(dtype=dtypes.input_dtype, shape=shape), label="x") dtypes_strat = st.just(dtypes.input_dtype) if dtypes.input_dtype == dtypes.result_dtype: dtypes_strat |= st.none() @@ -272,17 +277,17 @@ def test_asarray_arrays(shape, dtypes, data): dtype = kw.get("dtype", None) if dtype is None: - ph.assert_dtype("asarray", x.dtype, out.dtype) + ph.assert_dtype("asarray", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("asarray", dtype, out.dtype) - ph.assert_shape("asarray", out.shape, x.shape) - ph.assert_array_elements("asarray", out, x, **kw) + ph.assert_kw_dtype("asarray", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("asarray", out_shape=out.shape, expected=x.shape) + ph.assert_array_elements("asarray", out=out, expected=x, kw=kw) copy = kw.get("copy", None) if copy is not None: stype = dh.get_scalar_type(x.dtype) idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") old_value = stype(x[idx]) - scalar_strat = xps.from_dtype(dtypes.input_dtype).filter( + scalar_strat = hh.from_dtype(dtypes.input_dtype).filter( lambda n: not scalar_eq(n, old_value) ) value = data.draw( @@ -293,7 +298,7 @@ def test_asarray_arrays(shape, dtypes, data): note(f"mutated {x=}") # sanity check ph.assert_scalar_equals( - "__setitem__", stype, idx, stype(x[idx]), value, repr_name="x" + "__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x" ) new_out_value = stype(out[idx]) f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}" @@ -307,27 +312,27 @@ def test_asarray_arrays(shape, dtypes, data): ), f"{f_out}, but should be {value} after x was mutated" -@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes)) +@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes)) def test_empty(shape, kw): out = xp.empty(shape, **kw) if kw.get("dtype", None) is None: ph.assert_default_float("empty", out.dtype) else: - ph.assert_kw_dtype("empty", kw["dtype"], out.dtype) - ph.assert_shape("empty", out.shape, shape, shape=shape) + ph.assert_kw_dtype("empty", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("empty", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), - kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), + x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), + kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), ) def test_empty_like(x, kw): out = xp.empty_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("empty_like", x.dtype, out.dtype) + ph.assert_dtype("empty_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("empty_like", kw["dtype"], out.dtype) - ph.assert_shape("empty_like", out.shape, x.shape) + ph.assert_kw_dtype("empty_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("empty_like", out_shape=out.shape, expected=x.shape) @given( @@ -335,7 +340,7 @@ def test_empty_like(x, kw): n_cols=st.none() | hh.sqrt_sizes, kw=hh.kwargs( k=st.integers(), - dtype=xps.numeric_dtypes(), + dtype=hh.numeric_dtypes, ), ) def test_eye(n_rows, n_cols, kw): @@ -343,17 +348,17 @@ def test_eye(n_rows, n_cols, kw): if kw.get("dtype", None) is None: ph.assert_default_float("eye", out.dtype) else: - ph.assert_kw_dtype("eye", kw["dtype"], out.dtype) + ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype) _n_cols = n_rows if n_cols is None else n_cols - ph.assert_shape("eye", out.shape, (n_rows, _n_cols), n_rows=n_rows, n_cols=n_cols) - f_func = f"[eye({n_rows=}, {n_cols=})]" - for i in range(n_rows): - for j in range(_n_cols): - f_indexed_out = f"out[{i}, {j}]={out[i, j]}" - if j - i == kw.get("k", 0): - assert out[i, j] == 1, f"{f_indexed_out}, should be 1 {f_func}" - else: - assert out[i, j] == 0, f"{f_indexed_out}, should be 0 {f_func}" + ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols)) + k = kw.get("k", 0) + expected = xp.asarray( + [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)], + dtype=out.dtype # Note: dtype already checked above. + ) + if 0 in expected.shape: + expected = xp.reshape(expected, (n_rows, _n_cols)) + ph.assert_array_elements("eye", out=out, expected=expected, kw=kw) default_unsafe_dtypes = [xp.uint64] @@ -361,71 +366,77 @@ def test_eye(n_rows, n_cols, kw): default_unsafe_dtypes.extend([xp.uint32, xp.int64]) if dh.default_float == xp.float32: default_unsafe_dtypes.append(xp.float64) -default_safe_dtypes: st.SearchStrategy = xps.scalar_dtypes().filter( +if dh.default_complex == xp.complex64: + default_unsafe_dtypes.append(xp.complex64) +default_safe_dtypes: st.SearchStrategy = hh.all_dtypes.filter( lambda d: d not in default_unsafe_dtypes ) @st.composite -def full_fill_values(draw) -> st.SearchStrategy[float]: +def full_fill_values(draw) -> Union[bool, int, float, complex]: kw = draw( - st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw") + st.shared(hh.kwargs(dtype=st.none() | hh.all_dtypes), key="full_kw") ) dtype = kw.get("dtype", None) or draw(default_safe_dtypes) - return draw(xps.from_dtype(dtype)) + return draw(hh.from_dtype(dtype)) @given( shape=hh.shapes(), fill_value=full_fill_values(), - kw=st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_kw"), + kw=st.shared(hh.kwargs(dtype=st.none() | hh.all_dtypes), key="full_kw"), ) def test_full(shape, fill_value, kw): - out = xp.full(shape, fill_value, **kw) + with hh.reject_overflow(): + out = xp.full(shape, fill_value, **kw) if kw.get("dtype", None): dtype = kw["dtype"] elif isinstance(fill_value, bool): dtype = xp.bool elif isinstance(fill_value, int): dtype = dh.default_int - else: + elif isinstance(fill_value, float): dtype = dh.default_float + else: + assert isinstance(fill_value, complex) # sanity check + dtype = dh.default_complex + # Ignore large components so we don't fail like + # + # >>> torch.fill(complex(0.0, 3.402823466385289e+38)) + # RuntimeError: value cannot be converted to complex without overflow + # + M = dh.dtype_ranges[dh.dtype_components[dtype]].max + assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag])) if kw.get("dtype", None) is None: if isinstance(fill_value, bool): - pass # TODO + assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]" elif isinstance(fill_value, int): ph.assert_default_int("full", out.dtype) - else: + elif isinstance(fill_value, float): ph.assert_default_float("full", out.dtype) + else: + assert isinstance(fill_value, complex) # sanity check + ph.assert_default_complex("full", out.dtype) else: - ph.assert_kw_dtype("full", kw["dtype"], out.dtype) - ph.assert_shape("full", out.shape, shape, shape=shape) - ph.assert_fill("full", fill_value, dtype, out, fill_value=fill_value) + ph.assert_kw_dtype("full", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("full", out_shape=out.shape, expected=shape, kw=dict(shape=shape)) + ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) -@st.composite -def full_like_fill_values(draw): - kw = draw( - st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw") - ) - dtype = kw.get("dtype", None) or draw(hh.shared_dtypes) - return draw(xps.from_dtype(dtype)) - - -@given( - x=xps.arrays(dtype=hh.shared_dtypes, shape=hh.shapes()), - fill_value=full_like_fill_values(), - kw=st.shared(hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), key="full_like_kw"), -) -def test_full_like(x, fill_value, kw): +@given(kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), data=st.data()) +def test_full_like(kw, data): + dtype = kw.get("dtype", None) or data.draw(hh.all_dtypes, label="dtype") + x = data.draw(hh.arrays(dtype=dtype, shape=hh.shapes()), label="x") + fill_value = data.draw(hh.from_dtype(dtype), label="fill_value") out = xp.full_like(x, fill_value, **kw) dtype = kw.get("dtype", None) or x.dtype if kw.get("dtype", None) is None: - ph.assert_dtype("full_like", x.dtype, out.dtype) + ph.assert_dtype("full_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("full_like", kw["dtype"], out.dtype) - ph.assert_shape("full_like", out.shape, x.shape) - ph.assert_fill("full_like", fill_value, dtype, out, fill_value=fill_value) + ph.assert_kw_dtype("full_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("full_like", out_shape=out.shape, expected=x.shape) + ph.assert_fill("full_like", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value)) finite_kw = {"allow_nan": False, "allow_infinity": False} @@ -433,18 +444,21 @@ def test_full_like(x, fill_value, kw): @given( num=hh.sizes, - dtype=st.none() | xps.floating_dtypes(), + dtype=st.none() | hh.real_floating_dtypes, endpoint=st.booleans(), data=st.data(), ) def test_linspace(num, dtype, endpoint, data): _dtype = dh.default_float if dtype is None else dtype - start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start") - stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop") + start = data.draw(hh.from_dtype(_dtype, **finite_kw), label="start") + stop = data.draw(hh.from_dtype(_dtype, **finite_kw), label="stop") # avoid overflow errors assume(not xp.isnan(xp.asarray(stop - start, dtype=_dtype))) assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype))) + # avoid generating very large distances + # https://github.com/data-apis/array-api-tests/issues/125 + assume(abs(stop - start) < math.sqrt(dh.dtype_ranges[_dtype].max)) kw = data.draw( hh.specified_kwargs( @@ -458,8 +472,8 @@ def test_linspace(num, dtype, endpoint, data): if dtype is None: ph.assert_default_float("linspace", out.dtype) else: - ph.assert_kw_dtype("linspace", dtype, out.dtype) - ph.assert_shape("linspace", out.shape, num, start=stop, stop=stop, num=num) + ph.assert_kw_dtype("linspace", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("linspace", out_shape=out.shape, expected=num, kw=dict(start=start, stop=stop, num=num)) f_func = f"[linspace({start}, {stop}, {num})]" if num > 0: assert xp.equal( @@ -475,29 +489,30 @@ def test_linspace(num, dtype, endpoint, data): # the first num elements when endpoint=False expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True) expected = expected[:-1] - ph.assert_array_elements("linspace", out, expected) + ph.assert_array_elements("linspace", out=out, expected=expected) -@given(dtype=xps.numeric_dtypes(), data=st.data()) +@given(dtype=hh.numeric_dtypes, data=st.data()) def test_meshgrid(dtype, data): # The number and size of generated arrays is arbitrarily limited to prevent # meshgrid() running out of memory. shapes = data.draw( st.integers(1, 5).flatmap( lambda n: hh.mutually_broadcastable_shapes( - n, min_dims=1, max_dims=1, max_side=5 + n, min_dims=1, max_dims=1, max_side=4 ) ), label="shapes", ) arrays = [] for i, shape in enumerate(shapes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) - assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check + # sanity check + # assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE out = xp.meshgrid(*arrays) for i, x in enumerate(out): - ph.assert_dtype("meshgrid", dtype, x.dtype, repr_name=f"out[{i}].dtype") + ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype") def make_one(dtype: DataType) -> Scalar: @@ -509,31 +524,33 @@ def make_one(dtype: DataType) -> Scalar: return True -@given(hh.shapes(), hh.kwargs(dtype=st.none() | xps.scalar_dtypes())) +@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes)) def test_ones(shape, kw): out = xp.ones(shape, **kw) if kw.get("dtype", None) is None: ph.assert_default_float("ones", out.dtype) else: - ph.assert_kw_dtype("ones", kw["dtype"], out.dtype) - ph.assert_shape("ones", out.shape, shape, shape=shape) + ph.assert_kw_dtype("ones", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("ones", out_shape=out.shape, expected=shape, + kw={'shape': shape, **kw}) dtype = kw.get("dtype", None) or dh.default_float - ph.assert_fill("ones", make_one(dtype), dtype, out) + ph.assert_fill("ones", fill_value=make_one(dtype), dtype=dtype, out=out, kw=kw) @given( - x=xps.arrays(dtype=hh.dtypes, shape=hh.shapes()), - kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), + x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), + kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), ) def test_ones_like(x, kw): out = xp.ones_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("ones_like", x.dtype, out.dtype) + ph.assert_dtype("ones_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("ones_like", kw["dtype"], out.dtype) - ph.assert_shape("ones_like", out.shape, x.shape) + ph.assert_kw_dtype("ones_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("ones_like", out_shape=out.shape, expected=x.shape, kw=kw) dtype = kw.get("dtype", None) or x.dtype - ph.assert_fill("ones_like", make_one(dtype), dtype, out) + ph.assert_fill("ones_like", fill_value=make_one(dtype), dtype=dtype, + out=out, kw=kw) def make_zero(dtype: DataType) -> Scalar: @@ -545,28 +562,31 @@ def make_zero(dtype: DataType) -> Scalar: return False -@given(hh.shapes(), hh.kwargs(dtype=st.none() | xps.scalar_dtypes())) +@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes)) def test_zeros(shape, kw): out = xp.zeros(shape, **kw) if kw.get("dtype", None) is None: - ph.assert_default_float("zeros", out.dtype) + ph.assert_default_float("zeros", out_dtype=out.dtype) else: - ph.assert_kw_dtype("zeros", kw["dtype"], out.dtype) - ph.assert_shape("zeros", out.shape, shape, shape=shape) + ph.assert_kw_dtype("zeros", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("zeros", out_shape=out.shape, expected=shape, kw={'shape': shape, **kw}) dtype = kw.get("dtype", None) or dh.default_float - ph.assert_fill("zeros", make_zero(dtype), dtype, out) + ph.assert_fill("zeros", fill_value=make_zero(dtype), dtype=dtype, out=out, + kw=kw) @given( - x=xps.arrays(dtype=hh.dtypes, shape=hh.shapes()), - kw=hh.kwargs(dtype=st.none() | xps.scalar_dtypes()), + x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), + kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), ) def test_zeros_like(x, kw): out = xp.zeros_like(x, **kw) if kw.get("dtype", None) is None: - ph.assert_dtype("zeros_like", x.dtype, out.dtype) + ph.assert_dtype("zeros_like", in_dtype=x.dtype, out_dtype=out.dtype) else: - ph.assert_kw_dtype("zeros_like", kw["dtype"], out.dtype) - ph.assert_shape("zeros_like", out.shape, x.shape) + ph.assert_kw_dtype("zeros_like", kw_dtype=kw["dtype"], out_dtype=out.dtype) + ph.assert_shape("zeros_like", out_shape=out.shape, expected=x.shape, + kw=kw) dtype = kw.get("dtype", None) or x.dtype - ph.assert_fill("zeros_like", make_zero(dtype), dtype, out) + ph.assert_fill("zeros_like", fill_value=make_zero(dtype), dtype=dtype, + out=out, kw=kw) diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py index 763c71a4..84e6f34c 100644 --- a/array_api_tests/test_data_type_functions.py +++ b/array_api_tests/test_data_type_functions.py @@ -2,7 +2,7 @@ from typing import Union import pytest -from hypothesis import given +from hypothesis import given, assume from hypothesis import strategies as st from . import _array_module as xp @@ -13,34 +13,59 @@ from . import xps from .typing import DataType -pytestmark = pytest.mark.ci + +# TODO: test with complex dtypes +def non_complex_dtypes(): + return xps.boolean_dtypes() | hh.real_dtypes def float32(n: Union[int, float]) -> float: return struct.unpack("!f", struct.pack("!f", float(n)))[0] +def _float_match_complex(complex_dtype): + if complex_dtype == xp.complex64: + return xp.float32 + elif complex_dtype == xp.complex128: + return xp.float64 + else: + return dh.default_float + + @given( - x_dtype=xps.scalar_dtypes(), - dtype=xps.scalar_dtypes(), + x_dtype=hh.all_dtypes, + dtype=hh.all_dtypes, kw=hh.kwargs(copy=st.booleans()), data=st.data(), ) def test_astype(x_dtype, dtype, kw, data): + _complex_dtypes = (xp.complex64, xp.complex128) + if xp.bool in (x_dtype, dtype): - elements_strat = xps.from_dtype(x_dtype) + elements_strat = hh.from_dtype(x_dtype) else: - m1, M1 = dh.dtype_ranges[x_dtype] - m2, M2 = dh.dtype_ranges[dtype] + if dh.is_int_dtype(x_dtype): cast = int - elif x_dtype == xp.float32: + elif x_dtype in (xp.float32, xp.complex64): cast = float32 else: cast = float + + real_dtype = x_dtype + if x_dtype in _complex_dtypes: + real_dtype = _float_match_complex(x_dtype) + m1, M1 = dh.dtype_ranges[real_dtype] + + real_dtype = dtype + if dtype in _complex_dtypes: + real_dtype = _float_match_complex(x_dtype) + m2, M2 = dh.dtype_ranges[real_dtype] + min_value = cast(max(m1, m2)) max_value = cast(min(M1, M2)) - elements_strat = xps.from_dtype( + + elements_strat = hh.from_dtype( x_dtype, min_value=min_value, max_value=max_value, @@ -48,13 +73,18 @@ def test_astype(x_dtype, dtype, kw, data): allow_infinity=False, ) x = data.draw( - xps.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x" + hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x" ) + # according to the spec, "Casting a complex floating-point array to a real-valued + # data type should not be permitted." + # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype + assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes))) + out = xp.astype(x, dtype, **kw) - ph.assert_kw_dtype("astype", dtype, out.dtype) - ph.assert_shape("astype", out.shape, x.shape) + ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype) + ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw) # TODO: test values # TODO: test copy @@ -65,27 +95,30 @@ def test_astype(x_dtype, dtype, kw, data): def test_broadcast_arrays(shapes, data): arrays = [] for c, shape in enumerate(shapes, 1): - x = data.draw(xps.arrays(dtype=xps.scalar_dtypes(), shape=shape), label=f"x{c}") + x = data.draw(hh.arrays(dtype=hh.all_dtypes, shape=shape), label=f"x{c}") arrays.append(x) out = xp.broadcast_arrays(*arrays) - out_shape = sh.broadcast_shapes(*shapes) + expected_shape = sh.broadcast_shapes(*shapes) for i, x in enumerate(arrays): ph.assert_dtype( - "broadcast_arrays", x.dtype, out[i].dtype, repr_name=f"out[{i}].dtype" + "broadcast_arrays", + in_dtype=x.dtype, + out_dtype=out[i].dtype, + repr_name=f"out[{i}].dtype" ) ph.assert_result_shape( "broadcast_arrays", - shapes, - out[i].shape, - out_shape, + in_shapes=shapes, + out_shape=out[i].shape, + expected=expected_shape, repr_name=f"out[{i}].shape", ) # TODO: test values -@given(x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), data=st.data()) +@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data()) def test_broadcast_to(x, data): shape = data.draw( hh.mutually_broadcastable_shapes(1, base_shape=x.shape) @@ -96,74 +129,161 @@ def test_broadcast_to(x, data): out = xp.broadcast_to(x, shape) - ph.assert_dtype("broadcast_to", x.dtype, out.dtype) - ph.assert_shape("broadcast_to", out.shape, shape) + ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape) # TODO: test values -@given(_from=xps.scalar_dtypes(), to=xps.scalar_dtypes(), data=st.data()) -def test_can_cast(_from, to, data): - from_ = data.draw( - st.just(_from) | xps.arrays(dtype=_from, shape=hh.shapes()), label="from_" - ) +@given(_from=hh.all_dtypes, to=hh.all_dtypes) +def test_can_cast(_from, to): + out = xp.can_cast(_from, to) - out = xp.can_cast(from_, to) + expected = False + for other in dh.all_dtypes: + if dh.promotion_table.get((_from, other)) == to: + expected = True + break f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]" - assert isinstance(out, bool), f"{type(out)=}, but should be bool {f_func}" - if _from == xp.bool: - expected = to == xp.bool - else: - for dtypes in [dh.all_int_dtypes, dh.float_dtypes]: - if _from in dtypes: - same_family = to in dtypes - break - if same_family: - from_min, from_max = dh.dtype_ranges[_from] - to_min, to_max = dh.dtype_ranges[to] - expected = from_min >= to_min and from_max <= to_max - else: - expected = False - assert out == expected, f"{out=}, but should be {expected} {f_func}" + if expected: + # cross-kind casting is not explicitly disallowed. We can only test + # the cases where it should return True. TODO: if expected=False, + # check that the array library actually allows such casts. + assert out == expected, f"{out=}, but should be {expected} {f_func}" -def make_dtype_id(dtype: DataType) -> str: - return dh.dtype_to_name[dtype] +@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes) +def test_finfo(dtype): + for arg in ( + dtype, + xp.asarray(1, dtype=dtype), + # np.float64 and np.asarray(1, dtype=np.float64).dtype are different + xp.asarray(1, dtype=dtype).dtype, + ): + out = xp.finfo(arg) + assert isinstance(out.bits, int) + assert isinstance(out.eps, float) + assert isinstance(out.max, float) + assert isinstance(out.min, float) + assert isinstance(out.smallest_normal, float) -@pytest.mark.parametrize("dtype", dh.float_dtypes, ids=make_dtype_id) -def test_finfo(dtype): +@pytest.mark.min_version("2022.12") +@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes) +def test_finfo_dtype(dtype): out = xp.finfo(dtype) - f_func = f"[finfo({dh.dtype_to_name[dtype]})]" - for attr, stype in [ - ("bits", int), - ("eps", float), - ("max", float), - ("min", float), - ("smallest_normal", float), - ]: - assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}" - value = getattr(out, attr) - assert isinstance( - value, stype - ), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}" - # TODO: test values + + if dtype == xp.complex64: + assert out.dtype == xp.float32 + elif dtype == xp.complex128: + assert out.dtype == xp.float64 + else: + assert out.dtype == dtype + + # Guard vs. numpy.dtype.__eq__ lax comparison + assert not isinstance(out.dtype, str) + assert out.dtype is not float + assert out.dtype is not complex -@pytest.mark.parametrize("dtype", dh.all_int_dtypes, ids=make_dtype_id) +@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes) def test_iinfo(dtype): + for arg in ( + dtype, + xp.asarray(1, dtype=dtype), + # np.int64 and np.asarray(1, dtype=np.int64).dtype are different + xp.asarray(1, dtype=dtype).dtype, + ): + out = xp.iinfo(arg) + assert isinstance(out.bits, int) + assert isinstance(out.max, int) + assert isinstance(out.min, int) + + +@pytest.mark.min_version("2022.12") +@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes) +def test_iinfo_dtype(dtype): out = xp.iinfo(dtype) - f_func = f"[iinfo({dh.dtype_to_name[dtype]})]" - for attr in ["bits", "max", "min"]: - assert hasattr(out, attr), f"out has no attribute '{attr}' {f_func}" - value = getattr(out, attr) - assert isinstance( - value, int - ), f"type(out.{attr})={type(value)!r}, but should be int {f_func}" - # TODO: test values + assert out.dtype == dtype + # Guard vs. numpy.dtype.__eq__ lax comparison + assert not isinstance(out.dtype, str) + assert out.dtype is not int + + +def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]: + return hh.all_dtypes | st.sampled_from(list(dh.kind_to_dtypes.keys())) + + +@pytest.mark.min_version("2022.12") +@given( + dtype=hh.all_dtypes, + kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple), +) +def test_isdtype(dtype, kind): + out = xp.isdtype(dtype, kind) + + assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]" + _kinds = kind if isinstance(kind, tuple) else (kind,) + expected = False + for _kind in _kinds: + if isinstance(_kind, str): + if dtype in dh.kind_to_dtypes[_kind]: + expected = True + break + else: + if dtype == _kind: + expected = True + break + assert out == expected, f"{out=}, but should be {expected} [isdtype()]" + + +@pytest.mark.min_version("2024.12") +class TestResultType: + @given(dtypes=hh.mutually_promotable_dtypes(None)) + def test_result_type(self, dtypes): + out = xp.result_type(*dtypes) + ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out") + + @given(pair=hh.pair_of_mutually_promotable_dtypes(None)) + def test_shuffled(self, pair): + """Test that result_type is insensitive to the order of arguments.""" + s1, s2 = pair + out1 = xp.result_type(*s1) + out2 = xp.result_type(*s2) + assert out1 == out2 + + @given(pair=hh.pair_of_mutually_promotable_dtypes(2), data=st.data()) + def test_arrays_and_dtypes(self, pair, data): + s1, s2 = pair + a2 = tuple(xp.empty(1, dtype=dt) for dt in s2) + a_and_dt = data.draw(st.permutations(s1 + a2)) + out = xp.result_type(*a_and_dt) + ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out") + + @given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data()) + def test_with_scalars(self, dtypes, data): + out = xp.result_type(*dtypes) + + if out == xp.bool: + scalars = [True] + elif out in dh.all_int_dtypes: + scalars = [1] + elif out in dh.real_dtypes: + scalars = [1, 1.0] + elif out in dh.numeric_dtypes: + scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types + else: + raise ValueError(f"unknown dtype {out = }.") + + scalar = data.draw(st.sampled_from(scalars)) + inputs = data.draw(st.permutations(dtypes + (scalar,))) + + out_scalar = xp.result_type(*inputs) + assert out_scalar == out + # retry with arrays + arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes) + inputs = data.draw(st.permutations(arrays + (scalar,))) + out_scalar = xp.result_type(*inputs) + assert out_scalar == out -@given(hh.mutually_promotable_dtypes(None)) -def test_result_type(dtypes): - out = xp.result_type(*dtypes) - ph.assert_dtype("result_type", dtypes, out, repr_name="out") diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py new file mode 100644 index 00000000..358a8eef --- /dev/null +++ b/array_api_tests/test_fft.py @@ -0,0 +1,307 @@ +import math +from typing import List, Optional + +import pytest +from hypothesis import assume, given +from hypothesis import strategies as st + +from array_api_tests.typing import Array + +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh +from . import xp + +pytestmark = [ + pytest.mark.xp_extension("fft"), + pytest.mark.min_version("2022.12"), +] + +fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1) + + +def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: + size = math.prod(x.shape) + n = data.draw( + st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n" + ) + axis = data.draw(st.integers(-1, x.ndim - 1), label="axis") + if size_gt_1: + _axis = x.ndim - 1 if axis == -1 else axis + assume(x.shape[_axis] > 1) + norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") + kwargs = data.draw( + hh.specified_kwargs( + ("n", n, None), + ("axis", axis, -1), + ("norm", norm, "backward"), + ), + label="kwargs", + ) + return n, axis, norm, kwargs + + +def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple: + all_axes = list(range(x.ndim)) + axes = data.draw( + st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True), + label="axes", + ) + _axes = all_axes if axes is None else axes + axes_sides = [x.shape[axis] for axis in _axes] + s_strat = st.tuples( + *[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides] + ) + if axes is None: + s_strat = st.none() | s_strat + s = data.draw(s_strat, label="s") + + # Using `axes is None and s is not None` is disallowed by the spec + assume(axes is not None or s is None) + + norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm") + kwargs = data.draw( + hh.specified_kwargs( + ("s", s, None), + ("axes", axes, None), + ("norm", norm, "backward"), + ), + label="kwargs", + ) + return s, axes, norm, kwargs + + +def assert_n_axis_shape( + func_name: str, + *, + x: Array, + n: Optional[int], + axis: int, + out: Array, +): + _axis = len(x.shape) - 1 if axis == -1 else axis + if n is None: + axis_side = x.shape[_axis] + else: + axis_side = n + expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape(func_name, out_shape=out.shape, expected=expected) + + +def assert_s_axes_shape( + func_name: str, + *, + x: Array, + s: Optional[List[int]], + axes: Optional[List[int]], + out: Array, +): + _axes = sh.normalize_axis(axes, x.ndim) + _s = x.shape if s is None else s + expected = [] + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + expected.append(side) + ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected)) + + +@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_fft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.fft(x, **kwargs) + + ph.assert_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype) + assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out) + + +@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_ifft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.ifft(x, **kwargs) + + ph.assert_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype) + assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out) + + +@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_fftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.fftn(x, **kwargs) + + ph.assert_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out) + + +@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_ifftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.ifftn(x, **kwargs) + + ph.assert_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype) + assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out) + + +@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_rfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.rfft(x, **kwargs) + + ph.assert_float_to_complex_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype) + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = x.shape[_axis] // 2 + 1 + else: + axis_side = n // 2 + 1 + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("rfft", out_shape=out.shape, expected=expected_shape) + + +@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_irfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) + + out = xp.fft.irfft(x, **kwargs) + + ph.assert_dtype( + "irfft", + in_dtype=x.dtype, + out_dtype=out.dtype, + expected=dh.dtype_components[x.dtype], + ) + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = n + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape) + + +@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_rfftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.rfftn(x, **kwargs) + + ph.assert_float_to_complex_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype) + + _axes = sh.normalize_axis(axes, x.ndim) + _s = x.shape if s is None else s + expected = [] + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + expected.append(side) + expected[_axes[-1]] = _s[-1] // 2 + 1 + ph.assert_shape("rfftn", out_shape=out.shape, expected=tuple(expected)) + + +@given( + x=hh.arrays( + dtype=hh.complex_dtypes, shape=fft_shapes_strat.filter(lambda s: s[-1] > 1) + ), + data=st.data(), +) +def test_irfftn(x, data): + s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data) + + out = xp.fft.irfftn(x, **kwargs) + + ph.assert_dtype( + "irfftn", + in_dtype=x.dtype, + out_dtype=out.dtype, + expected=dh.dtype_components[x.dtype], + ) + + _axes = sh.normalize_axis(axes, x.ndim) + _s = x.shape if s is None else s + expected = [] + for i in range(x.ndim): + if i in _axes: + side = _s[_axes.index(i)] + else: + side = x.shape[i] + expected.append(side) + expected[_axes[-1]] = 2*(_s[-1] - 1) if s is None else _s[-1] + ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected)) + + +@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_hfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True) + + out = xp.fft.hfft(x, **kwargs) + + ph.assert_dtype( + "hfft", + in_dtype=x.dtype, + out_dtype=out.dtype, + expected=dh.dtype_components[x.dtype], + ) + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = 2 * (x.shape[_axis] - 1) + else: + axis_side = n + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape) + + +@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data()) +def test_ihfft(x, data): + n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data) + + out = xp.fft.ihfft(x, **kwargs) + + ph.assert_float_to_complex_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype) + + _axis = x.ndim - 1 if axis == -1 else axis + if n is None: + axis_side = x.shape[_axis] // 2 + 1 + else: + axis_side = n // 2 + 1 + expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :] + ph.assert_shape("ihfft", out_shape=out.shape, expected=expected_shape) + + +@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5))) +def test_fftfreq(n, kw): + out = xp.fft.fftfreq(n, **kw) + ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n}) + + +@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5))) +def test_rfftfreq(n, kw): + out = xp.fft.rfftfreq(n, **kw) + ph.assert_shape( + "rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n} + ) + + +@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"]) +@given(x=hh.arrays(hh.floating_dtypes, fft_shapes_strat), data=st.data()) +def test_shift_func(func_name, x, data): + func = getattr(xp.fft, func_name) + axes = data.draw( + st.none() + | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True), + label="axes", + ) + out = func(x, axes=axes) + ph.assert_dtype(func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(func_name, out_shape=out.shape, expected=x.shape) diff --git a/array_api_tests/test_has_names.py b/array_api_tests/test_has_names.py index d9194d82..8e934781 100644 --- a/array_api_tests/test_has_names.py +++ b/array_api_tests/test_has_names.py @@ -5,7 +5,7 @@ import pytest -from ._array_module import mod as xp, mod_name +from . import xp from .stubs import (array_attributes, array_methods, category_to_funcs, extension_to_funcs, EXTENSIONS) @@ -25,13 +25,13 @@ def test_has_names(category, name): if category in EXTENSIONS: ext_mod = getattr(xp, category) - assert hasattr(ext_mod, name), f"{mod_name} is missing the {category} extension function {name}()" + assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()" elif category.startswith('array_'): # TODO: This would fail if ones() is missing. arr = xp.ones((1, 1)) if category == 'array_attribute': - assert hasattr(arr, name), f"The {mod_name} array object is missing the attribute {name}" + assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}" else: - assert hasattr(arr, name), f"The {mod_name} array object is missing the method {name}()" + assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()" else: - assert hasattr(xp, name), f"{mod_name} is missing the {category} function {name}()" + assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()" diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py new file mode 100644 index 00000000..a599d218 --- /dev/null +++ b/array_api_tests/test_indexing_functions.py @@ -0,0 +1,122 @@ +import pytest +from hypothesis import given, note +from hypothesis import strategies as st + +from . import _array_module as xp +from . import dtype_helpers as dh +from . import hypothesis_helpers as hh +from . import pytest_helpers as ph +from . import shape_helpers as sh + + +@pytest.mark.unvectorized +@pytest.mark.min_version("2022.12") +@given( + x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_take(x, data): + # TODO: + # * negative axis + # * negative indices + # * different dtypes for indices + axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis") + _indices = data.draw( + st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True), + label="_indices", + ) + indices = xp.asarray(_indices, dtype=dh.default_int) + note(f"{indices=}") + + out = xp.take(x, indices, axis=axis) + + ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape( + "take", + out_shape=out.shape, + expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :], + kw=dict( + x=x, + indices=indices, + axis=axis, + ), + ) + out_indices = sh.ndindex(out.shape) + axis_indices = list(sh.axis_ndindex(x.shape, axis)) + for axis_idx in axis_indices: + f_axis_idx = sh.fmt_idx("x", axis_idx) + for i in _indices: + f_take_idx = sh.fmt_idx(f_axis_idx, i) + indexed_x = x[axis_idx][i, ...] + for at_idx in sh.ndindex(indexed_x.shape): + out_idx = next(out_indices) + ph.assert_0d_equals( + "take", + x_repr=sh.fmt_idx(f_take_idx, at_idx), + x_val=indexed_x[at_idx], + out_repr=sh.fmt_idx("out", out_idx), + out_val=out[out_idx], + ) + # sanity check + with pytest.raises(StopIteration): + next(out_indices) + + + +@pytest.mark.unvectorized +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_take_along_axis(x, data): + # TODO + # 2. negative indices + # 3. different dtypes for indices + # 4. "broadcast-compatible" indices + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + new_len = data.draw(st.integers(0, 2*x.shape[n_axis]), label="new_len") + idx_shape = x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:] + indices = data.draw( + hh.arrays( + shape=idx_shape, + dtype=dh.default_int, + elements={"min_value": 0, "max_value": x.shape[n_axis]-1} + ), + label="indices" + ) + note(f"{indices=} {idx_shape=}") + + out = xp.take_along_axis(x, indices, **axis_kw) + + ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape( + "take_along_axis", + out_shape=out.shape, + expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:], + kw=dict( + x=x, + indices=indices, + axis=axis, + ), + ) + + # value test: notation is from `np.take_along_axis` docstring + Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:] + for ii in sh.ndindex(Ni): + for kk in sh.ndindex(Nk): + a_1d = x[ii + (slice(None),) + kk] + i_1d = indices[ii + (slice(None),) + kk] + o_1d = out[ii + (slice(None),) + kk] + for j in range(new_len): + assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}' diff --git a/array_api_tests/test_inspection_functions.py b/array_api_tests/test_inspection_functions.py new file mode 100644 index 00000000..d210535e --- /dev/null +++ b/array_api_tests/test_inspection_functions.py @@ -0,0 +1,81 @@ +import pytest +from hypothesis import given, strategies as st +from array_api_tests.dtype_helpers import available_kinds, dtype_names + +from . import xp + +pytestmark = pytest.mark.min_version("2023.12") + + +class TestInspection: + def test_capabilities(self): + out = xp.__array_namespace_info__() + + capabilities = out.capabilities() + assert isinstance(capabilities, dict) + + expected_attr = {"boolean indexing": bool, "data-dependent shapes": bool} + if xp.__array_api_version__ >= "2024.12": + expected_attr.update(**{"max dimensions": type(None) | int}) + + for attr, typ in expected_attr.items(): + assert attr in capabilities, f'capabilites is missing "{attr}".' + assert isinstance(capabilities[attr], typ) + + max_dims = capabilities.get("max dimensions", 100500) + assert (max_dims is None) or (max_dims > 0) + + def test_devices(self): + out = xp.__array_namespace_info__() + + assert hasattr(out, "devices") + assert hasattr(out, "default_device") + + assert isinstance(out.devices(), list) + if out.default_device() is not None: + # Per https://github.com/data-apis/array-api/issues/923 + # default_device() can return None. Otherwise, it must be a valid device. + assert out.default_device() in out.devices() + + def test_default_dtypes(self): + out = xp.__array_namespace_info__() + + for device in xp.__array_namespace_info__().devices(): + default_dtypes = out.default_dtypes(device=device) + assert isinstance(default_dtypes, dict) + expected_subset = ( + {"real floating", "complex floating", "integral"} + & available_kinds() + | {"indexing"} + ) + assert expected_subset.issubset(set(default_dtypes.keys())) + + +atomic_kinds = [ + "bool", + "signed integer", + "unsigned integer", + "real floating", + "complex floating", +] + + +@given( + kind=st.one_of( + st.none(), + st.sampled_from(atomic_kinds + ["integral", "numeric"]), + st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple), + ), + device=st.one_of( + st.none(), + st.sampled_from(xp.__array_namespace_info__().devices()) + ) +) +def test_array_namespace_info_dtypes(kind, device): + out = xp.__array_namespace_info__().dtypes(kind=kind, device=device) + assert isinstance(out, dict) + + for name, dtyp in out.items(): + assert name in dtype_names + xp.empty(1, dtype=dtyp, device=device) # check `dtyp` is a valid dtype + diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 321263d3..6f4608da 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -12,38 +12,60 @@ required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25). """ - import pytest from hypothesis import assume, given -from hypothesis.strategies import (booleans, composite, none, tuples, integers, - shared, sampled_from, one_of, data, just) +from hypothesis.strategies import (booleans, composite, tuples, floats, + integers, shared, sampled_from, one_of, + data) from ndindex import iter_indices +import math +import itertools +from typing import Tuple + from .array_helpers import assert_exactly_equal, asarray -from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, - square_matrix_shapes, symmetric_matrices, +from .hypothesis_helpers import (arrays, all_floating_dtypes, all_dtypes, + numeric_dtypes, xps, shapes, kwargs, + matrix_shapes, square_matrix_shapes, + symmetric_matrices, SearchStrategy, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, mutually_promotable_dtypes, one_d_shapes, two_mutually_broadcastable_shapes, + mutually_broadcastable_shapes, SQRT_MAX_ARRAY_SIZE, finite_matrices, - rtol_shared_matrix_shapes, rtols) + rtol_shared_matrix_shapes, rtols, axes) from . import dtype_helpers as dh from . import pytest_helpers as ph from . import shape_helpers as sh +from . import api_version +from .typing import Array from . import _array_module from . import _array_module as xp from ._array_module import linalg -pytestmark = pytest.mark.ci -# Standin strategy for not yet implemented tests -todo = none() +def assert_equal(x, y, msg_extra=None): + extra = '' if not msg_extra else f' ({msg_extra})' + if x.dtype in dh.all_float_dtypes: + # It's too difficult to do an approximately equal test here because + # different routines can give completely different answers, and even + # when it does work, the elementwise comparisons are too slow. So for + # floating-point dtypes only test the shape and dtypes. + + # assert_allclose(x, y) + + assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}" + assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}" + else: + assert_exactly_equal(x, y, msg_extra=msg_extra) + def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), - assert_equal=assert_exactly_equal, **kw): + res_axes=None, + assert_equal=assert_equal, **kw): """ Test that f(*args, **kw) maps across stacks of matrices @@ -67,7 +89,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, # Assume the result is stacked along the last 'dims' axes of matrix_axes. # This holds for all the functions tested in this file - res_axes = matrix_axes[::-1][:dims] + if res_axes is None: + if not isinstance(matrix_axes, tuple) and all(isinstance(x, int) for x in matrix_axes): + raise ValueError("res_axes must be specified if matrix_axes is not a tuple of integers") + res_axes = matrix_axes[::-1][:dims] for (x_idxes, (res_idx,)) in zip( iter_indices(*shapes, skip_axes=matrix_axes), @@ -78,9 +103,11 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, res_stack = res[res_idx] x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)] decomp_res_stack = f(*x_stacks, **kw) - assert_equal(res_stack, decomp_res_stack) + msg_extra = f'{x_idxes = }, {res_idx = }' + assert_equal(res_stack, decomp_res_stack, msg_extra) if true_val: - assert_equal(decomp_res_stack, true_val(*x_stacks)) + assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra) + def _test_namedtuple(res, fields, func_name): """ @@ -91,11 +118,14 @@ def _test_namedtuple(res, fields, func_name): # a tuple subclass with the right fields in the right order. assert isinstance(res, tuple), f"{func_name}() did not return a tuple" + assert type(res) != tuple, f"{func_name}() did not return a namedtuple" assert len(res) == len(fields), f"{func_name}() result tuple not the correct length (should have {len(fields)} elements)" for i, field in enumerate(fields): assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field" assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}" + +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( x=positive_definite_matrices(), @@ -104,8 +134,9 @@ def _test_namedtuple(res, fields, func_name): def test_cholesky(x, kw): res = linalg.cholesky(x, **kw) - assert res.shape == x.shape, "cholesky() did not return the correct shape" - assert res.dtype == x.dtype, "cholesky() did not return the correct dtype" + ph.assert_dtype("cholesky", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("cholesky", in_shapes=[x.shape], + out_shape=res.shape, expected=x.shape) _test_stacks(linalg.cholesky, x, **kw, res=res) @@ -117,7 +148,7 @@ def test_cholesky(x, kw): @composite -def cross_args(draw, dtype_objects=dh.numeric_dtypes): +def cross_args(draw, dtype_objects=dh.real_dtypes): """ cross() requires two arrays with a size 3 in the 'axis' dimension @@ -125,26 +156,33 @@ def cross_args(draw, dtype_objects=dh.numeric_dtypes): in the drawn axis. """ - shape = list(draw(shapes())) - size = len(shape) - assume(size > 0) + shape1, shape2 = draw(two_mutually_broadcastable_shapes) + min_ndim = min(len(shape1), len(shape2)) + assume(min_ndim > 0) - kw = draw(kwargs(axis=integers(-size, size-1))) + kw = draw(kwargs(axis=integers(-min_ndim, -1))) axis = kw.get('axis', -1) - shape[axis] = 3 - shape = tuple(shape) + if draw(booleans()): + # Sometimes allow invalid inputs to test it errors + shape1 = list(shape1) + shape1[axis] = 3 + shape1 = tuple(shape1) + shape2 = list(shape2) + shape2[axis] = 3 + shape2 = tuple(shape2) mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects)) - arrays1 = xps.arrays( + arrays1 = arrays( dtype=mutual_dtypes.map(lambda pair: pair[0]), - shape=shape, + shape=shape1, ) - arrays2 = xps.arrays( + arrays2 = arrays( dtype=mutual_dtypes.map(lambda pair: pair[1]), - shape=shape, + shape=shape2, ) return draw(arrays1), draw(arrays2), kw +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( cross_args() @@ -153,46 +191,52 @@ def test_cross(x1_x2_kw): x1, x2, kw = x1_x2_kw axis = kw.get('axis', -1) - err = "test_cross produced invalid input. This indicates a bug in the test suite." - assert x1.shape == x2.shape, err - shape = x1.shape - assert x1.shape[axis] == x2.shape[axis] == 3, err + if not (x1.shape[axis] == x2.shape[axis] == 3): + ph.raises(Exception, lambda: xp.cross(x1, x2, **kw), + "cross did not raise an exception for invalid shapes") + return res = linalg.cross(x1, x2, **kw) - assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype" - assert res.shape == shape, "cross() did not return the correct shape" + broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape) + + ph.assert_dtype("cross", in_dtype=[x1.dtype, x2.dtype], + out_dtype=res.dtype) + ph.assert_result_shape("cross", in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=broadcasted_shape) def exact_cross(a, b): assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." - return asarray([ + return asarray(xp.stack([ a[1]*b[2] - a[2]*b[1], a[2]*b[0] - a[0]*b[2], a[0]*b[1] - a[1]*b[0], - ], dtype=res.dtype) + ]), dtype=res.dtype) # We don't want to pass in **kw here because that would pass axis to # cross() on a single stack, but the axis is not meaningful on unstacked # vectors. _test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross) +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes), + x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes), ) def test_det(x): res = linalg.det(x) - assert res.dtype == x.dtype, "det() did not return the correct dtype" - assert res.shape == x.shape[:-2], "det() did not return the correct shape" + ph.assert_dtype("det", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("det", in_shapes=[x.shape], out_shape=res.shape, + expected=x.shape[:-2]) _test_stacks(linalg.det, x, res=res, dims=0) # TODO: Test that res actually corresponds to the determinant of x +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=dtypes, shape=matrix_shapes()), + x=arrays(dtype=all_dtypes, shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -200,7 +244,7 @@ def test_det(x): def test_diagonal(x, kw): res = linalg.diagonal(x, **kw) - assert res.dtype == x.dtype, "diagonal() returned the wrong dtype" + ph.assert_dtype("diagonal", in_dtype=x.dtype, out_dtype=res.dtype) n, m = x.shape[-2:] offset = kw.get('offset', 0) @@ -214,18 +258,20 @@ def test_diagonal(x, kw): else: diag_size = min(n, m, max(m - offset, 0)) - assert res.shape == (*x.shape[:-2], diag_size), "diagonal() returned the wrong shape" + expected_shape = (*x.shape[:-2], diag_size) + ph.assert_result_shape("diagonal", in_shapes=[x.shape], + out_shape=res.shape, expected=expected_shape) - def true_diag(x_stack): + def true_diag(x_stack, offset=0): if offset >= 0: x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)] else: x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)] - return asarray(x_stack_diag, dtype=x.dtype) + return asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype) _test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag) -@pytest.mark.skip(reason="Inputs need to be restricted") # TODO +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given(x=symmetric_matrices(finite=True)) def test_eigh(x): @@ -236,50 +282,72 @@ def test_eigh(x): eigenvalues = res.eigenvalues eigenvectors = res.eigenvectors - assert eigenvalues.dtype == x.dtype, "eigh().eigenvalues did not return the correct dtype" - assert eigenvalues.shape == x.shape[:-1], "eigh().eigenvalues did not return the correct shape" - - assert eigenvectors.dtype == x.dtype, "eigh().eigenvectors did not return the correct dtype" - assert eigenvectors.shape == x.shape, "eigh().eigenvectors did not return the correct shape" - + ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvalues.dtype, + expected=x.dtype, repr_name="eigenvalues.dtype") + ph.assert_result_shape("eigh", in_shapes=[x.shape], + out_shape=eigenvalues.shape, + expected=x.shape[:-1], + repr_name="eigenvalues.shape") + + ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvectors.dtype, + expected=x.dtype, repr_name="eigenvectors.dtype") + ph.assert_result_shape("eigh", in_shapes=[x.shape], + out_shape=eigenvectors.shape, expected=x.shape, + repr_name="eigenvectors.shape") + + # Note: _test_stacks here is only testing the shape and dtype. The actual + # eigenvalues and eigenvectors may not be equal at all, since there is not + # requirements about how eigh computes an eigenbasis, or about the order + # of the eigenvalues _test_stacks(lambda x: linalg.eigh(x).eigenvalues, x, res=eigenvalues, dims=1) + + # TODO: Test that eigenvectors are orthonormal. + _test_stacks(lambda x: linalg.eigh(x).eigenvectors, x, res=eigenvectors, dims=2) # TODO: Test that res actually corresponds to the eigenvalues and # eigenvectors of x +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given(x=symmetric_matrices(finite=True)) def test_eigvalsh(x): res = linalg.eigvalsh(x) - assert res.dtype == x.dtype, "eigvalsh() did not return the correct dtype" - assert res.shape == x.shape[:-1], "eigvalsh() did not return the correct shape" + ph.assert_dtype("eigvalsh", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("eigvalsh", in_shapes=[x.shape], + out_shape=res.shape, expected=x.shape[:-1]) + # Note: _test_stacks here is only testing the shape and dtype. The actual + # eigenvalues may not be equal at all, since there is not requirements or + # about the order of the eigenvalues, and the stacking code may use a + # different code path. _test_stacks(linalg.eigvalsh, x, res=res, dims=1) # TODO: Should we test that the result is the same as eigh(x).eigenvalues? + # (probably no because the spec doesn't actually require that) # TODO: Test that res actually corresponds to the eigenvalues of x +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given(x=invertible_matrices()) def test_inv(x): res = linalg.inv(x) - assert res.shape == x.shape, "inv() did not return the correct shape" - assert res.dtype == x.dtype, "inv() did not return the correct dtype" + ph.assert_dtype("inv", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("inv", in_shapes=[x.shape], out_shape=res.shape, + expected=x.shape) _test_stacks(linalg.inv, x, res=res) # TODO: Test that the result is actually the inverse -@given( - *two_mutual_arrays(dh.numeric_dtypes) -) -def test_matmul(x1, x2): +def _test_matmul(namespace, x1, x2): + matmul = namespace.matmul + # TODO: Make this also test the @ operator if (x1.shape == () or x2.shape == () or len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape @@ -292,25 +360,47 @@ def test_matmul(x1, x2): "matmul did not raise an exception for invalid shapes") return else: - res = _array_module.matmul(x1, x2) + res = matmul(x1, x2) - ph.assert_dtype("matmul", [x1.dtype, x2.dtype], res.dtype) + ph.assert_dtype("matmul", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype) if len(x1.shape) == len(x2.shape) == 1: - assert res.shape == () + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=()) elif len(x1.shape) == 1: - assert res.shape == x2.shape[:-2] + x2.shape[-1:] - _test_stacks(_array_module.matmul, x1, x2, res=res, dims=1) + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, + expected=x2.shape[:-2] + x2.shape[-1:]) + _test_stacks(matmul, x1, x2, res=res, dims=1, + matrix_axes=[(0,), (-2, -1)], res_axes=[-1]) elif len(x2.shape) == 1: - assert res.shape == x1.shape[:-1] - _test_stacks(_array_module.matmul, x1, x2, res=res, dims=1) + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=x1.shape[:-1]) + _test_stacks(matmul, x1, x2, res=res, dims=1, + matrix_axes=[(-2, -1), (0,)], res_axes=[-1]) else: stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) - assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1]) - _test_stacks(_array_module.matmul, x1, x2, res=res) + ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, + expected=stack_shape + (x1.shape[-2], x2.shape[-1])) + _test_stacks(matmul, x1, x2, res=res) -matrix_norm_shapes = shared(matrix_shapes()) +@pytest.mark.unvectorized +@pytest.mark.xp_extension('linalg') +@given( + *two_mutual_arrays(dh.real_dtypes) +) +def test_linalg_matmul(x1, x2): + return _test_matmul(linalg, x1, x2) +@pytest.mark.unvectorized +@given( + *two_mutual_arrays(dh.real_dtypes) +) +def test_matmul(x1, x2): + return _test_matmul(_array_module, x1, x2) + +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( x=finite_matrices(), @@ -328,26 +418,30 @@ def test_matrix_norm(x, kw): expected_shape = x.shape[:-2] + (1, 1) else: expected_shape = x.shape[:-2] - assert res.shape == expected_shape, f"matrix_norm({keepdims=}) did not return the correct shape" - assert res.dtype == x.dtype, "matrix_norm() did not return the correct dtype" + ph.assert_complex_to_float_dtype("matrix_norm", in_dtype=x.dtype, + out_dtype=res.dtype) + ph.assert_result_shape("matrix_norm", in_shapes=[x.shape], + out_shape=res.shape, expected=expected_shape) _test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0, res=res) -matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n') +matrix_power_n = shared(integers(-100, 100), key='matrix_power n') +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( # Generate any square matrix if n >= 0 but only invertible matrices if n < 0 x=matrix_power_n.flatmap(lambda n: invertible_matrices() if n < 0 else - xps.arrays(dtype=xps.floating_dtypes(), + arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes)), n=matrix_power_n, ) def test_matrix_power(x, n): res = linalg.matrix_power(x, n) - assert res.shape == x.shape, "matrix_power() did not return the correct shape" - assert res.dtype == x.dtype, "matrix_power() did not return the correct dtype" + ph.assert_dtype("matrix_power", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("matrix_power", in_shapes=[x.shape], + out_shape=res.shape, expected=x.shape) if n == 0: true_val = lambda x: _array_module.eye(x.shape[0], dtype=x.dtype) @@ -357,6 +451,7 @@ def test_matrix_power(x, n): func = lambda x: linalg.matrix_power(x, n) _test_stacks(func, x, res=res, true_val=true_val) +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( x=finite_matrices(shape=rtol_shared_matrix_shapes), @@ -365,11 +460,9 @@ def test_matrix_power(x, n): def test_matrix_rank(x, kw): linalg.matrix_rank(x, **kw) -@given( - x=xps.arrays(dtype=dtypes, shape=matrix_shapes()), -) -def test_matrix_transpose(x): - res = _array_module.matrix_transpose(x) +def _test_matrix_transpose(namespace, x): + matrix_transpose = namespace.matrix_transpose + res = matrix_transpose(x) true_val = lambda a: _array_module.asarray([[a[i, j] for i in range(a.shape[0])] for j in range(a.shape[1])], @@ -377,14 +470,30 @@ def test_matrix_transpose(x): shape = list(x.shape) shape[-1], shape[-2] = shape[-2], shape[-1] shape = tuple(shape) - assert res.shape == shape, "matrix_transpose() did not return the correct shape" - assert res.dtype == x.dtype, "matrix_transpose() did not return the correct dtype" + ph.assert_dtype("matrix_transpose", in_dtype=x.dtype, out_dtype=res.dtype) + ph.assert_result_shape("matrix_transpose", in_shapes=[x.shape], + out_shape=res.shape, expected=shape) + + _test_stacks(matrix_transpose, x, res=res, true_val=true_val) - _test_stacks(_array_module.matrix_transpose, x, res=res, true_val=true_val) +@pytest.mark.unvectorized +@pytest.mark.xp_extension('linalg') +@given( + x=arrays(dtype=all_dtypes, shape=matrix_shapes()), +) +def test_linalg_matrix_transpose(x): + return _test_matrix_transpose(linalg, x) + +@pytest.mark.unvectorized +@given( + x=arrays(dtype=all_dtypes, shape=matrix_shapes()), +) +def test_matrix_transpose(x): + return _test_matrix_transpose(_array_module, x) @pytest.mark.xp_extension('linalg') @given( - *two_mutual_arrays(dtypes=dh.numeric_dtypes, + *two_mutual_arrays(dtypes=dh.real_dtypes, two_shapes=tuples(one_d_shapes, one_d_shapes)) ) def test_outer(x1, x2): @@ -393,8 +502,9 @@ def test_outer(x1, x2): res = linalg.outer(x1, x2) shape = (x1.shape[0], x2.shape[0]) - assert res.shape == shape, "outer() did not return the correct shape" - assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "outer() did not return the correct dtype" + ph.assert_dtype("outer", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype) + ph.assert_result_shape("outer", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=shape) if 0 in shape: true_res = _array_module.empty(shape, dtype=res.dtype) @@ -414,9 +524,10 @@ def test_outer(x1, x2): def test_pinv(x, kw): linalg.pinv(x, **kw) +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=matrix_shapes()), + x=arrays(dtype=all_floating_dtypes(), shape=matrix_shapes()), kw=kwargs(mode=sampled_from(['reduced', 'complete'])) ) def test_qr(x, kw): @@ -430,17 +541,23 @@ def test_qr(x, kw): Q = res.Q R = res.R - assert Q.dtype == x.dtype, "qr().Q did not return the correct dtype" + ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=Q.dtype, + expected=x.dtype, repr_name="Q.dtype") if mode == 'complete': - assert Q.shape == x.shape[:-2] + (M, M), "qr().Q did not return the correct shape" + expected_Q_shape = x.shape[:-2] + (M, M) else: - assert Q.shape == x.shape[:-2] + (M, K), "qr().Q did not return the correct shape" + expected_Q_shape = x.shape[:-2] + (M, K) + ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=Q.shape, + expected=expected_Q_shape, repr_name="Q.shape") - assert R.dtype == x.dtype, "qr().R did not return the correct dtype" + ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=R.dtype, + expected=x.dtype, repr_name="R.dtype") if mode == 'complete': - assert R.shape == x.shape[:-2] + (M, N), "qr().R did not return the correct shape" + expected_R_shape = x.shape[:-2] + (M, N) else: - assert R.shape == x.shape[:-2] + (K, N), "qr().R did not return the correct shape" + expected_R_shape = x.shape[:-2] + (K, N) + ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=R.shape, + expected=expected_R_shape, repr_name="R.shape") _test_stacks(lambda x: linalg.qr(x, **kw).Q, x, res=Q) _test_stacks(lambda x: linalg.qr(x, **kw).R, x, res=R) @@ -450,9 +567,10 @@ def test_qr(x, kw): # Check that R is upper-triangular. assert_exactly_equal(R, _array_module.triu(R)) +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes), + x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes), ) def test_slogdet(x): res = linalg.slogdet(x) @@ -461,11 +579,19 @@ def test_slogdet(x): sign, logabsdet = res - assert sign.dtype == x.dtype, "slogdet().sign did not return the correct dtype" - assert sign.shape == x.shape[:-2], "slogdet().sign did not return the correct shape" - assert logabsdet.dtype == x.dtype, "slogdet().logabsdet did not return the correct dtype" - assert logabsdet.shape == x.shape[:-2], "slogdet().logabsdet did not return the correct shape" - + ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=sign.dtype, + expected=x.dtype, repr_name="sign.dtype") + ph.assert_result_shape("slogdet", in_shapes=[x.shape], + out_shape=sign.shape, + expected=x.shape[:-2], + repr_name="sign.shape") + expected_dtype = dh.as_real_dtype(x.dtype) + ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=logabsdet.dtype, + expected=expected_dtype, repr_name="logabsdet.dtype") + ph.assert_result_shape("slogdet", in_shapes=[x.shape], + out_shape=logabsdet.shape, + expected=x.shape[:-2], + repr_name="logabsdet.shape") _test_stacks(lambda x: linalg.slogdet(x).sign, x, res=sign, dims=0) @@ -487,7 +613,7 @@ def test_slogdet(x): # TODO: Test this when we have tests for floating-point values. # assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps) -def solve_args(): +def solve_args() -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]: """ Strategy for the x1 and x2 arguments to test_solve() @@ -495,26 +621,45 @@ def solve_args(): of shape (..., M, M), and x2 is either shape (M,) or (..., M, K), where the ... parts of x1 and x2 are broadcast compatible. """ + mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dh.all_float_dtypes)) + stack_shapes = shared(two_mutually_broadcastable_shapes) # Don't worry about dtypes since all floating dtypes are type promotable # with each other. - x1 = shared(invertible_matrices(stack_shapes=stack_shapes.map(lambda pair: - pair[0]))) + x1 = shared(invertible_matrices( + stack_shapes=stack_shapes.map(lambda pair: pair[0]), + dtypes=mutual_dtypes.map(lambda pair: pair[0]))) @composite def _x2_shapes(draw): - end = draw(integers(0, SQRT_MAX_ARRAY_SIZE)) - return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,) + base_shape = draw(stack_shapes)[1] + draw(x1).shape[-1:] + end = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(base_shape), 1))) + return base_shape + (end,) x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes()) - x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes) + x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1])) return x1, x2 +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given(*solve_args()) def test_solve(x1, x2): - linalg.solve(x1, x2) + res = linalg.solve(x1, x2) + + ph.assert_dtype("solve", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype) + if x2.ndim == 1: + expected_shape = x1.shape[:-2] + x2.shape[-1:] + _test_stacks(linalg.solve, x1, x2, res=res, dims=1, + matrix_axes=[(-2, -1), (0,)], res_axes=[-1]) + else: + stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2]) + expected_shape = stack_shape + x2.shape[-2:] + _test_stacks(linalg.solve, x1, x2, res=res, dims=2) + ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=expected_shape) + +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( x=finite_matrices(), @@ -531,17 +676,31 @@ def test_svd(x, kw): U, S, Vh = res - assert U.dtype == x.dtype, "svd().U did not return the correct dtype" - assert S.dtype == x.dtype, "svd().S did not return the correct dtype" - assert Vh.dtype == x.dtype, "svd().Vh did not return the correct dtype" + ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=U.dtype, + expected=x.dtype, repr_name="U.dtype") + ph.assert_complex_to_float_dtype("svd", in_dtype=x.dtype, + out_dtype=S.dtype, repr_name="S.dtype") + ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=Vh.dtype, + expected=x.dtype, repr_name="Vh.dtype") if full_matrices: - assert U.shape == (*stack, M, M), "svd().U did not return the correct shape" - assert Vh.shape == (*stack, N, N), "svd().Vh did not return the correct shape" + expected_U_shape = (*stack, M, M) + expected_Vh_shape = (*stack, N, N) else: - assert U.shape == (*stack, M, K), "svd(full_matrices=False).U did not return the correct shape" - assert Vh.shape == (*stack, K, N), "svd(full_matrices=False).Vh did not return the correct shape" - assert S.shape == (*stack, K), "svd().S did not return the correct shape" + expected_U_shape = (*stack, M, K) + expected_Vh_shape = (*stack, K, N) + ph.assert_result_shape("svd", in_shapes=[x.shape], + out_shape=U.shape, + expected=expected_U_shape, + repr_name="U.shape") + ph.assert_result_shape("svd", in_shapes=[x.shape], + out_shape=Vh.shape, + expected=expected_Vh_shape, + repr_name="Vh.shape") + ph.assert_result_shape("svd", in_shapes=[x.shape], + out_shape=S.shape, + expected=(*stack, K), + repr_name="S.shape") # The values of s must be sorted from largest to smallest if K >= 1: @@ -551,6 +710,7 @@ def test_svd(x, kw): _test_stacks(lambda x: linalg.svd(x, **kw).S, x, dims=1, res=S) _test_stacks(lambda x: linalg.svd(x, **kw).Vh, x, res=Vh) +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( x=finite_matrices(), @@ -561,8 +721,11 @@ def test_svdvals(x): *stack, M, N = x.shape K = min(M, N) - assert res.dtype == x.dtype, "svdvals() did not return the correct dtype" - assert res.shape == (*stack, K), "svdvals() did not return the correct shape" + ph.assert_complex_to_float_dtype("svdvals", in_dtype=x.dtype, + out_dtype=res.dtype) + ph.assert_result_shape("svdvals", in_shapes=[x.shape], + out_shape=res.shape, + expected=(*stack, K)) # SVD values must be sorted from largest to smallest assert _array_module.all(res[..., :-1] >= res[..., 1:]), "svdvals() values are not sorted from largest to smallest" @@ -571,26 +734,137 @@ def test_svdvals(x): # TODO: Check that svdvals() is the same as svd().s. +_tensordot_pre_shapes = shared(two_mutually_broadcastable_shapes) -@given( - dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), - shape=shapes(), - data=data(), -) -def test_tensordot(dtypes, shape, data): - # TODO: vary shapes, vary contracted axes, test different axes arguments - x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shape), label="x1") - x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shape), label="x2") +@composite +def _tensordot_axes(draw): + shape1, shape2 = draw(_tensordot_pre_shapes) + ndim1, ndim2 = len(shape1), len(shape2) + isint = draw(booleans()) + + if isint: + N = min(ndim1, ndim2) + return draw(integers(0, N)) + else: + if ndim1 < ndim2: + first = draw(xps.valid_tuple_axes(ndim1)) + second = draw(xps.valid_tuple_axes(ndim2, min_size=len(first), + max_size=len(first))) + else: + second = draw(xps.valid_tuple_axes(ndim2)) + first = draw(xps.valid_tuple_axes(ndim1, min_size=len(second), + max_size=len(second))) + return (tuple(first), tuple(second)) + +tensordot_kw = shared(kwargs(axes=_tensordot_axes())) + +@composite +def tensordot_shapes(draw): + _shape1, _shape2 = map(list, draw(_tensordot_pre_shapes)) + ndim1, ndim2 = len(_shape1), len(_shape2) + kw = draw(tensordot_kw) + if 'axes' not in kw: + assume(ndim1 >= 2 and ndim2 >= 2) + axes = kw.get('axes', 2) + + if isinstance(axes, int): + axes = [list(range(-axes, 0)), list(range(0, axes))] + + first, second = axes + for i, j in zip(first, second): + try: + if -ndim2 <= j < ndim2 and _shape2[j] != 1: + _shape1[i] = _shape2[j] + if -ndim1 <= i < ndim1 and _shape1[i] != 1: + _shape2[j] = _shape1[i] + except: + raise + + shape1, shape2 = map(tuple, [_shape1, _shape2]) + return (shape1, shape2) + +def _test_tensordot_stacks(x1, x2, kw, res): + """ + Variant of _test_stacks for tensordot + + tensordot doesn't stack directly along the non-contracted dimensions like + the other linalg functions. Rather, it is stacked along the product of + each non-contracted dimension. These dimensions are independent of one + another and do not broadcast. + """ + shape1, shape2 = x1.shape, x2.shape + + axes = kw.get('axes', 2) + + if isinstance(axes, int): + res_axes = axes + axes = [list(range(-axes, 0)), list(range(0, axes))] + else: + # Convert something like (0, 4, 2) into (0, 2, 1) + res_axes = [] + for a, s in zip(axes, [shape1, shape2]): + indices = [range(len(s))[i] for i in a] + repl = dict(zip(sorted(indices), range(len(indices)))) + res_axes.append(tuple(repl[i] for i in indices)) + res_axes = tuple(res_axes) + + for ((i,), (j,)), (res_idx,) in zip( + itertools.product( + iter_indices(shape1, skip_axes=axes[0]), + iter_indices(shape2, skip_axes=axes[1])), + iter_indices(res.shape)): + i, j, res_idx = i.raw, j.raw, res_idx.raw + + res_stack = res[res_idx] + x1_stack = x1[i] + x2_stack = x2[j] + decomp_res_stack = xp.tensordot(x1_stack, x2_stack, axes=res_axes) + assert_equal(res_stack, decomp_res_stack) - out = xp.tensordot(x1, x2, axes=len(shape)) +def _test_tensordot(namespace, x1, x2, kw): + tensordot = namespace.tensordot + res = tensordot(x1, x2, **kw) - ph.assert_dtype("tensordot", dtypes, out.dtype) - # TODO: assert shape and elements + ph.assert_dtype("tensordot", in_dtype=[x1.dtype, x2.dtype], + out_dtype=res.dtype) + axes = _axes = kw.get('axes', 2) + if isinstance(axes, int): + _axes = [list(range(-axes, 0)), list(range(0, axes))] + + _shape1 = list(x1.shape) + _shape2 = list(x2.shape) + for i, j in zip(*_axes): + _shape1[i] = _shape2[j] = None + _shape1 = tuple([i for i in _shape1 if i is not None]) + _shape2 = tuple([i for i in _shape2 if i is not None]) + result_shape = _shape1 + _shape2 + ph.assert_result_shape('tensordot', [x1.shape, x2.shape], res.shape, + expected=result_shape) + _test_tensordot_stacks(x1, x2, kw, res) + +@pytest.mark.unvectorized +@pytest.mark.xp_extension('linalg') +@given( + *two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()), + tensordot_kw, +) +def test_linalg_tensordot(x1, x2, kw): + _test_tensordot(linalg, x1, x2, kw) + +@pytest.mark.unvectorized +@given( + *two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()), + tensordot_kw, +) +def test_tensordot(x1, x2, kw): + _test_tensordot(_array_module, x1, x2, kw) + +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()), + x=arrays(dtype=numeric_dtypes, shape=matrix_shapes()), # offset may produce an overflow if it is too large. Supporting offsets # that are way larger than the array shape isn't very important. kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)) @@ -598,17 +872,21 @@ def test_tensordot(dtypes, shape, data): def test_trace(x, kw): res = linalg.trace(x, **kw) - # TODO: trace() should promote in some cases. See - # https://github.com/data-apis/array-api/issues/202. See also the dtype - # argument to sum() below. - - # assert res.dtype == x.dtype, "trace() returned the wrong dtype" + dtype = kw.get("dtype", None) + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/160 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(res.dtype) # sanity check + elif api_version < "2023.12": # TODO: update dtype assertion for >2023.12 - see #234 + ph.assert_dtype("trace", in_dtype=x.dtype, out_dtype=res.dtype, expected=expected_dtype) n, m = x.shape[-2:] - offset = kw.get('offset', 0) - assert res.shape == x.shape[:-2], "trace() returned the wrong shape" + ph.assert_result_shape('trace', x.shape, res.shape, expected=x.shape[:-2]) - def true_trace(x_stack): + def true_trace(x_stack, offset=0): # Note: the spec does not specify that offset must be within the # bounds of the matrix. A large offset should just produce a size 0 # diagonal in the last dimension (trace 0). See test_diagonal(). @@ -623,33 +901,122 @@ def true_trace(x_stack): x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)] else: x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)] - return _array_module.sum(asarray(x_stack_diag, dtype=x.dtype), dtype=x.dtype) + result = xp.asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype) + return _array_module.sum(result) + _test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace) +def _conj(x): + # XXX: replace with xp.dtype when all array libraries implement it + if x.dtype in (xp.complex64, xp.complex128): + return xp.conj(x) + else: + return x + + +def _test_vecdot(namespace, x1, x2, data): + vecdot = namespace.vecdot + broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape) + min_ndim = min(x1.ndim, x2.ndim) + ndim = len(broadcasted_shape) + kw = data.draw(kwargs(axis=integers(-min_ndim, -1))) + axis = kw.get('axis', -1) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + ph.raises(Exception, lambda: vecdot(x1, x2, **kw), + "vecdot did not raise an exception for invalid shapes") + return + expected_shape = list(broadcasted_shape) + expected_shape.pop(axis) + expected_shape = tuple(expected_shape) + + res = vecdot(x1, x2, **kw) + + ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype], + out_dtype=res.dtype) + ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape], + out_shape=res.shape, expected=expected_shape) + + def true_val(x, y, axis=-1): + return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype) + + _test_stacks(vecdot, x1, x2, res=res, dims=0, + matrix_axes=(axis,), true_val=true_val) + + +@pytest.mark.unvectorized +@pytest.mark.xp_extension('linalg') @given( - dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), - shape=shapes(min_dims=1), - data=data(), + *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)), + data(), +) +def test_linalg_vecdot(x1, x2, data): + _test_vecdot(linalg, x1, x2, data) + + +@pytest.mark.unvectorized +@given( + *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)), + data(), ) -def test_vecdot(dtypes, shape, data): - # TODO: vary shapes, test different axis arguments - x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shape), label="x1") - x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shape), label="x2") - kw = data.draw(kwargs(axis=just(-1))) +def test_vecdot(x1, x2, data): + _test_vecdot(_array_module, x1, x2, data) + + +@pytest.mark.xp_extension('linalg') +def test_vecdot_conj(): + # no-hypothesis test to check that the 1st argument is in fact conjugated + x1 = xp.asarray([1j, 2j, 3j]) + x2 = xp.asarray([1, 2j, 3]) + + import cmath + assert cmath.isclose(complex(xp.linalg.vecdot(x1, x2)), 4 - 10j) - out = xp.vecdot(x1, x2, **kw) - ph.assert_dtype("vecdot", dtypes, out.dtype) - # TODO: assert shape and elements +# Insanely large orders might not work. There isn't a limit specified in the +# spec, so we just limit to reasonable values here. +max_ord = 100 +@pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()), - kw=kwargs(axis=todo, keepdims=todo, ord=todo) + x=arrays(dtype=all_floating_dtypes(), shape=shapes(min_side=1)), + data=data(), ) -def test_vector_norm(x, kw): - # res = linalg.vector_norm(x, **kw) - pass +def test_vector_norm(x, data): + kw = data.draw( + # We use data because axes is parameterized on x.ndim + kwargs(axis=axes(x.ndim), + keepdims=booleans(), + ord=one_of( + sampled_from([2, 1, 0, -1, -2, float("inf"), float("-inf")]), + integers(-max_ord, max_ord), + floats(-max_ord, max_ord), + )), label="kw") + + + res = linalg.vector_norm(x, **kw) + axis = kw.get('axis', None) + keepdims = kw.get('keepdims', False) + # TODO: Check that the ord values give the correct norms. + # ord = kw.get('ord', 2) + + _axes = sh.normalize_axis(axis, x.ndim) + + ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape, + in_shape=x.shape, axes=_axes, + keepdims=keepdims, kw=kw) + expected_dtype = dh.as_real_dtype(x.dtype) + ph.assert_dtype('linalg.vector_norm', in_dtype=x.dtype, + out_dtype=res.dtype, expected=expected_dtype) + + _kw = kw.copy() + _kw.pop('axis', None) + _test_stacks(linalg.vector_norm, x, res=res, + dims=x.ndim if keepdims else 0, + matrix_axes=_axes, **_kw + ) diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py index b9d9e03d..754b507d 100644 --- a/array_api_tests/test_manipulation_functions.py +++ b/array_api_tests/test_manipulation_functions.py @@ -14,11 +14,6 @@ from . import xps from .typing import Array, Shape -pytestmark = pytest.mark.ci - -MAX_SIDE = hh.MAX_ARRAY_SIZE // 64 -MAX_DIMS = min(hh.MAX_ARRAY_SIZE // MAX_SIDE, 32) # NumPy only supports up to 32 dims - def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: key = "shape" @@ -32,11 +27,11 @@ def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]: def assert_array_ndindex( func_name: str, x: Array, + *, x_indices: Iterable[Union[int, Shape]], out: Array, out_indices: Iterable[Union[int, Shape]], - /, - **kw, + kw: dict = {}, ): msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}" for x_idx, out_idx in zip(x_indices, out_indices): @@ -48,6 +43,7 @@ def assert_array_ndindex( assert out[out_idx] == x[x_idx], msg +@pytest.mark.unvectorized @given( dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes), base_shape=hh.shapes(), @@ -67,17 +63,17 @@ def test_concat(dtypes, base_shape, data): shape_strat = hh.shapes() else: _axis = axis if axis >= 0 else len(base_shape) + axis - shape_strat = st.integers(0, MAX_SIDE).map( + shape_strat = st.integers(0, hh.MAX_SIDE).map( lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :] ) arrays = [] for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") + x = data.draw(hh.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}") arrays.append(x) out = xp.concat(arrays, **kw) - ph.assert_dtype("concat", dtypes, out.dtype) + ph.assert_dtype("concat", in_dtype=dtypes, out_dtype=out.dtype) shapes = tuple(x.shape for x in arrays) if _axis is None: @@ -88,20 +84,20 @@ def test_concat(dtypes, base_shape, data): for other_shape in shapes[1:]: shape[_axis] += other_shape[_axis] shape = tuple(shape) - ph.assert_result_shape("concat", shapes, out.shape, shape, **kw) + ph.assert_result_shape("concat", in_shapes=shapes, out_shape=out.shape, expected=shape, kw=kw) if _axis is None: - out_indices = (i for i in range(out.size)) + out_indices = (i for i in range(math.prod(out.shape))) for x_num, x in enumerate(arrays, 1): for x_idx in sh.ndindex(x.shape): out_i = next(out_indices) ph.assert_0d_equals( "concat", - f"x{x_num}[{x_idx}]", - x[x_idx], - f"out[{out_i}]", - out[out_i], - **kw, + x_repr=f"x{x_num}[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_i}]", + out_val=out[out_i], + kw=kw, ) else: out_indices = sh.ndindex(out.shape) @@ -113,16 +109,17 @@ def test_concat(dtypes, base_shape, data): out_idx = next(out_indices) ph.assert_0d_equals( "concat", - f"x{x_num}[{f_idx}][{x_idx}]", - indexed_x[x_idx], - f"out[{out_idx}]", - out[out_idx], - **kw, + x_repr=f"x{x_num}[{f_idx}][{x_idx}]", + x_val=indexed_x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, ) +@pytest.mark.unvectorized @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), axis=shared_shapes().flatmap( # Generate both valid and invalid axis lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s)) @@ -136,22 +133,68 @@ def test_expand_dims(x, axis): out = xp.expand_dims(x, axis=axis) - ph.assert_dtype("expand_dims", x.dtype, out.dtype) + ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype) shape = [side for side in x.shape] index = axis if axis >= 0 else x.ndim + axis + 1 shape.insert(index, 1) shape = tuple(shape) - ph.assert_result_shape("expand_dims", [x.shape], out.shape, shape) + ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape) assert_array_ndindex( - "expand_dims", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape) + "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape) + ) + + +@pytest.mark.min_version("2023.12") +@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data()) +def test_moveaxis(x, data): + source = data.draw( + st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim), label="source" ) + if isinstance(source, int): + destination = data.draw(st.integers(-x.ndim, x.ndim - 1), label="destination") + else: + assert isinstance(source, tuple) # sanity check + destination = data.draw( + st.lists( + st.integers(-x.ndim, x.ndim - 1), + min_size=len(source), + max_size=len(source), + unique_by=lambda n: n if n >= 0 else x.ndim + n, + ).map(tuple), + label="destination" + ) + + out = xp.moveaxis(x, source, destination) + + ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype) + + _source = sh.normalize_axis(source, x.ndim) + _destination = sh.normalize_axis(destination, x.ndim) + new_axes = [n for n in range(x.ndim) if n not in _source] + + for dest, src in sorted(zip(_destination, _source)): + new_axes.insert(dest, src) + + expected_shape = tuple(x.shape[i] for i in new_axes) + + ph.assert_result_shape("moveaxis", in_shapes=[x.shape], + out_shape=out.shape, expected=expected_shape, + kw={"source": source, "destination": destination}) + + indices = list(sh.ndindex(x.shape)) + permuted_indices = [tuple(idx[axis] for axis in new_axes) for idx in indices] + assert_array_ndindex( + "moveaxis", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=permuted_indices + ) + +@pytest.mark.unvectorized @given( - x=xps.arrays( - dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1).filter(lambda s: 1 in s) + x=hh.arrays( + dtype=hh.all_dtypes, shape=hh.shapes(min_side=1).filter(lambda s: 1 in s) ), data=st.data(), ) @@ -164,7 +207,7 @@ def test_squeeze(x, data): ) axes = (axis,) if isinstance(axis, int) else axis - axes = sh.normalise_axis(axes, x.ndim) + axes = sh.normalize_axis(axes, x.ndim) squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1] if any(i not in squeezable_axes for i in axes): @@ -174,20 +217,21 @@ def test_squeeze(x, data): out = xp.squeeze(x, axis) - ph.assert_dtype("squeeze", x.dtype, out.dtype) + ph.assert_dtype("squeeze", in_dtype=x.dtype, out_dtype=out.dtype) shape = [] for i, side in enumerate(x.shape): if i not in axes: shape.append(side) shape = tuple(shape) - ph.assert_result_shape("squeeze", [x.shape], out.shape, shape, axis=axis) + ph.assert_result_shape("squeeze", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axis=axis)) - assert_array_ndindex("squeeze", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) + assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) +@pytest.mark.unvectorized @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data(), ) def test_flip(x, data): @@ -201,16 +245,18 @@ def test_flip(x, data): out = xp.flip(x, **kw) - ph.assert_dtype("flip", x.dtype, out.dtype) + ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) for indices in sh.axes_ndindex(x.shape, _axes): reverse_indices = indices[::-1] - assert_array_ndindex("flip", x, indices, out, reverse_indices) + assert_array_ndindex("flip", x, x_indices=indices, out=out, + out_indices=reverse_indices, kw=kw) +@pytest.mark.unvectorized @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes(min_dims=1)), + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)), axes=shared_shapes(min_dims=1).flatmap( lambda s: st.lists( st.integers(0, len(s) - 1), @@ -223,41 +269,93 @@ def test_flip(x, data): def test_permute_dims(x, axes): out = xp.permute_dims(x, axes) - ph.assert_dtype("permute_dims", x.dtype, out.dtype) + ph.assert_dtype("permute_dims", in_dtype=x.dtype, out_dtype=out.dtype) shape = [None for _ in range(len(axes))] for i, dim in enumerate(axes): side = x.shape[dim] shape[i] = side shape = tuple(shape) - ph.assert_result_shape("permute_dims", [x.shape], out.shape, shape, axes=axes) + ph.assert_result_shape("permute_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axes=axes)) indices = list(sh.ndindex(x.shape)) permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices] - assert_array_ndindex("permute_dims", x, indices, out, permuted_indices) - - -@st.composite -def reshape_shapes(draw, shape): - size = 1 if len(shape) == 0 else math.prod(shape) - rshape = draw(st.lists(st.integers(0)).filter(lambda s: math.prod(s) == size)) - assume(all(side <= MAX_SIDE for side in rshape)) - if len(rshape) != 0 and size > 0 and draw(st.booleans()): - index = draw(st.integers(0, len(rshape) - 1)) - rshape[index] = -1 - return tuple(rshape) + assert_array_ndindex("permute_dims", x, x_indices=indices, out=out, + out_indices=permuted_indices) +@pytest.mark.min_version("2023.12") @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(max_side=MAX_SIDE)), + x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)), + kw=hh.kwargs( + axis=st.none() | shared_shapes(min_dims=1).flatmap( + lambda s: st.integers(-len(s), len(s) - 1) + ) + ), data=st.data(), ) -def test_reshape(x, data): - shape = data.draw(reshape_shapes(x.shape)) +def test_repeat(x, kw, data): + shape = x.shape + axis = kw.get("axis", None) + size = math.prod(shape) if axis is None else shape[axis] + repeat_strat = st.integers(1, 10) + repeats = data.draw(repeat_strat + | hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat, + shape=st.sampled_from([(1,), (size,)])), + label="repeats") + if isinstance(repeats, int): + n_repititions = size*repeats + else: + if repeats.shape == (1,): + n_repititions = size*int(repeats[0]) + else: + n_repititions = int(xp.sum(repeats)) + + assume(n_repititions <= hh.SQRT_MAX_ARRAY_SIZE) + out = xp.repeat(x, repeats, **kw) + ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype) + if axis is None: + expected_shape = (n_repititions,) + else: + expected_shape = list(shape) + expected_shape[axis] = n_repititions + expected_shape = tuple(expected_shape) + ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape) + + # Test values + + if isinstance(repeats, int): + repeats_array = xp.full(size, repeats, dtype=xp.int32) + else: + repeats_array = repeats + + if kw.get("axis") is None: + x = xp.reshape(x, (-1,)) + axis = 0 + + for idx, in sh.iter_indices(x.shape, skip_axes=axis): + x_slice = x[idx] + out_slice = out[idx] + start = 0 + for i, count in enumerate(repeats_array): + end = start + count + ph.assert_array_elements("repeat", out=out_slice[start:end], + expected=xp.full((count,), x_slice[i], dtype=x.dtype), + kw=kw) + start = end + +reshape_shape = st.shared(hh.shapes(), key="reshape_shape") + +@pytest.mark.unvectorized +@given( + x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape), + shape=hh.reshape_shapes(reshape_shape), +) +def test_reshape(x, shape): out = xp.reshape(x, shape) - ph.assert_dtype("reshape", x.dtype, out.dtype) + ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype) _shape = list(shape) if any(side == -1 for side in shape): @@ -265,9 +363,9 @@ def test_reshape(x, data): rsize = math.prod(shape) * -1 _shape[shape.index(-1)] = size / rsize _shape = tuple(_shape) - ph.assert_result_shape("reshape", [x.shape], out.shape, _shape, shape=shape) + ph.assert_result_shape("reshape", in_shapes=[x.shape], out_shape=out.shape, expected=_shape, kw=dict(shape=shape)) - assert_array_ndindex("reshape", x, sh.ndindex(x.shape), out, sh.ndindex(out.shape)) + assert_array_ndindex("reshape", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)) def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]: @@ -279,7 +377,8 @@ def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape)) -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=shared_shapes()), st.data()) +@pytest.mark.unvectorized +@given(hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), st.data()) def test_roll(x, data): shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE) if x.ndim > 0: @@ -301,23 +400,24 @@ def test_roll(x, data): kw = {"shift": shift, **kw} # for error messages - ph.assert_dtype("roll", x.dtype, out.dtype) + ph.assert_dtype("roll", in_dtype=x.dtype, out_dtype=out.dtype) - ph.assert_result_shape("roll", [x.shape], out.shape) + ph.assert_result_shape("roll", in_shapes=[x.shape], out_shape=out.shape, kw=kw) if kw.get("axis", None) is None: assert isinstance(shift, int) # sanity check indices = list(sh.ndindex(x.shape)) shifted_indices = deque(indices) shifted_indices.rotate(-shift) - assert_array_ndindex("roll", x, indices, out, shifted_indices, **kw) + assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw) else: shifts = (shift,) if isinstance(shift, int) else shift - axes = sh.normalise_axis(kw["axis"], x.ndim) + axes = sh.normalize_axis(kw["axis"], x.ndim) shifted_indices = roll_ndindex(x.shape, shifts, axes) - assert_array_ndindex("roll", x, sh.ndindex(x.shape), out, shifted_indices, **kw) + assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw) +@pytest.mark.unvectorized @given( shape=shared_shapes(min_dims=1), dtypes=hh.mutually_promotable_dtypes(None), @@ -331,12 +431,12 @@ def test_roll(x, data): def test_stack(shape, dtypes, kw, data): arrays = [] for i, dtype in enumerate(dtypes, 1): - x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}") + x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}") arrays.append(x) out = xp.stack(arrays, **kw) - ph.assert_dtype("stack", dtypes, out.dtype) + ph.assert_dtype("stack", in_dtype=dtypes, out_dtype=out.dtype) axis = kw.get("axis", 0) _axis = axis if axis >= 0 else len(shape) + axis + 1 @@ -344,22 +444,70 @@ def test_stack(shape, dtypes, kw, data): _shape.insert(_axis, len(arrays)) _shape = tuple(_shape) ph.assert_result_shape( - "stack", tuple(x.shape for x in arrays), out.shape, _shape, **kw + "stack", in_shapes=tuple(x.shape for x in arrays), out_shape=out.shape, expected=_shape, kw=kw ) out_indices = sh.ndindex(out.shape) for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis): f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx) - print(f"{f_idx=}") for x_num, x in enumerate(arrays, 1): indexed_x = x[idx] for x_idx in sh.ndindex(indexed_x.shape): out_idx = next(out_indices) ph.assert_0d_equals( "stack", - f"x{x_num}[{f_idx}][{x_idx}]", - indexed_x[x_idx], - f"out[{out_idx}]", - out[out_idx], - **kw, + x_repr=f"x{x_num}[{f_idx}][{x_idx}]", + x_val=indexed_x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, ) + + +@pytest.mark.min_version("2023.12") +@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data()) +def test_tile(x, data): + repetitions = data.draw( + st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple), + label="repetitions" + ) + out = xp.tile(x, repetitions) + ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype) + # TODO: values testing + + # shape check; the notation is from the Array API docs + N, M = len(x.shape), len(repetitions) + if N > M: + S = x.shape + R = (1,)*(N - M) + repetitions + else: + S = (1,)*(M - N) + x.shape + R = repetitions + + assert out.shape == tuple(r*s for r, s in zip(R, S)) + + +@pytest.mark.min_version("2023.12") +@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data()) +def test_unstack(x, data): + axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis") + kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw") + out = xp.unstack(x, **kw) + + assert isinstance(out, tuple) + assert len(out) == x.shape[axis] + expected_shape = list(x.shape) + expected_shape.pop(axis) + expected_shape = tuple(expected_shape) + for i in range(x.shape[axis]): + arr = out[i] + ph.assert_result_shape("unstack", in_shapes=[x.shape], + out_shape=arr.shape, expected=expected_shape, + kw=kw, repr_name=f"out[{i}].shape") + + ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype, + repr_name=f"out[{i}].dtype") + + idx = [slice(None)] * x.ndim + idx[axis] = i + ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]") diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py index d4349372..84bcaa28 100644 --- a/array_api_tests/test_operators_and_elementwise_functions.py +++ b/array_api_tests/test_operators_and_elementwise_functions.py @@ -1,14 +1,19 @@ +""" +Test element-wise functions/operators against reference implementations. +""" +import cmath import math import operator +import builtins +from copy import copy from enum import Enum, auto from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union import pytest from hypothesis import assume, given from hypothesis import strategies as st -from hypothesis.control import reject -from . import _array_module as xp +from . import _array_module as xp, api_version from . import array_helpers as ah from . import dtype_helpers as dh from . import hypothesis_helpers as hh @@ -17,57 +22,11 @@ from . import xps from .typing import Array, DataType, Param, Scalar, ScalarType, Shape -pytestmark = pytest.mark.ci +pytestmark = pytest.mark.unvectorized -def all_integer_dtypes() -> st.SearchStrategy[DataType]: - """Returns a strategy for signed and unsigned integer dtype objects.""" - return xps.unsigned_integer_dtypes() | xps.integer_dtypes() - -def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]: - """Returns a strategy for boolean and all integer dtype objects.""" - return xps.boolean_dtypes() | all_integer_dtypes() - - -class OnewayPromotableDtypes(NamedTuple): - input_dtype: DataType - result_dtype: DataType - - -@st.composite -def oneway_promotable_dtypes( - draw, dtypes: Sequence[DataType] -) -> st.SearchStrategy[OnewayPromotableDtypes]: - """Return a strategy for input dtypes that promote to result dtypes.""" - d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes)) - result_dtype = dh.result_type(d1, d2) - if d1 == result_dtype: - return OnewayPromotableDtypes(d2, d1) - elif d2 == result_dtype: - return OnewayPromotableDtypes(d1, d2) - else: - reject() - - -class OnewayBroadcastableShapes(NamedTuple): - input_shape: Shape - result_shape: Shape - - -@st.composite -def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]: - """Return a strategy for input shapes that broadcast to result shapes.""" - result_shape = draw(hh.shapes(min_side=1)) - input_shape = draw( - xps.broadcastable_shapes( - result_shape, - # Override defaults so bad shapes are less likely to be generated. - max_side=None if result_shape == () else max(result_shape), - max_dims=len(result_shape), - ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape) - ) - return OnewayBroadcastableShapes(input_shape, result_shape) +EPS32 = xp.finfo(xp.float32).eps def mock_int_dtype(n: int, dtype: DataType) -> int: @@ -82,42 +41,46 @@ def mock_int_dtype(n: int, dtype: DataType) -> int: return n -# This module tests elementwise functions/operators against a reference -# implementation. We iterate through the input array(s) and resulting array, -# casting the indexed arrays to Python scalars and calculating the expected -# output with `refimpl` function. -# -# This is finicky to refactor, but possible and ultimately worthwhile - hence -# why these *_assert_again_refimpl() utilities exist. -# -# Values which are special-cased are generated and passed, but are filtered by -# the `filter_` callable before they can be asserted against `refimpl`. We -# automatically generate tests for special cases in the special_cases/ dir. We -# still pass them here so as to ensure their presence doesn't affect the outputs -# respective to non-special-cased elements. -# -# By default, results are casted to scalars the same way that the inputs are. -# You can specify a cast via `res_stype, i.e. when a function accepts numerical -# inputs but returns boolean arrays. -# -# By default, floating-point functions/methods are loosely asserted against. Use -# `strict_check=True` when they should be strictly asserted against, i.e. -# when a function should return intergrals. Likewise, use `strict_check=False` -# when integer function/methods should be loosely asserted against, i.e. when -# floats are used internally for optimisation or legacy reasons. - - -def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bool: +def isclose( + a: float, + b: float, + maximum: float, + *, + rel_tol: float = 0.25, + abs_tol: float = 1, +) -> bool: """Wraps math.isclose with very generous defaults. This is useful for many floating-point operations where the spec does not make accuracy requirements. """ - if not (math.isfinite(a) and math.isfinite(b)): - raise ValueError(f"{a=} and {b=}, but input must be finite") + if math.isnan(a) or math.isnan(b): + raise ValueError(f"{a=} and {b=}, but input must be non-NaN") + if math.isinf(a): + return math.isinf(b) or abs(b) > math.log(maximum) + elif math.isinf(b): + return math.isinf(a) or abs(a) > math.log(maximum) return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) +def isclose_complex( + a: complex, + b: complex, + maximum: float, + *, + rel_tol: float = 0.25, + abs_tol: float = 1, +) -> bool: + """Like isclose() but specifically for complex values.""" + if cmath.isnan(a) or cmath.isnan(b): + raise ValueError(f"{a=} and {b=}, but input must be non-NaN") + if cmath.isinf(a): + return cmath.isinf(b) or abs(b) > 2**(math.log2(maximum)//2) + elif cmath.isinf(b): + return cmath.isinf(a) or abs(a) > 2**(math.log2(maximum)//2) + return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) + + def default_filter(s: Scalar) -> bool: """Returns False when s is a non-finite or a signed zero. @@ -137,40 +100,178 @@ def unary_assert_against_refimpl( in_: Array, res: Array, refimpl: Callable[[T], T], - expr_template: Optional[str] = None, + *, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, strict_check: Optional[bool] = None, + expr_template: Optional[str] = None, ): + """ + Assert unary element-wise results are as expected. + + We iterate through every element in the input and resulting arrays, casting + the respective elements (0-D arrays) to Python scalars, and assert against + the expected output specified by the passed reference implementation, e.g. + + >>> x = xp.asarray([[0, 1], [2, 4]]) + >>> out = xp.square(x) + >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... expected = int(x[idx]) ** 2 + ... assert int(out[idx]) == expected + + Casting + ------- + + The input scalar type is inferred from the input array's dtype like so + + Array dtypes | Python builtin type + ----------------- | --------------------- + xp.bool | bool + xp.int*, xp.uint* | int + xp.float* | float + xp.complex* | complex + + If res_stype=None (the default), the result scalar type is the same as the + input scalar type. We can also specify the result scalar type ourselves, e.g. + + >>> x = xp.asarray([42., xp.inf]) + >>> out = xp.isinf(x) # should be [False, True] + >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool) + + Filtering special-cased values + ------------------------------ + + Values which are special-cased can be present in the input array, but get + filtered before they can be asserted against refimpl. + + If filter_=default_filter (the default), all non-finite and floating zero + values are filtered, e.g. + + >>> unary_assert_against_refimpl('sin', x, out, math.sin) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... at_x = float(x[idx]) + ... if math.isfinite(at_x) or at_x != 0: + ... expected = math.sin(at_x) + ... assert math.isclose(float(out[idx]), expected) + + We can also specify the filter function ourselves, e.g. + + >>> def sqrt_filter(s: float) -> bool: + ... return math.isfinite(s) and s >= 0 + >>> unary_assert_against_refimpl('sqrt', x, out, math.sqrt, filter_=sqrt_filter) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... at_x = float(x[idx]) + ... if math.isfinite(s) and s >=0: + ... expected = math.sin(at_x) + ... assert math.isclose(float(out[idx]), expected) + + Note we leave special-cased values in the input arrays, so as to ensure + their presence doesn't affect the outputs respective to non-special-cased + elements. We specifically test special case bevaiour in test_special_cases.py. + + Assertion strictness + -------------------- + + If strict_check=None (the default), integer elements are strictly asserted + against, and floating elements are loosely asserted against, e.g. + + >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2) + + is equivalent to + + >>> for idx in np.ndindex(x.shape): + ... expected = in_stype(x[idx]) ** 2 + ... if in_stype == int: + ... assert int(out[idx]) == expected + ... else: # in_stype == float + ... assert math.isclose(float(out[idx]), expected) + + Specifying strict_check as True or False will assert strictly/loosely + respectively, regardless of dtype. This is useful for testing functions that + have definitive outputs for floating inputs, i.e. rounding functions. + + Expressions in errors + --------------------- + + Assertion error messages include an expression, by default using func_name + like so + + >>> x = xp.asarray([42., xp.inf]) + >>> out = xp.isinf(x) + >>> out + [False, False] + >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool) + AssertionError: out[1]=False, but should be isinf(x[1])=True ... + + We can specify the expression template ourselves, e.g. + + >>> x = xp.asarray(True) + >>> out = xp.logical_not(x) + >>> out + True + >>> unary_assert_against_refimpl( + ... 'logical_not', x, out, expr_template='(not {})={}' + ... ) + AssertionError: out=True, but should be (not True)=False ... + + """ if in_.shape != res.shape: raise ValueError(f"{res.shape=}, but should be {in_.shape=}") if expr_template is None: expr_template = func_name + "({})={}" in_stype = dh.get_scalar_type(in_.dtype) if res_stype is None: - res_stype = in_stype - m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + res_stype = dh.get_scalar_type(res.dtype) + if res.dtype == xp.bool: + m, M = (None, None) + elif res.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] + else: + m, M = dh.dtype_ranges[res.dtype] + if in_.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) for idx in sh.ndindex(in_.shape): scalar_i = in_stype(in_[idx]) if not filter_(scalar_i): continue try: expected = refimpl(scalar_i) - except Exception: + except OverflowError: continue if res.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[idx]) f_i = sh.fmt_idx("x", idx) f_o = sh.fmt_idx("out", idx) expr = expr_template.format(f_i, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + # TODO: strict check floating results too + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_i}={scalar_i}" ) + if res.dtype in dh.complex_dtypes: + assert isclose_complex(scalar_o, expected, M), msg + else: + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -184,20 +285,36 @@ def binary_assert_against_refimpl( right: Array, res: Array, refimpl: Callable[[T, T], T], - expr_template: Optional[str] = None, + *, res_stype: Optional[ScalarType] = None, + filter_: Callable[[Scalar], bool] = default_filter, + strict_check: Optional[bool] = None, left_sym: str = "x1", right_sym: str = "x2", res_name: str = "out", - filter_: Callable[[Scalar], bool] = default_filter, - strict_check: Optional[bool] = None, + expr_template: Optional[str] = None, ): + """ + Assert binary element-wise results are as expected. + + See unary_assert_against_refimpl for more information. + """ if expr_template is None: expr_template = func_name + "({}, {})={}" in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - m, M = dh.dtype_ranges.get(res.dtype, (None, None)) + if res.dtype == xp.bool: + m, M = (None, None) + elif res.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]] + else: + m, M = dh.dtype_ranges[res.dtype] + if left.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape): scalar_l = in_stype(left[l_idx]) scalar_r = in_stype(right[r_idx]) @@ -205,22 +322,31 @@ def binary_assert_against_refimpl( continue try: expected = refimpl(scalar_l, scalar_r) - except Exception: + except OverflowError: continue if res.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[o_idx]) f_l = sh.fmt_idx(left_sym, l_idx) f_r = sh.fmt_idx(right_sym, r_idx) f_o = sh.fmt_idx(res_name, o_idx) expr = expr_template.format(f_l, f_r, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}, {f_r}={scalar_r}" ) + if res.dtype in dh.complex_dtypes: + assert isclose_complex(scalar_o, expected, M), msg + else: + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -234,40 +360,67 @@ def right_scalar_assert_against_refimpl( right: Scalar, res: Array, refimpl: Callable[[T, T], T], - expr_template: str = None, + *, res_stype: Optional[ScalarType] = None, - left_sym: str = "x1", - res_name: str = "out", filter_: Callable[[Scalar], bool] = default_filter, strict_check: Optional[bool] = None, + left_sym: str = "x1", + res_name: str = "out", + expr_template: str = None, ): + """ + Assert binary element-wise results from scalar operands are as expected. + + See unary_assert_against_refimpl for more information. + """ + if expr_template is None: + expr_template = func_name + "({}, {})={}" + if left.dtype in dh.complex_dtypes: + component_filter = copy(filter_) + filter_ = lambda s: component_filter(s.real) and component_filter(s.imag) if filter_(right): return # short-circuit here as there will be nothing to test in_stype = dh.get_scalar_type(left.dtype) + if res_stype is None: + res_stype = dh.get_scalar_type(left.dtype) if res_stype is None: res_stype = in_stype - m, M = dh.dtype_ranges.get(left.dtype, (None, None)) + if res.dtype == xp.bool: + m, M = (None, None) + elif left.dtype in dh.complex_dtypes: + m, M = dh.dtype_ranges[dh.dtype_components[left.dtype]] + else: + m, M = dh.dtype_ranges[left.dtype] for idx in sh.ndindex(res.shape): scalar_l = in_stype(left[idx]) - if not filter_(scalar_l): + if not (filter_(scalar_l) and filter_(right)): continue try: expected = refimpl(scalar_l, right) - except Exception: + except OverflowError: continue if left.dtype != xp.bool: - assert m is not None and M is not None # for mypy - if expected <= m or expected >= M: - continue + if res.dtype in dh.complex_dtypes: + if expected.real <= m or expected.real >= M: + continue + if expected.imag <= m or expected.imag >= M: + continue + else: + if expected <= m or expected >= M: + continue scalar_o = res_stype(res[idx]) f_l = sh.fmt_idx(left_sym, idx) f_o = sh.fmt_idx(res_name, idx) expr = expr_template.format(f_l, right, expected) - if strict_check == False or dh.is_float_dtype(res.dtype): - assert isclose(scalar_o, expected), ( + if strict_check == False or res.dtype in dh.all_float_dtypes: + msg = ( f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n" f"{f_l}={scalar_l}" ) + if res.dtype in dh.complex_dtypes: + assert isclose_complex(scalar_o, expected, M), msg + else: + assert isclose(scalar_o, expected, M), msg else: assert scalar_o == expected, ( f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n" @@ -275,7 +428,7 @@ def right_scalar_assert_against_refimpl( ) -# When appropiate, this module tests operators alongside their respective +# When appropriate, this module tests operators alongside their respective # elementwise methods. We do this by parametrizing a generalised test method # with every relevant method and operator. # @@ -285,8 +438,8 @@ def right_scalar_assert_against_refimpl( # - The argument strategies, which can be used to draw arguments for the test # case. They may require additional filtering for certain test cases. # - right_is_scalar (binary parameters only), which denotes if the right -# argument is a scalar in a test case. This can be used to appropiately adjust -# draw filtering and test logic. +# argument is a scalar in a test case. This can be used to appropriately +# adjust draw filtering and test logic. func_to_op = {v: k for k, v in dh.op_to_func.items()} @@ -301,16 +454,24 @@ class UnaryParamContext(NamedTuple): @property def id(self) -> str: - return f"{self.func_name}" + return self.func_name def __repr__(self): return f"UnaryParamContext(<{self.id}>)" def make_unary_params( - elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType] + elwise_func_name: str, + dtypes: Sequence[DataType], + *, + min_version: str = "2021.12", ) -> List[Param[UnaryParamContext]]: - strat = xps.arrays(dtype=dtypes_strat, shape=hh.shapes()) + dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] + assert len(dtypes) > 0 # sanity check + if api_version < "2022.12": + dtypes = [d for d in dtypes if d not in dh.complex_dtypes] + dtypes_strat = st.sampled_from(dtypes) + strat = hh.arrays(dtype=dtypes_strat, shape=hh.shapes()) func_ctx = UnaryParamContext( func_name=elwise_func_name, func=getattr(xp, elwise_func_name), strat=strat ) @@ -318,7 +479,16 @@ def make_unary_params( op_ctx = UnaryParamContext( func_name=op_name, func=lambda x: getattr(x, op_name)(), strat=strat ) - return [pytest.param(func_ctx, id=func_ctx.id), pytest.param(op_ctx, id=op_ctx.id)] + if api_version < min_version: + marks = pytest.mark.skip( + reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}" + ) + else: + marks = () + return [ + pytest.param(func_ctx, id=func_ctx.id, marks=marks), + pytest.param(op_ctx, id=op_ctx.id, marks=marks), + ] class FuncType(Enum): @@ -351,9 +521,9 @@ def __repr__(self): def make_binary_params( elwise_func_name: str, dtypes: Sequence[DataType] ) -> List[Param[BinaryParamContext]]: - if hh.FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] - shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes)) + dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)] + assert len(dtypes) > 0 # sanity check + shared_oneway_dtypes = st.shared(hh.oneway_promotable_dtypes(dtypes)) left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype) right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype) @@ -368,16 +538,16 @@ def make_param( right_sym = "x2" if right_is_scalar: - left_strat = xps.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw)) - right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw)) + left_strat = hh.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw)) + right_strat = right_dtypes.flatmap(lambda d: hh.from_dtype(d, **finite_kw)) else: if func_type is FuncType.IOP: - shared_oneway_shapes = st.shared(oneway_broadcastable_shapes()) - left_strat = xps.arrays( + shared_oneway_shapes = st.shared(hh.oneway_broadcastable_shapes()) + left_strat = hh.arrays( dtype=left_dtypes, shape=shared_oneway_shapes.map(lambda S: S.result_shape), ) - right_strat = xps.arrays( + right_strat = hh.arrays( dtype=right_dtypes, shape=shared_oneway_shapes.map(lambda S: S.input_shape), ) @@ -385,10 +555,10 @@ def make_param( mutual_shapes = st.shared( hh.mutually_broadcastable_shapes(2, **shapes_kw) ) - left_strat = xps.arrays( + left_strat = hh.arrays( dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0]) ) - right_strat = xps.arrays( + right_strat = hh.arrays( dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1]) ) @@ -459,7 +629,7 @@ def binary_param_assert_dtype( else: in_dtypes = [left.dtype, right.dtype] # type: ignore ph.assert_dtype( - ctx.func_name, in_dtypes, res.dtype, expected, repr_name=f"{ctx.res_name}.dtype" + ctx.func_name, in_dtype=in_dtypes, out_dtype=res.dtype, expected=expected, repr_name=f"{ctx.res_name}.dtype" ) @@ -475,7 +645,7 @@ def binary_param_assert_shape( else: in_shapes = [left.shape, right.shape] # type: ignore ph.assert_result_shape( - ctx.func_name, in_shapes, res.shape, expected, repr_name=f"{ctx.res_name}.shape" + ctx.func_name, in_shapes=in_shapes, out_shape=res.shape, expected=expected, repr_name=f"{ctx.res_name}.shape" ) @@ -486,6 +656,7 @@ def binary_param_assert_against_refimpl( res: Array, op_sym: str, refimpl: Callable[[T, T], T], + *, res_stype: Optional[ScalarType] = None, filter_: Callable[[Scalar], bool] = default_filter, strict_check: Optional[bool] = None, @@ -522,7 +693,43 @@ def binary_param_assert_against_refimpl( ) -@pytest.mark.parametrize("ctx", make_unary_params("abs", xps.numeric_dtypes())) +def _convert_scalars_helper(x1, x2): + """Convert python scalar to arrays, record the shapes/dtypes of arrays. + + For inputs being scalars or arrays, return the dtypes and shapes of array arguments, + and all arguments converted to arrays. + + dtypes are separate to help distinguishing between + `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array` + """ + if dh.is_scalar(x1): + in_dtypes = [x2.dtype] + in_shapes = [x2.shape] + x1a, x2a = xp.asarray(x1), x2 + elif dh.is_scalar(x2): + in_dtypes = [x1.dtype] + in_shapes = [x1.shape] + x1a, x2a = x1, xp.asarray(x2) + else: + in_dtypes = [x1.dtype, x2.dtype] + in_shapes = [x1.shape, x2.shape] + x1a, x2a = x1, x2 + + return in_dtypes, in_shapes, (x1a, x2a) + + +def _assert_correctness_binary( + name, func, in_dtypes, in_shapes, in_arrs, out, expected_dtype=None, **kwargs +): + x1a, x2a = in_arrs + ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype) + ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape) + check_values = kwargs.pop('check_values', None) + if check_values: + binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs) + + +@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes)) @given(data=st.data()) def test_abs(ctx, data): x = data.draw(ctx.strat, label="x") @@ -532,37 +739,45 @@ def test_abs(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) + if x.dtype in dh.complex_dtypes: + assert out.dtype == dh.dtype_components[x.dtype] + else: + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( ctx.func_name, x, out, abs, # type: ignore + res_stype=float if x.dtype in dh.complex_dtypes else None, expr_template="abs({})={}", - filter_=lambda s: ( - s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s)) - ), + # filter_=lambda s: ( + # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s)) + # ), ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acos(x): out = xp.acos(x) - ph.assert_dtype("acos", x.dtype, out.dtype) - ph.assert_shape("acos", out.shape, x.shape) + ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 unary_assert_against_refimpl( - "acos", x, out, math.acos, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + "acos", x, out, refimpl, filter_=filter_ ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_acosh(x): out = xp.acosh(x) - ph.assert_dtype("acosh", x.dtype, out.dtype) - ph.assert_shape("acosh", out.shape, x.shape) + ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1 unary_assert_against_refimpl( - "acosh", x, out, math.acosh, filter_=lambda s: default_filter(s) and s >= 1 + "acosh", x, out, refimpl, filter_=filter_ ) @@ -572,61 +787,70 @@ def test_add(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - try: + with hh.reject_overflow(): res = ctx.func(left, right) - except OverflowError: - reject() binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asin(x): out = xp.asin(x) - ph.assert_dtype("asin", x.dtype, out.dtype) - ph.assert_shape("asin", out.shape, x.shape) + ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1 unary_assert_against_refimpl( - "asin", x, out, math.asin, filter_=lambda s: default_filter(s) and -1 <= s <= 1 + "asin", x, out, refimpl, filter_=filter_ ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_asinh(x): out = xp.asinh(x) - ph.assert_dtype("asinh", x.dtype, out.dtype) - ph.assert_shape("asinh", out.shape, x.shape) - unary_assert_against_refimpl("asinh", x, out, math.asinh) + ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh + unary_assert_against_refimpl("asinh", x, out, refimpl) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atan(x): out = xp.atan(x) - ph.assert_dtype("atan", x.dtype, out.dtype) - ph.assert_shape("atan", out.shape, x.shape) - unary_assert_against_refimpl("atan", x, out, math.atan) + ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan + unary_assert_against_refimpl("atan", x, out, refimpl) -@given(*hh.two_mutual_arrays(dh.float_dtypes)) +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_atan2(x1, x2): out = xp.atan2(x1, x2) - ph.assert_dtype("atan2", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("atan2", [x1.shape, x2.shape], out.shape) - binary_assert_against_refimpl("atan2", x1, x2, out, math.atan2) + _assert_correctness_binary( + "atan", + cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_atanh(x): out = xp.atanh(x) - ph.assert_dtype("atanh", x.dtype, out.dtype) - ph.assert_shape("atanh", out.shape, x.shape) + ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1 unary_assert_against_refimpl( "atanh", x, out, - math.atanh, - filter_=lambda s: default_filter(s) and -1 <= s <= 1, + refimpl, + filter_=filter_, ) @@ -665,14 +889,14 @@ def test_bitwise_left_shift(ctx, data): binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) - nbits = res.dtype + nbits = dh.dtype_nbits[res.dtype] binary_param_assert_against_refimpl( ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0 ) @pytest.mark.parametrize( - "ctx", make_unary_params("bitwise_invert", boolean_and_all_integer_dtypes()) + "ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes) ) @given(data=st.data()) def test_bitwise_invert(ctx, data): @@ -680,8 +904,8 @@ def test_bitwise_invert(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) if x.dtype == xp.bool: refimpl = operator.not_ else: @@ -748,37 +972,186 @@ def test_bitwise_xor(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) def test_ceil(x): out = xp.ceil(x) - ph.assert_dtype("ceil", x.dtype, out.dtype) - ph.assert_shape("ceil", out.shape, x.shape) + ph.assert_dtype("ceil", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("ceil", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@pytest.mark.min_version("2023.12") +@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data()) +def test_clip(x, data): + # Ensure that if both min and max are arrays that all three of x, min, max + # are broadcast compatible. + shape1, shape2 = data.draw(hh.mutually_broadcastable_shapes(2, + base_shape=x.shape), + label="min.shape, max.shape") + + min = data.draw(st.one_of( + st.none(), + hh.scalars(dtypes=st.just(x.dtype)), + hh.arrays(dtype=st.just(x.dtype), shape=shape1), + ), label="min") + max = data.draw(st.one_of( + st.none(), + hh.scalars(dtypes=st.just(x.dtype)), + hh.arrays(dtype=st.just(x.dtype), shape=shape2), + ), label="max") + + # Note1: min > max is undefined (but allow nans) + assume(min is None or max is None or not xp.any(ah.less(xp.asarray(max, dtype=x.dtype), xp.asarray(min, dtype=x.dtype)))) + + kw = data.draw( + hh.specified_kwargs( + ("min", min, None), + ("max", max, None)), + label="kwargs") + + out = xp.clip(x, **kw) + + # min and max do not participate in type promotion + ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype) + + shapes = [x.shape] + if min is not None and not dh.is_scalar(min): + shapes.append(min.shape) + if max is not None and not dh.is_scalar(max): + shapes.append(max.shape) + expected_shape = sh.broadcast_shapes(*shapes) + ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape) + + # This is based on right_scalar_assert_against_refimpl and + # binary_assert_against_refimpl. clip() is currently the only ternary + # elementwise function and the only function that supports arrays and + # scalars. However, where() (in test_searching_functions) is similar + # and if scalar support is added to it, we may want to factor out and + # reuse this logic. + + def refimpl(_x, _min, _max): + # Skip cases where _min and _max are integers whose values do not + # fit in the dtype of _x, since this behavior is unspecified. + if dh.is_int_dtype(x.dtype): + if _min is not None and _min not in dh.dtype_ranges[x.dtype]: + return None + if _max is not None and _max not in dh.dtype_ranges[x.dtype]: + return None + + # If min or max are float64 and x is float32, they will need to be + # downcast to float32. This could result in a round in the wrong + # direction meaning the resulting clipped value might not actually be + # between min and max. This behavior is unspecified, so skip any cases + # where x is within the rounding error of downcasting min or max. + if x.dtype == xp.float32: + if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min): + _min_float32 = float(xp.asarray(_min, dtype=xp.float32)) + if math.isinf(_min_float32): + return None + tol = abs(_min - _min_float32) + if math.isclose(_min, _min_float32, abs_tol=tol): + return None + if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max): + _max_float32 = float(xp.asarray(_max, dtype=xp.float32)) + if math.isinf(_max_float32): + return None + tol = abs(_max - _max_float32) + if math.isclose(_max, _max_float32, abs_tol=tol): + return None + + if (math.isnan(_x) + or (_min is not None and math.isnan(_min)) + or (_max is not None and math.isnan(_max))): + return math.nan + if _min is _max is None: + return _x + if _max is None: + return builtins.max(_x, _min) + if _min is None: + return builtins.min(_x, _max) + return builtins.min(builtins.max(_x, _min), _max) + + stype = dh.get_scalar_type(x.dtype) + min_shape = () if min is None or dh.is_scalar(min) else min.shape + max_shape = () if max is None or dh.is_scalar(max) else max.shape + + for x_idx, min_idx, max_idx, o_idx in sh.iter_indices( + x.shape, min_shape, max_shape, out.shape): + x_val = stype(x[x_idx]) + if min is None or dh.is_scalar(min): + min_val = min + else: + min_val = stype(min[min_idx]) + if max is None or dh.is_scalar(max): + max_val = max + else: + max_val = stype(max[max_idx]) + expected = refimpl(x_val, min_val, max_val) + if expected is None: + continue + out_val = stype(out[o_idx]) + if math.isnan(expected): + assert math.isnan(out_val), ( + f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" + ) + else: + if out.dtype == xp.float32: + # conversion to builtin float is prone to roundoff errors + close_enough = math.isclose(out_val, expected, rel_tol=EPS32) + else: + close_enough = out_val == expected + + assert close_enough, ( + f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n" + f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}" + ) + + +@pytest.mark.min_version("2022.12") +@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") +@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) +def test_conj(x): + out = xp.conj(x) + ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("conj", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate")) + + +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_copysign(x1, x2): + out = xp.copysign(x1, x2) + ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype) + ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape) + binary_assert_against_refimpl("copysign", x1, x2, out, math.copysign) + + +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cos(x): out = xp.cos(x) - ph.assert_dtype("cos", x.dtype, out.dtype) - ph.assert_shape("cos", out.shape, x.shape) - unary_assert_against_refimpl("cos", x, out, math.cos) + ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cos", out_shape=out.shape, expected=x.shape) + refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos + unary_assert_against_refimpl("cos", x, out, refimpl) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_cosh(x): out = xp.cosh(x) - ph.assert_dtype("cosh", x.dtype, out.dtype) - ph.assert_shape("cosh", out.shape, x.shape) - unary_assert_against_refimpl("cosh", x, out, math.cosh) + ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh + unary_assert_against_refimpl("cosh", x, out, refimpl) -@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes)) @given(data=st.data()) def test_divide(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) if ctx.right_is_scalar: - assume + assume # TODO: assume what? res = ctx.func(left, right) @@ -791,7 +1164,7 @@ def test_divide(ctx, data): res, "/", operator.truediv, - filter_=lambda s: math.isfinite(s) and s != 0, + filter_=lambda s: cmath.isfinite(s) and s != 0, ) @@ -823,31 +1196,53 @@ def test_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_exp(x): out = xp.exp(x) - ph.assert_dtype("exp", x.dtype, out.dtype) - ph.assert_shape("exp", out.shape, x.shape) - unary_assert_against_refimpl("exp", x, out, math.exp) + ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("exp", out_shape=out.shape, expected=x.shape) + refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp + unary_assert_against_refimpl("exp", x, out, refimpl) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_expm1(x): out = xp.expm1(x) - ph.assert_dtype("expm1", x.dtype, out.dtype) - ph.assert_shape("expm1", out.shape, x.shape) - unary_assert_against_refimpl("expm1", x, out, math.expm1) + ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + # There's no cmath.expm1. Use + # + # exp(x+yi) - 1 + # = exp(x)exp(yi) - 1 + # = exp(x)(cos(y) + sin(y)i) - 1 + # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i + # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i + # + # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of + # significance near y = 0. + re, im = z.real, z.imag + return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im) + else: + refimpl = math.expm1 + unary_assert_against_refimpl("expm1", x, out, refimpl) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes())) def test_floor(x): out = xp.floor(x) - ph.assert_dtype("floor", x.dtype, out.dtype) - ph.assert_shape("floor", out.shape, x.shape) - unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True) + ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("floor", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + def refimpl(z): + return complex(math.floor(z.real), math.floor(z.imag)) + else: + refimpl = math.floor + unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True) -@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes)) @given(data=st.data()) def test_floor_divide(ctx, data): left = data.draw( @@ -866,7 +1261,7 @@ def test_floor_divide(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv) -@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes)) @given(data=st.data()) def test_greater(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -886,7 +1281,7 @@ def test_greater(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes)) @given(data=st.data()) def test_greater_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -906,31 +1301,58 @@ def test_greater_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_hypot(x1, x2): + out = xp.hypot(x1, x2) + _assert_correctness_binary( + "hypot", + math.hypot, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + + +@pytest.mark.min_version("2022.12") +@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") +@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) +def test_imag(x): + out = xp.imag(x) + ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("imag", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag")) + + +@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isfinite(x): out = xp.isfinite(x) - ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) - ph.assert_shape("isfinite", out.shape, x.shape) - unary_assert_against_refimpl("isfinite", x, out, math.isfinite, res_stype=bool) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite + unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isinf(x): out = xp.isinf(x) - ph.assert_dtype("isfinite", x.dtype, out.dtype, xp.bool) - ph.assert_shape("isinf", out.shape, x.shape) - unary_assert_against_refimpl("isinf", x, out, math.isinf, res_stype=bool) + ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf + unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_isnan(x): out = xp.isnan(x) - ph.assert_dtype("isnan", x.dtype, out.dtype, xp.bool) - ph.assert_shape("isnan", out.shape, x.shape) - unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool) + ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan + unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool) -@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes)) @given(data=st.data()) def test_less(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -950,7 +1372,7 @@ def test_less(ctx, data): ) -@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes)) +@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes)) @given(data=st.data()) def test_less_equal(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -970,95 +1392,157 @@ def test_less_equal(ctx, data): ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log(x): out = xp.log(x) - ph.assert_dtype("log", x.dtype, out.dtype) - ph.assert_shape("log", out.shape, x.shape) + ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log", out_shape=out.shape, expected=x.shape) + refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log", x, out, math.log, filter_=lambda s: default_filter(s) and s >= 1 + "log", x, out, refimpl, filter_=filter_ ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log1p(x): out = xp.log1p(x) - ph.assert_dtype("log1p", x.dtype, out.dtype) - ph.assert_shape("log1p", out.shape, x.shape) + ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape) + # There isn't a cmath.log1p, and implementing one isn't straightforward + # (see + # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy). + # For now, just use log(1+p) for complex inputs, which should hopefully be + # fine given the very loose numerical tolerances we use. If it isn't, we + # can try using something like a series expansion for small p. + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(1+z) + else: + refimpl = math.log1p + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1 unary_assert_against_refimpl( - "log1p", x, out, math.log1p, filter_=lambda s: default_filter(s) and s >= 1 + "log1p", x, out, refimpl, filter_=filter_ ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log2(x): out = xp.log2(x) - ph.assert_dtype("log2", x.dtype, out.dtype) - ph.assert_shape("log2", out.shape, x.shape) + ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log2", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(2) + else: + refimpl = math.log2 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log2", x, out, math.log2, filter_=lambda s: default_filter(s) and s > 1 + "log2", x, out, refimpl, filter_=filter_ ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_log10(x): out = xp.log10(x) - ph.assert_dtype("log10", x.dtype, out.dtype) - ph.assert_shape("log10", out.shape, x.shape) + ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("log10", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: cmath.log(z)/math.log(10) + else: + refimpl = math.log10 + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0 unary_assert_against_refimpl( - "log10", x, out, math.log10, filter_=lambda s: default_filter(s) and s > 0 + "log10", x, out, refimpl, filter_=filter_ ) -def logaddexp(l: float, r: float) -> float: - return math.log(math.exp(l) + math.exp(r)) +def logaddexp_refimpl(l: float, r: float) -> float: + try: + return math.log(math.exp(l) + math.exp(r)) + except ValueError: # raised for log(0.) + raise OverflowError -@given(*hh.two_mutual_arrays(dh.float_dtypes)) +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) def test_logaddexp(x1, x2): out = xp.logaddexp(x1, x2) - ph.assert_dtype("logaddexp", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logaddexp", [x1.shape, x2.shape], out.shape) - binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp) - - -@given(*hh.two_mutual_arrays([xp.bool])) -def test_logical_and(x1, x2): - out = xp.logical_and(x1, x2) - ph.assert_dtype("logical_and", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_and", [x1.shape, x2.shape], out.shape) - binary_assert_against_refimpl( - "logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}" + _assert_correctness_binary( + "logaddexp", + logaddexp_refimpl, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out ) -@given(xps.arrays(dtype=xp.bool, shape=hh.shapes())) +@given(hh.arrays(dtype=xp.bool, shape=hh.shapes())) def test_logical_not(x): out = xp.logical_not(x) - ph.assert_dtype("logical_not", x.dtype, out.dtype) - ph.assert_shape("logical_not", out.shape, x.shape) + ph.assert_dtype("logical_not", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("logical_not", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( "logical_not", x, out, operator.not_, expr_template="(not {})={}" ) +@given(*hh.two_mutual_arrays([xp.bool])) +def test_logical_and(x1, x2): + out = xp.logical_and(x1, x2) + _assert_correctness_binary( + "logical_and", + operator.and_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} and {})={}" + ) + + @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_or(x1, x2): out = xp.logical_or(x1, x2) - ph.assert_dtype("logical_or", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_or", [x1.shape, x2.shape], out.shape) - binary_assert_against_refimpl( - "logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}" + _assert_correctness_binary( + "logical_or", + operator.or_, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} or {})={}" ) @given(*hh.two_mutual_arrays([xp.bool])) def test_logical_xor(x1, x2): out = xp.logical_xor(x1, x2) - ph.assert_dtype("logical_xor", [x1.dtype, x2.dtype], out.dtype) - ph.assert_result_shape("logical_xor", [x1.shape, x2.shape], out.shape) - binary_assert_against_refimpl( - "logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}" + _assert_correctness_binary( + "logical_xor", + operator.xor, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out, + expr_template="({} ^ {})={}" + ) + + +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_maximum(x1, x2): + out = xp.maximum(x1, x2) + _assert_correctness_binary( + "maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True + ) + + +@pytest.mark.min_version("2023.12") +@given(*hh.two_mutual_arrays(dh.real_float_dtypes)) +def test_minimum(x1, x2): + out = xp.minimum(x1, x2) + _assert_correctness_binary( + "minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True ) @@ -1076,9 +1560,7 @@ def test_multiply(ctx, data): # TODO: clarify if uints are acceptable, adjust accordingly -@pytest.mark.parametrize( - "ctx", make_unary_params("negative", xps.integer_dtypes() | xps.floating_dtypes()) -) +@pytest.mark.parametrize("ctx", make_unary_params("negative", dh.numeric_dtypes)) @given(data=st.data()) def test_negative(ctx, data): x = data.draw(ctx.strat, label="x") @@ -1088,8 +1570,8 @@ def test_negative(ctx, data): out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore ) @@ -1115,16 +1597,36 @@ def test_not_equal(ctx, data): ) -@pytest.mark.parametrize("ctx", make_unary_params("positive", xps.numeric_dtypes())) +@pytest.mark.min_version("2024.12") +@given( + shapes=hh.two_mutually_broadcastable_shapes, + dtype=hh.real_floating_dtypes, + data=st.data() +) +def test_nextafter(shapes, dtype, data): + x1 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x1") + x2 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x2") + + out = xp.nextafter(x1, x2) + _assert_correctness_binary( + "nextafter", + math.nextafter, + in_dtypes=[x1.dtype, x2.dtype], + in_shapes=[x1.shape, x2.shape], + in_arrs=[x1, x2], + out=out + ) + +@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes)) @given(data=st.data()) def test_positive(ctx, data): x = data.draw(ctx.strat, label="x") out = ctx.func(x) - ph.assert_dtype(ctx.func_name, x.dtype, out.dtype) - ph.assert_shape(ctx.func_name, out.shape, x.shape) - ph.assert_array_elements(ctx.func_name, out, x) + ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape) + ph.assert_array_elements(ctx.func_name, out=out, expected=x) @pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes)) @@ -1139,17 +1641,42 @@ def test_pow(ctx, data): if dh.is_int_dtype(right.dtype): assume(xp.all(right >= 0)) - try: + with hh.reject_overflow(): res = ctx.func(left, right) - except OverflowError: - reject() binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) # Values testing pow is too finicky -@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes)) +@pytest.mark.min_version("2022.12") +@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from") +@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes())) +def test_real(x): + out = xp.real(x) + ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype]) + ph.assert_shape("real", out_shape=out.shape, expected=x.shape) + unary_assert_against_refimpl("real", x, out, operator.attrgetter("real")) + + +@pytest.mark.min_version("2024.12") +@given(hh.arrays(dtype=hh.floating_dtypes, shape=hh.shapes(), elements=finite_kw)) +def test_reciprocal(x): + out = xp.reciprocal(x) + ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: 1.0 / x + unary_assert_against_refimpl( + "reciprocal", + x, + out, + refimpl, + strict_check=True, + ) + + +@pytest.mark.skip(reason="flaky") +@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes)) @given(data=st.data()) def test_remainder(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) @@ -1166,57 +1693,80 @@ def test_remainder(ctx, data): binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_round(x): out = xp.round(x) - ph.assert_dtype("round", x.dtype, out.dtype) - ph.assert_shape("round", out.shape, x.shape) - unary_assert_against_refimpl("round", x, out, round, strict_check=True) + ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("round", out_shape=out.shape, expected=x.shape) + if x.dtype in dh.complex_dtypes: + refimpl = lambda z: complex(round(z.real), round(z.imag)) + else: + refimpl = round + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) + + +@pytest.mark.min_version("2023.12") +@given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes())) +def test_signbit(x): + out = xp.signbit(x) + ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: math.copysign(1.0, x) < 0 + unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes(), elements=finite_kw)) +@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes(), elements=finite_kw)) def test_sign(x): out = xp.sign(x) - ph.assert_dtype("sign", x.dtype, out.dtype) - ph.assert_shape("sign", out.shape, x.shape) + ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sign", out_shape=out.shape, expected=x.shape) + refimpl = lambda x: x / abs(x) if x != 0 else 0 unary_assert_against_refimpl( - "sign", x, out, lambda s: math.copysign(1, s), filter_=lambda s: s != 0 + "sign", + x, + out, + refimpl, + strict_check=True, ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sin(x): out = xp.sin(x) - ph.assert_dtype("sin", x.dtype, out.dtype) - ph.assert_shape("sin", out.shape, x.shape) - unary_assert_against_refimpl("sin", x, out, math.sin) + ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sin", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin + unary_assert_against_refimpl("sin", x, out, refimpl) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sinh(x): out = xp.sinh(x) - ph.assert_dtype("sinh", x.dtype, out.dtype) - ph.assert_shape("sinh", out.shape, x.shape) - unary_assert_against_refimpl("sinh", x, out, math.sinh) + ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh + unary_assert_against_refimpl("sinh", x, out, refimpl) -@given(xps.arrays(dtype=xps.numeric_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes())) def test_square(x): out = xp.square(x) - ph.assert_dtype("square", x.dtype, out.dtype) - ph.assert_shape("square", out.shape, x.shape) + ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("square", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl( - "square", x, out, lambda s: s ** 2, expr_template="{}²={}" + "square", x, out, lambda s: s*s, expr_template="{}²={}" ) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_sqrt(x): out = xp.sqrt(x) - ph.assert_dtype("sqrt", x.dtype, out.dtype) - ph.assert_shape("sqrt", out.shape, x.shape) + ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape) + refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt + filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0 unary_assert_against_refimpl( - "sqrt", x, out, math.sqrt, filter_=lambda s: default_filter(s) and s >= 0 + "sqrt", x, out, refimpl, filter_=filter_ ) @@ -1226,35 +1776,180 @@ def test_subtract(ctx, data): left = data.draw(ctx.left_strat, label=ctx.left_sym) right = data.draw(ctx.right_strat, label=ctx.right_sym) - try: + with hh.reject_overflow(): res = ctx.func(left, right) - except OverflowError: - reject() binary_param_assert_dtype(ctx, left, right, res) binary_param_assert_shape(ctx, left, right, res) binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tan(x): out = xp.tan(x) - ph.assert_dtype("tan", x.dtype, out.dtype) - ph.assert_shape("tan", out.shape, x.shape) - unary_assert_against_refimpl("tan", x, out, math.tan) + ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tan", out_shape=out.shape, expected=x.shape) + refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan + unary_assert_against_refimpl("tan", x, out, refimpl) -@given(xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes())) +@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes())) def test_tanh(x): out = xp.tanh(x) - ph.assert_dtype("tanh", x.dtype, out.dtype) - ph.assert_shape("tanh", out.shape, x.shape) - unary_assert_against_refimpl("tanh", x, out, math.tanh) + ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape) + refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh + unary_assert_against_refimpl("tanh", x, out, refimpl) -@given(xps.arrays(dtype=hh.numeric_dtypes, shape=xps.array_shapes())) +@given(hh.arrays(dtype=hh.real_dtypes, shape=xps.array_shapes())) def test_trunc(x): out = xp.trunc(x) - ph.assert_dtype("trunc", x.dtype, out.dtype) - ph.assert_shape("trunc", out.shape, x.shape) + ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype) + ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape) unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True) + + +def _check_binary_with_scalars(func_data, x1x2): + x1, x2 = x1x2 + func_name, refimpl, kwds, expected_dtype = func_data + func = getattr(xp, func_name) + out = func(x1, x2) + in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2) + _assert_correctness_binary( + func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds + ) + + +def _filter_zero(x): + return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0)) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # func_name, refimpl, kwargs, expected_dtype + [ + ("add", operator.add, {}, None), + ("atan2", math.atan2, {}, None), + ("copysign", math.copysign, {}, None), + ("divide", operator.truediv, {"filter_": lambda s: s != 0}, None), + ("hypot", math.hypot, {}, None), + ("logaddexp", logaddexp_refimpl, {}, None), + ("nextafter", math.nextafter, {}, None), + ("maximum", max, {'strict_check': True}, None), + ("minimum", min, {'strict_check': True}, None), + ("multiply", operator.mul, {}, None), + ("subtract", operator.sub, {}, None), + + ("equal", operator.eq, {}, xp.bool), + ("not_equal", operator.ne, {}, xp.bool), + ("less", operator.lt, {}, xp.bool), + ("less_equal", operator.le, {}, xp.bool), + ("greater", operator.gt, {}, xp.bool), + ("greater_equal", operator.ge, {}, xp.bool), + ("pow", operator.pow, {'check_values': False}, None) # value tests are too finicky for pow + ], + ids=lambda func_data: func_data[0] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes)) +def test_binary_with_scalars_real(func_data, x1x2): + _check_binary_with_scalars(func_data, x1x2) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # func_name, refimpl, kwargs, expected_dtype + [ + ("logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None), + ("logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None), + ("logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None), + ], + ids=lambda func_data: func_data[0] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.bool])) +def test_binary_with_scalars_bool(func_data, x1x2): + _check_binary_with_scalars(func_data, x1x2) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # func_name, refimpl, kwargs, expected_dtype + [ + ("floor_divide", operator.floordiv, {}, None), + ("remainder", operator.mod, {}, None), + ], + ids=lambda func_data: func_data[0] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.int64])) +def test_binary_with_scalars_int(func_data, x1x2): + assume(_filter_zero(x1x2[1])) + assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1])) + _check_binary_with_scalars(func_data, x1x2) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # func_name, refimpl, kwargs, expected_dtype + [ + ("bitwise_and", operator.and_, {}, None), + ("bitwise_or", operator.or_, {}, None), + ("bitwise_xor", operator.xor, {}, None), + ], + ids=lambda func_data: func_data[0] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.int32])) +def test_binary_with_scalars_bitwise(func_data, x1x2): + func_name, refimpl, kwargs, expected = func_data + # repack the refimpl + refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 ) + _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.parametrize('func_data', + # func_name, refimpl, kwargs, expected_dtype + [ + ("bitwise_left_shift", operator.lshift, {}, None), + ("bitwise_right_shift", operator.rshift, {}, None), + ], + ids=lambda func_data: func_data[0] # use names for test IDs +) +@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3))) +def test_binary_with_scalars_bitwise_shifts(func_data, x1x2): + func_name, refimpl, kwargs, expected = func_data + # repack the refimpl + refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 ) + _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2) + + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@given( + x1x2=hh.array_and_py_scalar([xp.int32]), + data=st.data() +) +def test_where_with_scalars(x1x2, data): + x1, x2 = x1x2 + + if dh.is_scalar(x1): + dtype, shape = x2.dtype, x2.shape + x1_arr, x2_arr = xp.broadcast_to(xp.asarray(x1), shape), x2 + else: + dtype, shape = x1.dtype, x1.shape + x1_arr, x2_arr = x1, xp.broadcast_to(xp.asarray(x2), shape) + + condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool)) + + out = xp.where(condition, x1, x2) + + assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}" + assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}" + + # value test + for idx in sh.ndindex(shape): + if condition[idx]: + assert out[idx] == x1_arr[idx] + else: + assert out[idx] == x2_arr[idx] + + diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py index 6b134bb0..b72c8030 100644 --- a/array_api_tests/test_searching_functions.py +++ b/array_api_tests/test_searching_functions.py @@ -1,5 +1,7 @@ +import math + import pytest -from hypothesis import given +from hypothesis import given, note, assume from hypothesis import strategies as st from . import _array_module as xp @@ -9,12 +11,13 @@ from . import shape_helpers as sh from . import xps -pytestmark = pytest.mark.ci + +pytestmark = pytest.mark.unvectorized @given( - x=xps.arrays( - dtype=xps.numeric_dtypes(), + x=hh.arrays( + dtype=hh.real_dtypes, shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), @@ -28,13 +31,14 @@ def test_argmax(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.argmax(x, **kw) ph.assert_default_index("argmax", out.dtype) - axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "argmax", x.shape, out.shape, axes, kw.get("keepdims", False), **kw + "argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): @@ -44,12 +48,13 @@ def test_argmax(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(range(len(elements)), key=elements.__getitem__) - ph.assert_scalar_equals("argmax", int, out_idx, max_i, expected) + ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i, + expected=expected, kw=kw) @given( - x=xps.arrays( - dtype=xps.numeric_dtypes(), + x=hh.arrays( + dtype=hh.real_dtypes, shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), @@ -63,13 +68,14 @@ def test_argmin(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.argmin(x, **kw) ph.assert_default_index("argmin", out.dtype) - axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "argmin", x.shape, out.shape, axes, kw.get("keepdims", False), **kw + "argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): @@ -79,23 +85,74 @@ def test_argmin(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(range(len(elements)), key=elements.__getitem__) - ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected) + ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected) + + +# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on +# the problem is tha for ints >iinfo(int32) it runs into essentially this: +# >>> jnp.asarray[2147483648], dtype=jnp.int64) +# .... https://github.com/jax-ml/jax/pull/6047 ... +# Explicitly limiting the range in elements(...) runs into problems with +# hypothesis where floating-point numbers are not exactly representable. +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays( + dtype=hh.all_dtypes, + shape=hh.shapes(min_dims=1, min_side=1), + elements={"allow_nan": False}, + ), + data=st.data(), +) +def test_count_nonzero(x, data): + kw = data.draw( + hh.kwargs( + axis=hh.axes(x.ndim), + keepdims=st.booleans(), + ), + label="kw", + ) + keepdims = kw.get("keepdims", False) + + assume(kw.get("axis", None) != ()) # TODO clarify in the spec + + out = xp.count_nonzero(x, **kw) + + ph.assert_default_index("count_nonzero", out.dtype) + axes = sh.normalize_axis(kw.get("axis", None), x.ndim) + ph.assert_keepdimable_shape( + "count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw + ) + scalar_type = dh.get_scalar_type(x.dtype) + + for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)): + count = int(out[out_idx]) + elements = [] + for idx in indices: + s = scalar_type(x[idx]) + elements.append(s) + expected = sum(el != 0 for el in elements) + ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected) + + +@given(hh.arrays(dtype=hh.all_dtypes, shape=())) +def test_nonzero_zerodim_error(x): + with pytest.raises(Exception): + xp.nonzero(x) @pytest.mark.data_dependent_shapes -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, min_side=1))) def test_nonzero(x): out = xp.nonzero(x) - if x.ndim == 0: - assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays" - else: - assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" - size = out[0].size + assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}" + out_size = math.prod(out[0].shape) for i in range(len(out)): assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1" - assert ( - out[i].size == size - ), f"out[{i}].size={x.size}, but should be out[0].size={size}" + size_at = math.prod(out[i].shape) + assert size_at == out_size, ( + f"prod(out[{i}].shape)={size_at}, " + f"but should be prod(out[0].shape)={out_size}" + ) ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype") indices = [] if x.dtype == xp.bool: @@ -107,11 +164,11 @@ def test_nonzero(x): if x[idx] != 0: indices.append(idx) if x.ndim == 0: - assert out[0].size == len( + assert out_size == len( indices - ), f"{out[0].size=}, but should be {len(indices)}" + ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}" else: - for i in range(size): + for i in range(out_size): idx = tuple(int(x[i]) for x in out) f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}" f_element = f"x[{idx}]={x[idx]}" @@ -127,14 +184,14 @@ def test_nonzero(x): data=st.data(), ) def test_where(shapes, dtypes, data): - cond = data.draw(xps.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") - x1 = data.draw(xps.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") - x2 = data.draw(xps.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") + cond = data.draw(hh.arrays(dtype=xp.bool, shape=shapes[0]), label="condition") + x1 = data.draw(hh.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1") + x2 = data.draw(hh.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2") out = xp.where(cond, x1, x2) shape = sh.broadcast_shapes(*shapes) - ph.assert_shape("where", out.shape, shape) + ph.assert_shape("where", out_shape=out.shape, expected=shape) # TODO: generate indices without broadcasting arrays _cond = xp.broadcast_to(cond, shape) _x1 = xp.broadcast_to(x1, shape) @@ -142,9 +199,51 @@ def test_where(shapes, dtypes, data): for idx in sh.ndindex(shape): if _cond[idx]: ph.assert_0d_equals( - "where", f"_x1[{idx}]", _x1[idx], f"out[{idx}]", out[idx] + "where", + x_repr=f"_x1[{idx}]", + x_val=_x1[idx], + out_repr=f"out[{idx}]", + out_val=out[idx] ) else: ph.assert_0d_equals( - "where", f"_x2[{idx}]", _x2[idx], f"out[{idx}]", out[idx] + "where", + x_repr=f"_x2[{idx}]", + x_val=_x2[idx], + out_repr=f"out[{idx}]", + out_val=out[idx] ) + + +@pytest.mark.min_version("2023.12") +@given(data=st.data()) +def test_searchsorted(data): + # TODO: test side="right" + # TODO: Allow different dtypes for x1 and x2 + _x1 = data.draw( + st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True), + label="_x1", + ) + x1 = xp.asarray(_x1, dtype=dh.default_float) + if data.draw(st.booleans(), label="use sorter?"): + sorter = xp.argsort(x1) + else: + sorter = None + x1 = xp.sort(x1) + note(f"{x1=}") + x2 = data.draw( + st.lists(st.sampled_from(_x1), unique=True, min_size=1).map( + lambda o: xp.asarray(o, dtype=dh.default_float) + ), + label="x2", + ) + + out = xp.searchsorted(x1, x2, sorter=sorter) + + ph.assert_dtype( + "searchsorted", + in_dtype=[x1.dtype, x2.dtype], + out_dtype=out.dtype, + expected=xp.__array_namespace_info__().default_dtypes()["indexing"], + ) + # TODO: shapes and values testing diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py index 5bae6147..c9abaad1 100644 --- a/array_api_tests/test_set_functions.py +++ b/array_api_tests/test_set_functions.py @@ -1,4 +1,5 @@ # TODO: disable if opted out, refactor things +import cmath import math from collections import Counter, defaultdict @@ -10,12 +11,11 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps -pytestmark = [pytest.mark.ci, pytest.mark.data_dependent_shapes] +pytestmark = [pytest.mark.data_dependent_shapes, pytest.mark.unvectorized] -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_all(x): out = xp.unique_all(x) @@ -25,7 +25,7 @@ def test_unique_all(x): assert hasattr(out, "counts") ph.assert_dtype( - "unique_all", x.dtype, out.values.dtype, repr_name="out.values.dtype" + "unique_all", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" ) ph.assert_default_index( "unique_all", out.indices.dtype, repr_name="out.indices.dtype" @@ -42,8 +42,8 @@ def test_unique_all(x): ), f"{out.indices.shape=}, but should be {out.values.shape=}" ph.assert_shape( "unique_all", - out.inverse_indices.shape, - x.shape, + out_shape=out.inverse_indices.shape, + expected=x.shape, repr_name="out.inverse_indices.shape", ) assert ( @@ -61,7 +61,7 @@ def test_unique_all(x): for idx in sh.ndindex(out.indices.shape): val = scalar_type(out.values[idx]) - if math.isnan(val): + if cmath.isnan(val): break i = int(out.indices[idx]) expected = firsts[val] @@ -88,7 +88,7 @@ def test_unique_all(x): for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 assert count == 1, ( f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " @@ -110,18 +110,18 @@ def test_unique_all(x): vals_idx[val] = idx if dh.is_float_dtype(out.values.dtype): - assume(x.size <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if math.isnan(k)) + assume(math.prod(x.shape) <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_counts(x): out = xp.unique_counts(x) assert hasattr(out, "values") assert hasattr(out, "counts") ph.assert_dtype( - "unique_counts", x.dtype, out.values.dtype, repr_name="out.values.dtype" + "unique_counts", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" ) ph.assert_default_index( "unique_counts", out.counts.dtype, repr_name="out.counts.dtype" @@ -136,7 +136,7 @@ def test_unique_counts(x): for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) count = int(out.counts[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 assert count == 1, ( f"out.counts[{idx}]={count} for out.values[{idx}]={val}, " @@ -157,18 +157,18 @@ def test_unique_counts(x): ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" vals_idx[val] = idx if dh.is_float_dtype(out.values.dtype): - assume(x.size <= 128) # may not be representable - expected = sum(v for k, v in counts.items() if math.isnan(k)) + assume(math.prod(x.shape) <= 128) # may not be representable + expected = sum(v for k, v in counts.items() if cmath.isnan(k)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_inverse(x): out = xp.unique_inverse(x) assert hasattr(out, "values") assert hasattr(out, "inverse_indices") ph.assert_dtype( - "unique_inverse", x.dtype, out.values.dtype, repr_name="out.values.dtype" + "unique_inverse", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype" ) ph.assert_default_index( "unique_inverse", @@ -177,8 +177,8 @@ def test_unique_inverse(x): ) ph.assert_shape( "unique_inverse", - out.inverse_indices.shape, - x.shape, + out_shape=out.inverse_indices.shape, + expected=x.shape, repr_name="out.inverse_indices.shape", ) scalar_type = dh.get_scalar_type(out.values.dtype) @@ -187,7 +187,7 @@ def test_unique_inverse(x): nans = 0 for idx in sh.ndindex(out.values.shape): val = scalar_type(out.values[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 else: assert ( @@ -210,22 +210,22 @@ def test_unique_inverse(x): else: assert val == expected, msg if dh.is_float_dtype(out.values.dtype): - assume(x.size <= 128) # may not be representable + assume(math.prod(x.shape) <= 128) # may not be representable expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}" -@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1))) +@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1))) def test_unique_values(x): out = xp.unique_values(x) - ph.assert_dtype("unique_values", x.dtype, out.dtype) + ph.assert_dtype("unique_values", in_dtype=x.dtype, out_dtype=out.dtype) scalar_type = dh.get_scalar_type(x.dtype) distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape)) vals_idx = {} nans = 0 for idx in sh.ndindex(out.shape): val = scalar_type(out[idx]) - if math.isnan(val): + if cmath.isnan(val): nans += 1 else: assert val in distinct, f"out[{idx}]={val}, but {val} not in input array" @@ -234,6 +234,6 @@ def test_unique_values(x): ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]" vals_idx[val] = idx if dh.is_float_dtype(out.dtype): - assume(x.size <= 128) # may not be representable + assume(math.prod(x.shape) <= 128) # may not be representable expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8)) assert nans == expected, f"{nans} NaNs in out, but should be {expected}" diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py index 2db804b1..1c9a8ef6 100644 --- a/array_api_tests/test_signatures.py +++ b/array_api_tests/test_signatures.py @@ -20,24 +20,19 @@ def squeeze(x, /, axis): ... """ +from collections import defaultdict +from copy import copy from inspect import Parameter, Signature, signature from types import FunctionType -from typing import Any, Callable, Dict, List, Literal, get_args +from typing import Any, Callable, Dict, Literal, get_args +from warnings import warn import pytest -from hypothesis import given, note, settings -from hypothesis import strategies as st -from hypothesis.strategies import DataObject from . import dtype_helpers as dh -from . import hypothesis_helpers as hh -from . import xps -from ._array_module import _UndefinedStub -from ._array_module import mod as xp -from .stubs import array_methods, category_to_funcs, extension_to_funcs -from .typing import Array, DataType - -pytestmark = pytest.mark.ci +from . import xp +from .stubs import (array_methods, category_to_funcs, extension_to_funcs, + name_to_func, info_funcs) ParameterKind = Literal[ Parameter.POSITIONAL_ONLY, @@ -93,6 +88,7 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature): stub_param.name in sig.parameters.keys() ), f"Argument '{stub_param.name}' missing from signature" param = next(p for p in params if p.name == stub_param.name) + f_stub_kind = kind_to_str[stub_param.kind] assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], ( f"{param.name} is a {kind_to_str[param.kind]}, " f"but should be a {f_stub_kind} " @@ -100,17 +96,7 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature): ) -def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]: - if func_name in dh.func_in_dtypes.keys(): - dtypes = dh.func_in_dtypes[func_name] - if hh.FILTER_UNDEFINED_DTYPES: - dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] - return st.sampled_from(dtypes) - else: - return xps.scalar_dtypes() - - -def make_pretty_func(func_name: str, *args: Any, **kwargs: Any): +def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str: f_sig = f"{func_name}(" f_sig += ", ".join(str(a) for a in args) if len(kwargs) != 0: @@ -121,96 +107,166 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any): return f_sig -matrixy_funcs: List[FunctionType] = [ - *category_to_funcs["linear_algebra"], - *extension_to_funcs["linalg"], +# We test uninspectable signatures by passing valid, manually-defined arguments +# to the signature's function/method. +# +# Arguments which require use of the array module are specified as string +# expressions to be eval()'d on runtime. This is as opposed to just using the +# array module whilst setting up the tests, which is prone to halt the entire +# test suite if an array module doesn't support a given expression. +func_to_specified_args = defaultdict( + dict, + { + "permute_dims": {"axes": 0}, + "reshape": {"shape": (1, 5)}, + "broadcast_to": {"shape": (1, 5)}, + "asarray": {"obj": [0, 1, 2, 3, 4]}, + "full_like": {"fill_value": 42}, + "matrix_power": {"n": 2}, + }, +) +func_to_specified_arg_exprs = defaultdict( + dict, + { + "stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"}, + "iinfo": {"type": "xp.int64"}, + "finfo": {"type": "xp.float64"}, + "cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"}, + "inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"}, + "solve": { + a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"] + }, + "outer": {"x1": "xp.ones((5,))", "x2": "xp.ones((5,))"}, + }, +) +# We default most array arguments heuristically. As functions/methods work only +# with arrays of certain dtypes and shapes, we specify only supported arrays +# respective to the function. +casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"] +matrixy_names = [ + f.__name__ + for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"] ] -matrixy_names: List[str] = [f.__name__ for f in matrixy_funcs] matrixy_names += ["__matmul__", "triu", "tril"] +for func_name, func in name_to_func.items(): + stub_sig = signature(func) + array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"} + if func in array_methods: + array_argnames.add("self") + array_argnames -= set(func_to_specified_arg_exprs[func_name].keys()) + if len(array_argnames) > 0: + in_dtypes = dh.func_in_dtypes[func_name] + for dtype_name in ["float64", "bool", "int64", "complex128"]: + # We try float64 first because uninspectable numerical functions + # tend to support float inputs first-and-foremost (i.e. PyTorch) + try: + dtype = getattr(xp, dtype_name) + except AttributeError: + pass + else: + if dtype in in_dtypes: + if func_name in casty_names: + shape = () + elif func_name in matrixy_names: + shape = (3, 3) + else: + shape = (5,) + fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})" + break + else: + warn( + f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does " + "not contain any assumed dtypes, so skipping specifying fallback array." + ) + continue + for argname in array_argnames: + func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr + +def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature): + params = list(stub_sig.parameters.values()) -@given(data=st.data()) -@settings(max_examples=1) -def _test_uninspectable_func( - func_name: str, func: Callable, stub_sig: Signature, array: Array, data: DataObject -): - skip_msg = ( - f"Signature for {func_name}() is not inspectable " - "and is too troublesome to test for otherwise" + if len(params) == 0: + func() + return + + uninspectable_msg = ( + f"Note {func_name}() is not inspectable so arguments are passed " + "manually to test the signature." ) - if func_name in [ - # 0d shapes - "__bool__", - "__int__", - "__index__", - "__float__", - # x2 elements must be >=0 - "pow", - "bitwise_left_shift", - "bitwise_right_shift", - # axis default invalid with 0d shapes - "sort", - # shape requirements - *matrixy_names, - ]: - pytest.skip(skip_msg) - - param_to_value: Dict[Parameter, Any] = {} - for param in stub_sig.parameters.values(): - if param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]: + + argname_to_arg = copy(func_to_specified_args[func_name]) + argname_to_expr = func_to_specified_arg_exprs[func_name] + for argname, expr in argname_to_expr.items(): + assert argname not in argname_to_arg.keys() # sanity check + try: + argname_to_arg[argname] = eval(expr, {"xp": xp}) + except Exception as e: pytest.skip( - skip_msg + f" (because '{param.name}' is a {kind_to_str[param.kind]})" - ) - elif param.default != Parameter.empty: - value = param.default - elif param.name in ["x", "x1"]: - dtypes = get_dtypes_strategy(func_name) - value = data.draw( - xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name + f"Exception occured when evaluating {argname}={expr}: {e}\n" + f"{uninspectable_msg}" ) - elif param.name in ["x2", "other"]: - if param.name == "x2": - assert "x1" in [p.name for p in param_to_value.keys()] # sanity check - orig = next(v for p, v in param_to_value.items() if p.name == "x1") + + posargs = [] + posorkw_args = {} + kwargs = {} + no_arg_msg = ( + "We have no argument specified for '{}'. Please ensure you're using " + "the latest version of array-api-tests, then open an issue if one " + f"doesn't already exist. {uninspectable_msg}" + ) + for param in params: + if param.kind == Parameter.POSITIONAL_ONLY: + try: + posargs.append(argname_to_arg[param.name]) + except KeyError: + pytest.skip(no_arg_msg.format(param.name)) + elif param.kind == Parameter.POSITIONAL_OR_KEYWORD: + if param.default == Parameter.empty: + try: + posorkw_args[param.name] = argname_to_arg[param.name] + except KeyError: + pytest.skip(no_arg_msg.format(param.name)) else: - assert array is not None # sanity check - orig = array - value = data.draw( - xps.arrays(dtype=orig.dtype, shape=orig.shape), label=param.name - ) + assert argname_to_arg[param.name] + posorkw_args[param.name] = param.default + elif param.kind == Parameter.KEYWORD_ONLY: + assert param.default != Parameter.empty # sanity check + kwargs[param.name] = param.default else: - pytest.skip( - skip_msg + f" (because no default was found for argument {param.name})" - ) - param_to_value[param] = value - - args: List[Any] = [ - v for p, v in param_to_value.items() if p.kind == Parameter.POSITIONAL_ONLY - ] - kwargs: Dict[str, Any] = { - p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY - } - f_func = make_pretty_func(func_name, *args, **kwargs) - note(f"trying {f_func}") - func(*args, **kwargs) + assert param.kind in VAR_KINDS # sanity check + pytest.skip(no_arg_msg.format(param.name)) + if len(posorkw_args) == 0: + func(*posargs, **kwargs) + else: + posorkw_name_to_arg_pairs = list(posorkw_args.items()) + for i in range(len(posorkw_name_to_arg_pairs), -1, -1): + extra_posargs = [arg for _, arg in posorkw_name_to_arg_pairs[:i]] + extra_kwargs = dict(posorkw_name_to_arg_pairs[i:]) + func(*posargs, *extra_posargs, **kwargs, **extra_kwargs) -def _test_func_signature(func: Callable, stub: FunctionType, array=None): +def _test_func_signature(func: Callable, stub: FunctionType, is_method=False): stub_sig = signature(stub) # If testing against array, ignore 'self' arg in stub as it won't be present # in func (which should be a method). - if array is not None: + if is_method: stub_params = list(stub_sig.parameters.values()) - del stub_params[0] + if stub_params[0].name == "self": + del stub_params[0] stub_sig = Signature( parameters=stub_params, return_annotation=stub_sig.return_annotation ) try: sig = signature(func) - _test_inspectable_func(sig, stub_sig) except ValueError: - _test_uninspectable_func(stub.__name__, func, stub_sig, array) + try: + _test_uninspectable_func(stub.__name__, func, stub_sig) + except Exception as e: + raise e from None # suppress parent exception for cleaner pytest output + else: + _test_inspectable_func(sig, stub_sig) @pytest.mark.parametrize( @@ -244,11 +300,24 @@ def test_extension_func_signature(extension: str, stub: FunctionType): @pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__) -@given(st.data()) -@settings(max_examples=1) -def test_array_method_signature(stub: FunctionType, data: DataObject): - dtypes = get_dtypes_strategy(stub.__name__) - x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x") +def test_array_method_signature(stub: FunctionType): + x_expr = func_to_specified_arg_exprs[stub.__name__]["self"] + try: + x = eval(x_expr, {"xp": xp}) + except Exception as e: + pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}") assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}" method = getattr(x, stub.__name__) - _test_func_signature(method, stub, array=x) + _test_func_signature(method, stub, is_method=True) + +if info_funcs: # pytest fails collecting if info_funcs is empty + @pytest.mark.min_version("2023.12") + @pytest.mark.parametrize("stub", info_funcs, ids=lambda f: f.__name__) + def test_info_func_signature(stub: FunctionType): + try: + info_namespace = xp.__array_namespace_info__() + except Exception as e: + raise AssertionError(f"Could not get info namespace from xp.__array_namespace_info__(): {e}") + + func = getattr(info_namespace, stub.__name__) + _test_func_signature(func, stub) diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py index 7c5a1411..3d25798c 100644 --- a/array_api_tests/test_sorting_functions.py +++ b/array_api_tests/test_sorting_functions.py @@ -1,4 +1,4 @@ -import math +import cmath from typing import Set import pytest @@ -11,31 +11,28 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps from .typing import Scalar, Shape -pytestmark = pytest.mark.ci - def assert_scalar_in_set( func_name: str, idx: Shape, out: Scalar, set_: Set[Scalar], - /, - **kw, + kw={}, ): out_repr = "out" if idx == () else f"out[{idx}]" - if math.isnan(out): + if cmath.isnan(out): raise NotImplementedError() msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]" assert out in set_, msg # TODO: Test with signed zeros and NaNs (and ignore them somehow) +@pytest.mark.unvectorized @given( - x=xps.arrays( - dtype=xps.scalar_dtypes(), + x=hh.arrays( + dtype=hh.real_dtypes, shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), @@ -57,9 +54,9 @@ def test_argsort(x, data): out = xp.argsort(x, **kw) ph.assert_default_index("argsort", out.dtype) - ph.assert_shape("argsort", out.shape, x.shape, **kw) + ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw) axis = kw.get("axis", -1) - axes = sh.normalise_axis(axis, x.ndim) + axes = sh.normalize_axis(axis, x.ndim) scalar_type = dh.get_scalar_type(x.dtype) for indices in sh.axes_ndindex(x.shape, axes): elements = [scalar_type(x[idx]) for idx in indices] @@ -69,7 +66,7 @@ def test_argsort(x, data): ) if kw.get("stable", True): for idx, o in zip(indices, sorders): - ph.assert_scalar_equals("argsort", int, idx, int(out[idx]), o, **kw) + ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw) else: idx_elements = dict(zip(indices, elements)) idx_orders = dict(zip(indices, orders)) @@ -84,18 +81,19 @@ def test_argsort(x, data): out_o = int(out[idx]) if len(expected_orders) == 1: ph.assert_scalar_equals( - "argsort", int, idx, out_o, expected_orders[0], **kw + "argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw ) else: assert_scalar_in_set( - "argsort", idx, out_o, set(expected_orders), **kw + "argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw ) +@pytest.mark.unvectorized # TODO: Test with signed zeros and NaNs (and ignore them somehow) @given( - x=xps.arrays( - dtype=xps.scalar_dtypes(), + x=hh.arrays( + dtype=hh.real_dtypes, shape=hh.shapes(min_dims=1, min_side=1), elements={"allow_nan": False}, ), @@ -116,10 +114,10 @@ def test_sort(x, data): out = xp.sort(x, **kw) - ph.assert_dtype("sort", out.dtype, x.dtype) - ph.assert_shape("sort", out.shape, x.shape, **kw) + ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype) + ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw) axis = kw.get("axis", -1) - axes = sh.normalise_axis(axis, x.ndim) + axes = sh.normalize_axis(axis, x.ndim) scalar_type = dh.get_scalar_type(x.dtype) for indices in sh.axes_ndindex(x.shape, axes): elements = [scalar_type(x[idx]) for idx in indices] @@ -132,9 +130,9 @@ def test_sort(x, data): # TODO: error message when unstable should not imply just one idx ph.assert_0d_equals( "sort", - f"x[{x_idx}]", - x[x_idx], - f"out[{out_idx}]", - out[out_idx], - **kw, + x_repr=f"x[{x_idx}]", + x_val=x[x_idx], + out_repr=f"out[{out_idx}]", + out_val=out[out_idx], + kw=kw, ) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 9999d9b0..bf05a262 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -19,28 +19,21 @@ from dataclasses import dataclass, field from decimal import ROUND_HALF_EVEN, Decimal from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple -from warnings import warn +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal +from warnings import warn, filterwarnings, catch_warnings import pytest -from hypothesis import assume, given, note +from hypothesis import given, note, settings, assume from hypothesis import strategies as st +from hypothesis.errors import NonInteractiveExampleWarning from array_api_tests.typing import Array, DataType from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph -from . import shape_helpers as sh -from . import xps -from ._array_module import mod as xp +from . import xp, xps from .stubs import category_to_funcs -from .test_operators_and_elementwise_functions import ( - oneway_broadcastable_shapes, - oneway_promotable_dtypes, -) - -pytestmark = pytest.mark.ci UnaryCheck = Callable[[float], bool] BinaryCheck = Callable[[float, float], bool] @@ -130,6 +123,8 @@ def abs_cond(i: float) -> bool: "infinity": float("inf"), "0": 0.0, "1": 1.0, + "False": 0.0, + "True": 1.0, } r_value = re.compile(r"([+-]?)(.+)") r_pi = re.compile(r"(\d?)π(?:/(\d))?") @@ -162,7 +157,10 @@ def parse_value(value_str: str) -> float: if denominator := pi_m.group(2): value /= int(denominator) else: - value = repr_to_value[m.group(2)] + try: + value = repr_to_value[m.group(2)] + except KeyError as e: + raise ParseError(value_str) from e if sign := m.group(1): if sign == "-": value *= -1 @@ -263,7 +261,7 @@ class BoundFromDtype(FromDtypeFunc): def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]: assert len(kw) == 0 # sanity check - from_dtype = self.base_func or xps.from_dtype + from_dtype = self.base_func or hh.from_dtype strat = from_dtype(dtype, **self.kwargs) if self.filter_ is not None: strat = strat.filter(self.filter_) @@ -497,6 +495,7 @@ def check_result(result: float) -> bool: class Case(Protocol): cond_expr: str result_expr: str + raw_case: Optional[str] def cond(self, *args) -> bool: ... @@ -511,7 +510,10 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(<{self}>)" -r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters") +r_case_block = re.compile( + r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*" + r"(?:.+\n--+)?(?:\.\. versionchanged.*)?" +) r_case = re.compile(r"\s+-\s*(.*)\.") @@ -532,6 +534,7 @@ class UnaryCase(Case): cond_from_dtype: FromDtypeFunc cond: UnaryCheck check_result: UnaryResultCheck + raw_case: Optional[str] = field(default=None) r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)") @@ -542,6 +545,10 @@ class UnaryCase(Case): "If two integers are equally close to ``x_i``, " "the result is the even integer closest to ``x_i``" ) +r_nan_signbit = re.compile( + "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, " + "the result is ``(.+)``" +) def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]: @@ -597,6 +604,25 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]: ) +def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase: + if signbit: + sign = -1 + nan_expr = "-NaN" + float_arg = "-nan" + else: + sign = 1 + nan_expr = "+NaN" + float_arg = "nan" + + return UnaryCase( + cond_expr=f"x_i is {nan_expr}", + cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign, + cond_from_dtype=lambda _: st.just(float(float_arg)), + result_expr=str(expected), + check_result=lambda _, result: result == float(expected), + ) + + def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck: def check_result(i: float, result: float) -> bool: return check_just_result(result) @@ -604,7 +630,7 @@ def check_result(i: float, result: float) -> bool: return check_result -def parse_unary_case_block(case_block: str) -> List[UnaryCase]: +def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]: """ Parses a Sphinx-formatted docstring of a unary function to return a list of codified unary cases, e.g. @@ -635,7 +661,7 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: ... ''' ... >>> case_block = r_case_block.search(sqrt.__doc__).group(1) - >>> unary_cases = parse_unary_case_block(case_block) + >>> unary_cases = parse_unary_case_block(case_block, 'sqrt') >>> for case in unary_cases: ... print(repr(case)) UnaryCase( NaN>) @@ -653,16 +679,20 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cases = [] for case_m in r_case.finditer(case_block): case_str = case_m.group(1) - if m := r_already_int_case.search(case_str): + if r_already_int_case.search(case_str): cases.append(already_int_case) - elif m := r_even_round_halves_case.search(case_str): + elif r_even_round_halves_case.search(case_str): cases.append(even_round_halves_case) + elif m := r_nan_signbit.search(case_str): + signbit = parse_value(m.group(1)) + expected = bool(parse_value(m.group(2))) + cases.append(make_nan_signbit_case(signbit, expected)) elif m := r_unary_case.search(case_str): try: cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1)) _check_result, result_expr = parse_result(m.group(2)) except ParseError as e: - warn(f"not machine-readable: '{e.value}'") + warn(f"case for {func_name} not machine-readable: '{e.value}'") continue cond_expr = cond_expr_template.replace("{}", "x_i") # Do not define check_result in this function's body - see @@ -674,11 +704,12 @@ def parse_unary_case_block(case_block: str) -> List[UnaryCase]: cond_from_dtype=cond_from_dtype, result_expr=result_expr, check_result=check_result, + raw_case=case_str, ) cases.append(case) else: if not r_remaining_case.search(case_str): - warn(f"case not machine-readable: '{case_str}'") + warn(f"case for {func_name} not machine-readable: '{case_str}'") return cases @@ -700,6 +731,7 @@ class BinaryCase(Case): x2_cond_from_dtype: FromDtypeFunc cond: BinaryCond check_result: BinaryResultCheck + raw_case: Optional[str] = field(default=None) r_binary_case = re.compile("If (.+), the result (.+)") @@ -937,12 +969,18 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) == math.copysign(1, i2) + x1_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1})) + x2_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1})) + elif value_str == "different mathematical signs": partial_expr = "copysign(1, x1_i) != copysign(1, x2_i)" def partial_cond(i1: float, i2: float) -> bool: return math.copysign(1, i1) != math.copysign(1, i2) + x1_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1})) + x2_cond_from_dtypes.append(BoundFromDtype(kwargs={"max_value": -1})) + else: unary_cond, expr_template, cond_from_dtype = parse_cond(value_str) # Do not define partial_cond via the def keyword or lambda @@ -1007,7 +1045,7 @@ def _x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: return use_x1_or_x2_strat.flatmap( lambda t: cond_from_dtype(dtype) if t[0] - else xps.from_dtype(dtype) + else hh.from_dtype(dtype) ) def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: @@ -1015,7 +1053,7 @@ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]: return use_x1_or_x2_strat.flatmap( lambda t: cond_from_dtype(dtype) if t[1] - else xps.from_dtype(dtype) + else hh.from_dtype(dtype) ) x1_cond_from_dtypes.append( @@ -1058,13 +1096,14 @@ def cond(i1: float, i2: float) -> bool: x2_cond_from_dtype=x2_cond_from_dtype, result_expr=result_expr, check_result=check_result, + raw_case=case_str, ) r_redundant_case = re.compile("result.+determined by the rule already stated above") -def parse_binary_case_block(case_block: str) -> List[BinaryCase]: +def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]: """ Parses a Sphinx-formatted docstring of a binary function to return a list of codified binary cases, e.g. @@ -1095,7 +1134,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: ... ''' ... >>> case_block = r_case_block.search(logaddexp.__doc__).group(1) - >>> binary_cases = parse_binary_case_block(case_block) + >>> binary_cases = parse_binary_case_block(case_block, 'logaddexp') >>> for case in binary_cases: ... print(repr(case)) BinaryCase( NaN>) @@ -1113,10 +1152,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: case = parse_binary_case(case_str) cases.append(case) except ParseError as e: - warn(f"not machine-readable: '{e.value}'") + warn(f"case for {func_name} not machine-readable: '{e.value}'") else: if not r_remaining_case.match(case_str): - warn(f"case not machine-readable: '{case_str}'") + warn(f"case for {func_name} not machine-readable: '{case_str}'") return cases @@ -1125,8 +1164,9 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: iop_params = [] func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()} for stub in category_to_funcs["elementwise"]: + func_name = stub.__name__ if stub.__doc__ is None: - warn(f"{stub.__name__}() stub has no docstring") + warn(f"{func_name}() stub has no docstring") continue if m := r_case_block.search(stub.__doc__): case_block = m.group(1) @@ -1134,10 +1174,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: continue marks = [] try: - func = getattr(xp, stub.__name__) + func = getattr(xp, func_name) except AttributeError: marks.append( - pytest.mark.skip(reason=f"{stub.__name__} not found in array module") + pytest.mark.skip(reason=f"{func_name} not found in array module") ) func = None sig = inspect.signature(stub) @@ -1146,10 +1186,10 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: warn(f"{func=} has no parameters") continue if param_names[0] == "x": - if cases := parse_unary_case_block(case_block): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] + if cases := parse_unary_case_block(case_block, func_name): + name_to_func = {func_name: func} + if func_name in func_to_op.keys(): + op_name = func_to_op[func_name] op = getattr(operator, op_name) name_to_func[op_name] = op for func_name, func in name_to_func.items(): @@ -1158,19 +1198,21 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: p = pytest.param(func_name, func, case, id=id_) unary_params.append(p) else: - warn(f"Special cases found for {stub.__name__} but none were parsed") + warn(f"Special cases found for {func_name} but none were parsed") continue if len(sig.parameters) == 1: warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'") continue if param_names[0] == "x1" and param_names[1] == "x2": - if cases := parse_binary_case_block(case_block): - name_to_func = {stub.__name__: func} - if stub.__name__ in func_to_op.keys(): - op_name = func_to_op[stub.__name__] + if cases := parse_binary_case_block(case_block, func_name): + name_to_func = {func_name: func} + if func_name in func_to_op.keys(): + op_name = func_to_op[func_name] op = getattr(operator, op_name) name_to_func[op_name] = op # We collect inplace operator test cases seperately + if "equal" in func_name: + continue iop_name = "__i" + op_name[2:] iop = getattr(operator, iop_name) for case in cases: @@ -1183,7 +1225,7 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: p = pytest.param(func_name, func, case, id=id_) binary_params.append(p) else: - warn(f"Special cases found for {stub.__name__} but none were parsed") + warn(f"Special cases found for {func_name} but none were parsed") continue else: warn( @@ -1201,134 +1243,69 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]: # its False - Hypothesis will complain if we reject too many examples, thus # indicating we've done something wrong. +# sanity checks +assert len(unary_params) != 0 +assert len(binary_params) != 0 +assert len(iop_params) != 0 + @pytest.mark.parametrize("func_name, func, case", unary_params) -@given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), - data=st.data(), -) -def test_unary(func_name, func, case, x, data): - set_idx = data.draw( - xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" +def test_unary(func_name, func, case): + with catch_warnings(): + # XXX: We are using example here to generate one example draw, but + # hypothesis issues a warning from this. We should consider either + # drawing multiple examples like a normal test, or just hard-coding a + # single example test case without using hypothesis. + filterwarnings('ignore', category=NonInteractiveExampleWarning) + in_value = case.cond_from_dtype(xp.float64).example() + x = xp.asarray(in_value, dtype=xp.float64) + out = func(x) + out_value = float(out) + assert case.check_result(in_value, out_value), ( + f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" ) - set_value = data.draw(case.cond_from_dtype(x.dtype), label="set value") - x[set_idx] = set_value - note(f"{x=}") - - res = func(x) - - good_example = False - for idx in sh.ndindex(res.shape): - in_ = float(x[idx]) - if case.cond(in_): - good_example = True - out = float(res[idx]) - f_in = f"{sh.fmt_idx('x', idx)}={in_}" - f_out = f"{sh.fmt_idx('out', idx)}={out}" - assert case.check_result(in_, out), ( - f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" - f"condition: {case.cond_expr}\n" - f"{f_in}" - ) - break - assume(good_example) - - -x1_strat, x2_strat = hh.two_mutual_arrays( - dtypes=dh.float_dtypes, - two_shapes=hh.mutually_broadcastable_shapes(2, min_side=1), -) @pytest.mark.parametrize("func_name, func, case", binary_params) -@given(x1=x1_strat, x2=x2_strat, data=st.data()) -def test_binary(func_name, func, case, x1, x2, data): - result_shape = sh.broadcast_shapes(x1.shape, x2.shape) - all_indices = list(sh.iter_indices(x1.shape, x2.shape, result_shape)) - - indices_strat = st.shared(st.sampled_from(all_indices)) - set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") - set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") - x1[set_x1_idx] = set_x1_value - note(f"{x1=}") - set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") - set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") - x2[set_x2_idx] = set_x2_value - note(f"{x2=}") - - res = func(x1, x2) - # sanity check - ph.assert_result_shape(func_name, [x1.shape, x2.shape], res.shape, result_shape) - - good_example = False - for l_idx, r_idx, o_idx in all_indices: - l = float(x1[l_idx]) - r = float(x2[r_idx]) - if case.cond(l, r): - good_example = True - o = float(res[o_idx]) - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" - f_out = f"{sh.fmt_idx('out', o_idx)}={o}" - assert case.check_result(l, r, o), ( - f"{f_out}, but should be {case.result_expr} [{func_name}()]\n" - f"condition: {case}\n" - f"{f_left}, {f_right}" - ) - break - assume(good_example) +@settings(max_examples=1) +@given(data=st.data()) +def test_binary(func_name, func, case, data): + # We don't use example() like in test_unary because the same internal shared + # strategies used in both x1's and x2's don't "sync" with example() draws. + x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") + x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") + x1 = xp.asarray(x1_value, dtype=xp.float64) + x2 = xp.asarray(x2_value, dtype=xp.float64) + + out = func(x1, x2) + out_value = float(out) + + assert case.check_result(x1_value, x2_value, out_value), ( + f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n" + f"condition: {case}\n" + f"x1={x1_value}, x2={x2_value}" + ) + @pytest.mark.parametrize("iop_name, iop, case", iop_params) -@given( - oneway_dtypes=oneway_promotable_dtypes(dh.float_dtypes), - oneway_shapes=oneway_broadcastable_shapes(), - data=st.data(), -) -def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): - x1 = data.draw( - xps.arrays(dtype=oneway_dtypes.result_dtype, shape=oneway_shapes.result_shape), - label="x1", +@settings(max_examples=1) +@given(data=st.data()) +def test_iop(iop_name, iop, case, data): + # See test_binary comment + x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value") + x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value") + x1 = xp.asarray(x1_value, dtype=xp.float64) + x2 = xp.asarray(x2_value, dtype=xp.float64) + + res = iop(x1, x2) + res_value = float(res) + + assert case.check_result(x1_value, x2_value, res_value), ( + f"x1={res}, but should be {case.result_expr} [{func_name}()]\n" + f"condition: {case}\n" + f"x1={x1_value}, x2={x2_value}" ) - x2 = data.draw( - xps.arrays(dtype=oneway_dtypes.input_dtype, shape=oneway_shapes.input_shape), - label="x2", - ) - - all_indices = list(sh.iter_indices(x1.shape, x2.shape, x1.shape)) - - indices_strat = st.shared(st.sampled_from(all_indices)) - set_x1_idx = data.draw(indices_strat.map(lambda t: t[0]), label="set x1 idx") - set_x1_value = data.draw(case.x1_cond_from_dtype(x1.dtype), label="set x1 value") - x1[set_x1_idx] = set_x1_value - note(f"{x1=}") - set_x2_idx = data.draw(indices_strat.map(lambda t: t[1]), label="set x2 idx") - set_x2_value = data.draw(case.x2_cond_from_dtype(x2.dtype), label="set x2 value") - x2[set_x2_idx] = set_x2_value - note(f"{x2=}") - - res = xp.asarray(x1, copy=True) - res = iop(res, x2) - # sanity check - ph.assert_result_shape(iop_name, [x1.shape, x2.shape], res.shape) - - good_example = False - for l_idx, r_idx, o_idx in all_indices: - l = float(x1[l_idx]) - r = float(x2[r_idx]) - if case.cond(l, r): - good_example = True - o = float(res[o_idx]) - f_left = f"{sh.fmt_idx('x1', l_idx)}={l}" - f_right = f"{sh.fmt_idx('x2', r_idx)}={r}" - f_out = f"{sh.fmt_idx('out', o_idx)}={o}" - assert case.check_result(l, r, o), ( - f"{f_out}, but should be {case.result_expr} [{iop_name}()]\n" - f"condition: {case}\n" - f"{f_left}, {f_right}" - ) - break - assume(good_example) @pytest.mark.parametrize( @@ -1345,7 +1322,7 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data): def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get expected func = getattr(xp, func_name) out = func(xp.asarray([], dtype=dh.default_float)) - ph.assert_shape(func_name, out.shape, ()) # sanity check + ph.assert_shape(func_name, out_shape=out.shape, expected=()) # sanity check msg = f"{out=!r}, but should be {expected}" if math.isnan(expected): assert xp.isnan(out), msg @@ -1354,21 +1331,23 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp @pytest.mark.parametrize( - "func_name", [f.__name__ for f in category_to_funcs["statistical"]] + "func_name", [f.__name__ for f in category_to_funcs["statistical"] + if f.__name__ not in ['cumulative_sum', 'cumulative_prod']] ) @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)), + x=hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes(min_side=1)), data=st.data(), ) def test_nan_propagation(func_name, x, data): func = getattr(xp, func_name) - set_idx = data.draw( - xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" + nan_positions = data.draw( + hh.arrays(dtype=hh.bool_dtype, shape=x.shape), label="nan_positions" ) - x[set_idx] = float("nan") + assume(xp.any(nan_positions)) + x = xp.where(nan_positions, xp.asarray(float("nan")), x) note(f"{x=}") out = func(x) - ph.assert_shape(func_name, out.shape, ()) # sanity check + ph.assert_shape(func_name, out_shape=out.shape, expected=()) # sanity check assert xp.isnan(out), f"{out=!r}, but should be NaN" diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index 8f05bc13..0e3aa9d4 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -1,31 +1,165 @@ +import cmath import math from typing import Optional import pytest from hypothesis import assume, given from hypothesis import strategies as st -from hypothesis.control import reject +from ndindex import iter_indices from . import _array_module as xp from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps from ._array_module import _UndefinedStub from .typing import DataType -pytestmark = pytest.mark.ci + +@pytest.mark.min_version("2023.12") +@pytest.mark.unvectorized +@given( + x=hh.arrays( + dtype=hh.numeric_dtypes, + shape=hh.shapes(min_dims=1)), + data=st.data(), +) +def test_cumulative_sum(x, data): + axes = st.integers(-x.ndim, x.ndim - 1) + if x.ndim == 1: + axes = axes | st.none() + axis = data.draw(axes, label='axis') + _axis, = sh.normalize_axis(axis, x.ndim) + dtype = data.draw(kwarg_dtypes(x.dtype)) + include_initial = data.draw(st.booleans(), label="include_initial") + + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("dtype", dtype, None), + ("include_initial", include_initial, False), + ), + label="kw", + ) + + out = xp.cumulative_sum(x, **kw) + + expected_shape = list(x.shape) + if include_initial: + expected_shape[_axis] += 1 + expected_shape = tuple(expected_shape) + ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape) + + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: + ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + + scalar_type = dh.get_scalar_type(out.dtype) + + for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): + x_arr = x[x_idx.raw] + out_arr = out[out_idx.raw] + + if include_initial: + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0) + + for n in range(x.shape[_axis]): + start = 1 if include_initial else 0 + out_val = out_arr[n + start] + assume(cmath.isfinite(out_val)) + elements = [] + for idx in range(n + 1): + s = scalar_type(x_arr[idx]) + elements.append(s) + expected = sum(elements) + if dh.is_int_dtype(out.dtype): + m, M = dh.dtype_ranges[out.dtype] + assume(m <= expected <= M) + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, + idx=out_idx.raw, out=out_val, + expected=expected) + else: + condition_number = _sum_condition_number(elements) + assume(condition_number < 1e6) + ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type, + idx=out_idx.raw, out=out_val, + expected=expected) + + + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@given( + x=hh.arrays( + dtype=hh.numeric_dtypes, + shape=hh.shapes(min_dims=1)), + data=st.data(), +) +def test_cumulative_prod(x, data): + axes = st.integers(-x.ndim, x.ndim - 1) + if x.ndim == 1: + axes = axes | st.none() + axis = data.draw(axes, label='axis') + _axis, = sh.normalize_axis(axis, x.ndim) + dtype = data.draw(kwarg_dtypes(x.dtype)) + include_initial = data.draw(st.booleans(), label="include_initial") + + kw = data.draw( + hh.specified_kwargs( + ("axis", axis, None), + ("dtype", dtype, None), + ("include_initial", include_initial, False), + ), + label="kw", + ) + + out = xp.cumulative_prod(x, **kw) + + expected_shape = list(x.shape) + if include_initial: + expected_shape[_axis] += 1 + expected_shape = tuple(expected_shape) + ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape) + + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check + else: + ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + + scalar_type = dh.get_scalar_type(out.dtype) + + for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis): + #x_arr = x[x_idx.raw] + out_arr = out[out_idx.raw] + + if include_initial: + ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=1) + + #TODO: add value testing of cumulative_prod def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] + dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)] + assert len(dtypes) > 0 # sanity check return st.none() | st.sampled_from(dtypes) +@pytest.mark.unvectorized @given( - x=xps.arrays( - dtype=xps.numeric_dtypes(), + x=hh.arrays( + dtype=hh.real_dtypes, shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), @@ -33,13 +167,14 @@ def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: ) def test_max(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.max(x, **kw) - ph.assert_dtype("max", x.dtype, out.dtype) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_dtype("max", in_dtype=x.dtype, out_dtype=out.dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "max", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "max", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -49,12 +184,12 @@ def test_max(x, data): s = scalar_type(x[idx]) elements.append(s) expected = max(elements) - ph.assert_scalar_equals("max", scalar_type, out_idx, max_, expected) + ph.assert_scalar_equals("max", type_=scalar_type, idx=out_idx, out=max_, expected=expected) @given( - x=xps.arrays( - dtype=xps.floating_dtypes(), + x=hh.arrays( + dtype=hh.real_floating_dtypes, shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), @@ -62,20 +197,22 @@ def test_max(x, data): ) def test_mean(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.mean(x, **kw) - ph.assert_dtype("mean", x.dtype, out.dtype) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_dtype("mean", in_dtype=x.dtype, out_dtype=out.dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "mean", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "mean", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) # Values testing mean is too finicky +@pytest.mark.unvectorized @given( - x=xps.arrays( - dtype=xps.numeric_dtypes(), + x=hh.arrays( + dtype=hh.real_dtypes, shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), @@ -83,13 +220,14 @@ def test_mean(x, data): ) def test_min(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.min(x, **kw) - ph.assert_dtype("min", x.dtype, out.dtype) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_dtype("min", in_dtype=x.dtype, out_dtype=out.dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "min", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "min", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -99,12 +237,23 @@ def test_min(x, data): s = scalar_type(x[idx]) elements.append(s) expected = min(elements) - ph.assert_scalar_equals("min", scalar_type, out_idx, min_, expected) + ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected) + + +def _prod_condition_number(elements): + # Relative condition number using the infinity norm + abs_max = max([abs(i) for i in elements]) + abs_min = min([abs(i) for i in elements]) + + if abs_min == 0: + return float('inf') + return abs_max / abs_min +@pytest.mark.unvectorized @given( - x=xps.arrays( - dtype=xps.numeric_dtypes(), + x=hh.arrays( + dtype=hh.numeric_dtypes, shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), @@ -119,44 +268,29 @@ def test_prod(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) - try: + with hh.reject_overflow(): out = xp.prod(x, **kw) - except OverflowError: - reject() dtype = kw.get("dtype", None) - if dtype is None: - if dh.is_int_dtype(x.dtype): - if x.dtype in dh.uint_dtypes: - default_dtype = dh.default_uint - else: - default_dtype = dh.default_int - m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[default_dtype] - if m < d_m or M > d_M: - _dtype = x.dtype - else: - _dtype = default_dtype - else: - if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: - _dtype = x.dtype - else: - _dtype = dh.default_float + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/106 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check else: - _dtype = dtype - # We ignore asserting the out dtype if what we expect is undefined - # See https://github.com/data-apis/array-api-tests/issues/106 - if not isinstance(_dtype, _UndefinedStub): - ph.assert_dtype("prod", x.dtype, out.dtype, _dtype) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "prod", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): prod = scalar_type(out[out_idx]) - assume(math.isfinite(prod)) + assume(cmath.isfinite(prod)) elements = [] for idx in indices: s = scalar_type(x[idx]) @@ -165,47 +299,65 @@ def test_prod(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - ph.assert_scalar_equals("prod", scalar_type, out_idx, prod, expected) + ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, + out=prod, expected=expected) + else: + condition_number = _prod_condition_number(elements) + assume(condition_number < 1e15) + ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx, + out=prod, expected=expected) +@pytest.mark.skip(reason="flaky") # TODO: fix! @given( - x=xps.arrays( - dtype=xps.floating_dtypes(), + x=hh.arrays( + dtype=hh.real_floating_dtypes, shape=hh.shapes(min_side=1), elements={"allow_nan": False}, - ).filter(lambda x: x.size >= 2), + ).filter(lambda x: math.prod(x.shape) >= 2), data=st.data(), ) def test_std(x, data): axis = data.draw(hh.axes(x.ndim), label="axis") - _axes = sh.normalise_axis(axis, x.ndim) + _axes = sh.normalize_axis(axis, x.ndim) N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), label="correction", ) - keepdims = data.draw(st.booleans(), label="keepdims") + _keepdims = data.draw(st.booleans(), label="keepdims") kw = data.draw( hh.specified_kwargs( ("axis", axis, None), ("correction", correction, 0.0), - ("keepdims", keepdims, False), + ("keepdims", _keepdims, False), ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.std(x, **kw) - ph.assert_dtype("std", x.dtype, out.dtype) + ph.assert_dtype("std", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_keepdimable_shape( - "std", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "std", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) # We can't easily test the result(s) as standard deviation methods vary a lot +def _sum_condition_number(elements): + sum_abs = sum([abs(i) for i in elements]) + abs_sum = abs(sum(elements)) + + if abs_sum == 0: + return float('inf') + + return sum_abs / abs_sum + +# @pytest.mark.unvectorized @given( - x=xps.arrays( - dtype=xps.numeric_dtypes(), + x=hh.arrays( + dtype=hh.numeric_dtypes, shape=hh.shapes(min_side=1), elements={"allow_nan": False}, ), @@ -220,41 +372,29 @@ def test_sum(x, data): ), label="kw", ) + keepdims = kw.get("keepdims", False) - try: + with hh.reject_overflow(): out = xp.sum(x, **kw) - except OverflowError: - reject() dtype = kw.get("dtype", None) - if dtype is None: - if dh.is_int_dtype(x.dtype): - if x.dtype in dh.uint_dtypes: - default_dtype = dh.default_uint - else: - default_dtype = dh.default_int - m, M = dh.dtype_ranges[x.dtype] - d_m, d_M = dh.dtype_ranges[default_dtype] - if m < d_m or M > d_M: - _dtype = x.dtype - else: - _dtype = default_dtype - else: - if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]: - _dtype = x.dtype - else: - _dtype = dh.default_float + expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype) + if expected_dtype is None: + # If a default uint cannot exist (i.e. in PyTorch which doesn't support + # uint32 or uint64), we skip testing the output dtype. + # See https://github.com/data-apis/array-api-tests/issues/160 + if x.dtype in dh.uint_dtypes: + assert dh.is_int_dtype(out.dtype) # sanity check else: - _dtype = dtype - ph.assert_dtype("sum", x.dtype, out.dtype, _dtype) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "sum", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(out.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): sum_ = scalar_type(out[out_idx]) - assume(math.isfinite(sum_)) + assume(cmath.isfinite(sum_)) elements = [] for idx in indices: s = scalar_type(x[idx]) @@ -263,39 +403,50 @@ def test_sum(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - ph.assert_scalar_equals("sum", scalar_type, out_idx, sum_, expected) + ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, + out=sum_, expected=expected) + else: + # Avoid value testing for ill conditioned summations. See + # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and + # https://en.wikipedia.org/wiki/Condition_number. + condition_number = _sum_condition_number(elements) + assume(condition_number < 1e6) + ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) +@pytest.mark.unvectorized +@pytest.mark.skip(reason="flaky") # TODO: fix! @given( - x=xps.arrays( - dtype=xps.floating_dtypes(), + x=hh.arrays( + dtype=hh.real_floating_dtypes, shape=hh.shapes(min_side=1), elements={"allow_nan": False}, - ).filter(lambda x: x.size >= 2), + ).filter(lambda x: math.prod(x.shape) >= 2), data=st.data(), ) def test_var(x, data): axis = data.draw(hh.axes(x.ndim), label="axis") - _axes = sh.normalise_axis(axis, x.ndim) + _axes = sh.normalize_axis(axis, x.ndim) N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes) correction = data.draw( st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N), label="correction", ) - keepdims = data.draw(st.booleans(), label="keepdims") + _keepdims = data.draw(st.booleans(), label="keepdims") kw = data.draw( hh.specified_kwargs( ("axis", axis, None), ("correction", correction, 0.0), - ("keepdims", keepdims, False), + ("keepdims", _keepdims, False), ), label="kw", ) + keepdims = kw.get("keepdims", False) out = xp.var(x, **kw) - ph.assert_dtype("var", x.dtype, out.dtype) + ph.assert_dtype("var", in_dtype=x.dtype, out_dtype=out.dtype) ph.assert_keepdimable_shape( - "var", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "var", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) # We can't easily test the result(s) as variance methods vary a lot diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py deleted file mode 100644 index 9bbaf930..00000000 --- a/array_api_tests/test_type_promotion.py +++ /dev/null @@ -1,315 +0,0 @@ -""" -https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html -""" -from collections import defaultdict -from typing import List, Tuple, Union - -import pytest -from hypothesis import given, reject -from hypothesis import strategies as st - -from . import _array_module as xp -from . import dtype_helpers as dh -from . import hypothesis_helpers as hh -from . import pytest_helpers as ph -from . import xps -from .stubs import category_to_funcs -from .typing import DataType, Param, ScalarType - -bitwise_shift_funcs = [ - "bitwise_left_shift", - "bitwise_right_shift", - "__lshift__", - "__rshift__", - "__ilshift__", - "__irshift__", -] - - -# We pass kwargs to the elements strategy used by xps.arrays() so that we don't -# generate array elements that are erroneous or undefined for a function. -func_elements = defaultdict( - lambda: None, {func: {"min_value": 1} for func in bitwise_shift_funcs} -) - - -def make_id( - func_name: str, - in_dtypes: Tuple[Union[DataType, ScalarType], ...], - out_dtype: DataType, -) -> str: - f_args = dh.fmt_types(in_dtypes) - f_out_dtype = dh.dtype_to_name[out_dtype] - return f"{func_name}({f_args}) -> {f_out_dtype}" - - -def mark_stubbed_dtypes(*dtypes): - for dtype in dtypes: - if isinstance(dtype, xp._UndefinedStub): - return pytest.mark.skip(reason=f"xp.{dtype.name} not defined") - else: - return () - - -func_params: List[Param[str, Tuple[DataType, ...], DataType]] = [] -for func_name in [f.__name__ for f in category_to_funcs["elementwise"]]: - valid_in_dtypes = dh.func_in_dtypes[func_name] - ndtypes = ph.nargs(func_name) - if ndtypes == 1: - for in_dtype in valid_in_dtypes: - out_dtype = xp.bool if dh.func_returns_bool[func_name] else in_dtype - p = pytest.param( - func_name, - (in_dtype,), - out_dtype, - id=make_id(func_name, (in_dtype,), out_dtype), - marks=mark_stubbed_dtypes(in_dtype, out_dtype), - ) - func_params.append(p) - elif ndtypes == 2: - for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = ( - xp.bool if dh.func_returns_bool[func_name] else promoted_dtype - ) - p = pytest.param( - func_name, - (in_dtype1, in_dtype2), - out_dtype, - id=make_id(func_name, (in_dtype1, in_dtype2), out_dtype), - marks=mark_stubbed_dtypes(in_dtype1, in_dtype2, out_dtype), - ) - func_params.append(p) - else: - raise NotImplementedError() - - -@pytest.mark.parametrize("func_name, in_dtypes, out_dtype", func_params) -@given(data=st.data()) -def test_func_promotion(func_name, in_dtypes, out_dtype, data): - func = getattr(xp, func_name) - elements = func_elements[func_name] - if len(in_dtypes) == 1: - x = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements), - label="x", - ) - out = func(x) - else: - arrays = [] - shapes = data.draw( - hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes" - ) - for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): - x = data.draw( - xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}" - ) - arrays.append(x) - try: - out = func(*arrays) - except OverflowError: - reject() - ph.assert_dtype(func_name, in_dtypes, out.dtype, out_dtype) - - -promotion_params: List[Param[Tuple[DataType, DataType], DataType]] = [] -for (dtype1, dtype2), promoted_dtype in dh.promotion_table.items(): - p = pytest.param( - (dtype1, dtype2), - promoted_dtype, - id=make_id("", (dtype1, dtype2), promoted_dtype), - marks=mark_stubbed_dtypes(dtype1, dtype2, promoted_dtype), - ) - promotion_params.append(p) - - -numeric_promotion_params = promotion_params[1:] - - -op_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = [] -op_to_symbol = {**dh.unary_op_to_symbol, **dh.binary_op_to_symbol} -for op, symbol in op_to_symbol.items(): - if op == "__matmul__": - continue - valid_in_dtypes = dh.func_in_dtypes[op] - ndtypes = ph.nargs(op) - if ndtypes == 1: - for in_dtype in valid_in_dtypes: - out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype - p = pytest.param( - op, - f"{symbol}x", - (in_dtype,), - out_dtype, - id=make_id(op, (in_dtype,), out_dtype), - marks=mark_stubbed_dtypes(in_dtype, out_dtype), - ) - op_params.append(p) - else: - for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if in_dtype1 in valid_in_dtypes and in_dtype2 in valid_in_dtypes: - out_dtype = xp.bool if dh.func_returns_bool[op] else promoted_dtype - p = pytest.param( - op, - f"x1 {symbol} x2", - (in_dtype1, in_dtype2), - out_dtype, - id=make_id(op, (in_dtype1, in_dtype2), out_dtype), - marks=mark_stubbed_dtypes(in_dtype1, in_dtype2, out_dtype), - ) - op_params.append(p) -# We generate params for abs seperately as it does not have an associated symbol -for in_dtype in dh.func_in_dtypes["__abs__"]: - p = pytest.param( - "__abs__", - "abs(x)", - (in_dtype,), - in_dtype, - id=make_id("__abs__", (in_dtype,), in_dtype), - marks=mark_stubbed_dtypes(in_dtype), - ) - op_params.append(p) - - -@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", op_params) -@given(data=st.data()) -def test_op_promotion(op, expr, in_dtypes, out_dtype, data): - elements = func_elements[func_name] - if len(in_dtypes) == 1: - x = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=hh.shapes(), elements=elements), - label="x", - ) - out = eval(expr, {"x": x}) - else: - locals_ = {} - shapes = data.draw( - hh.mutually_broadcastable_shapes(len(in_dtypes)), label="shapes" - ) - for i, (dtype, shape) in enumerate(zip(in_dtypes, shapes), 1): - locals_[f"x{i}"] = data.draw( - xps.arrays(dtype=dtype, shape=shape, elements=elements), label=f"x{i}" - ) - try: - out = eval(expr, locals_) - except OverflowError: - reject() - ph.assert_dtype(op, in_dtypes, out.dtype, out_dtype) - - -inplace_params: List[Param[str, str, Tuple[DataType, ...], DataType]] = [] -for op, symbol in dh.inplace_op_to_symbol.items(): - if op == "__imatmul__": - continue - valid_in_dtypes = dh.func_in_dtypes[op] - for (in_dtype1, in_dtype2), promoted_dtype in dh.promotion_table.items(): - if ( - in_dtype1 == promoted_dtype - and in_dtype1 in valid_in_dtypes - and in_dtype2 in valid_in_dtypes - ): - p = pytest.param( - op, - f"x1 {symbol} x2", - (in_dtype1, in_dtype2), - promoted_dtype, - id=make_id(op, (in_dtype1, in_dtype2), promoted_dtype), - marks=mark_stubbed_dtypes(in_dtype1, in_dtype2, promoted_dtype), - ) - inplace_params.append(p) - - -@pytest.mark.parametrize("op, expr, in_dtypes, out_dtype", inplace_params) -@given(shape=hh.shapes(), data=st.data()) -def test_inplace_op_promotion(op, expr, in_dtypes, out_dtype, shape, data): - # TODO: test broadcastable shapes (that don't change x1's shape) - elements = func_elements[func_name] - x1 = data.draw( - xps.arrays(dtype=in_dtypes[0], shape=shape, elements=elements), label="x1" - ) - x2 = data.draw( - xps.arrays(dtype=in_dtypes[1], shape=shape, elements=elements), label="x2" - ) - locals_ = {"x1": x1, "x2": x2} - try: - exec(expr, locals_) - except OverflowError: - reject() - x1 = locals_["x1"] - ph.assert_dtype(op, in_dtypes, x1.dtype, out_dtype, repr_name="x1.dtype") - - -op_scalar_params: List[Param[str, str, DataType, ScalarType, DataType]] = [] -for op, symbol in dh.binary_op_to_symbol.items(): - if op == "__matmul__": - continue - for in_dtype in dh.func_in_dtypes[op]: - out_dtype = xp.bool if dh.func_returns_bool[op] else in_dtype - for in_stype in dh.dtype_to_scalars[in_dtype]: - p = pytest.param( - op, - f"x {symbol} s", - in_dtype, - in_stype, - out_dtype, - id=make_id(op, (in_dtype, in_stype), out_dtype), - marks=mark_stubbed_dtypes(in_dtype, out_dtype), - ) - op_scalar_params.append(p) - - -@pytest.mark.parametrize("op, expr, in_dtype, in_stype, out_dtype", op_scalar_params) -@given(data=st.data()) -def test_op_scalar_promotion(op, expr, in_dtype, in_stype, out_dtype, data): - elements = func_elements[func_name] - kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")} - s = data.draw(xps.from_dtype(in_dtype, **kw).map(in_stype), label="scalar") - x = data.draw( - xps.arrays(dtype=in_dtype, shape=hh.shapes(), elements=elements), label="x" - ) - try: - out = eval(expr, {"x": x, "s": s}) - except OverflowError: - reject() - ph.assert_dtype(op, [in_dtype, in_stype], out.dtype, out_dtype) - - -inplace_scalar_params: List[Param[str, str, DataType, ScalarType]] = [] -for op, symbol in dh.inplace_op_to_symbol.items(): - if op == "__imatmul__": - continue - for dtype in dh.func_in_dtypes[op]: - for in_stype in dh.dtype_to_scalars[dtype]: - p = pytest.param( - op, - f"x {symbol} s", - dtype, - in_stype, - id=make_id(op, (dtype, in_stype), dtype), - marks=mark_stubbed_dtypes(dtype), - ) - inplace_scalar_params.append(p) - - -@pytest.mark.parametrize("op, expr, dtype, in_stype", inplace_scalar_params) -@given(data=st.data()) -def test_inplace_op_scalar_promotion(op, expr, dtype, in_stype, data): - elements = func_elements[func_name] - kw = {k: in_stype is float for k in ("allow_nan", "allow_infinity")} - s = data.draw(xps.from_dtype(dtype, **kw).map(in_stype), label="scalar") - x = data.draw( - xps.arrays(dtype=dtype, shape=hh.shapes(), elements=elements), label="x" - ) - locals_ = {"x": x, "s": s} - try: - exec(expr, locals_) - except OverflowError: - reject() - x = locals_["x"] - assert x.dtype == dtype, f"{x.dtype=!s}, but should be {dtype}" - ph.assert_dtype(op, [dtype, in_stype], x.dtype, dtype, repr_name="x.dtype") - - -if __name__ == "__main__": - for (i, j), p in dh.promotion_table.items(): - print(f"({i}, {j}) -> {p}") diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py index 7c09fb27..b6e0a4fe 100644 --- a/array_api_tests/test_utility_functions.py +++ b/array_api_tests/test_utility_functions.py @@ -7,24 +7,23 @@ from . import hypothesis_helpers as hh from . import pytest_helpers as ph from . import shape_helpers as sh -from . import xps - -pytestmark = pytest.mark.ci +@pytest.mark.unvectorized @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)), + x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)), data=st.data(), ) def test_all(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.all(x, **kw) - ph.assert_dtype("all", x.dtype, out.dtype, xp.bool) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "all", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -34,22 +33,25 @@ def test_all(x, data): s = scalar_type(x[idx]) elements.append(s) expected = all(elements) - ph.assert_scalar_equals("all", scalar_type, out_idx, result, expected) + ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx, + out=result, expected=expected, kw=kw) +@pytest.mark.unvectorized @given( - x=xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), + x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data(), ) def test_any(x, data): kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw") + keepdims = kw.get("keepdims", False) out = xp.any(x, **kw) - ph.assert_dtype("any", x.dtype, out.dtype, xp.bool) - _axes = sh.normalise_axis(kw.get("axis", None), x.ndim) + ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool) + _axes = sh.normalize_axis(kw.get("axis", None), x.ndim) ph.assert_keepdimable_shape( - "any", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw + "any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw, ) scalar_type = dh.get_scalar_type(x.dtype) for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)): @@ -59,4 +61,81 @@ def test_any(x, data): s = scalar_type(x[idx]) elements.append(s) expected = any(elements) - ph.assert_scalar_equals("any", scalar_type, out_idx, result, expected) + ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx, + out=result, expected=expected, kw=kw) + + +@pytest.mark.unvectorized +@pytest.mark.min_version("2024.12") +@given( + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_diff(x, data): + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {"axis": -1} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + n = data.draw(st.integers(1, min(x.shape[n_axis], 3))) + + out = xp.diff(x, **axis_kw, n=n) + + expected_shape = list(x.shape) + expected_shape[n_axis] -= n + + assert out.shape == tuple(expected_shape) + + # value test + if n == 1: + for idx in sh.ndindex(out.shape): + l = list(idx) + l[n_axis] += 1 + assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }" + + +@pytest.mark.min_version("2024.12") +@pytest.mark.unvectorized +@given( + x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)), + data=st.data(), +) +def test_diff_append_prepend(x, data): + axis = data.draw( + st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(), + label="axis" + ) + if axis is None: + axis_kw = {"axis": -1} + n_axis = x.ndim - 1 + else: + axis_kw = {"axis": axis} + n_axis = axis + x.ndim if axis < 0 else axis + + n = data.draw(st.integers(1, min(x.shape[n_axis], 3))) + + append_shape = list(x.shape) + append_axis_len = data.draw(st.integers(1, 2*append_shape[n_axis]), label="append_axis") + append_shape[n_axis] = append_axis_len + append = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(append_shape)), label="append") + + prepend_shape = list(x.shape) + prepend_axis_len = data.draw(st.integers(1, 2*prepend_shape[n_axis]), label="prepend_axis") + prepend_shape[n_axis] = prepend_axis_len + prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend") + + out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend) + + in_1 = xp.concat((prepend, x, append), **axis_kw) + out_1 = xp.diff(in_1, **axis_kw, n=n) + + assert out.shape == out_1.shape + for idx in sh.ndindex(out.shape): + assert out[idx] == out_1[idx], f"{idx = }" + diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py index f0ed8c50..84311ff3 100644 --- a/array_api_tests/typing.py +++ b/array_api_tests/typing.py @@ -12,8 +12,8 @@ ] DataType = Type[Any] -Scalar = Union[bool, int, float] -ScalarType = Union[Type[bool], Type[int], Type[float]] +Scalar = Union[bool, int, float, complex] +ScalarType = Union[Type[bool], Type[int], Type[float], Type[complex]] Array = Any Shape = Tuple[int, ...] AtomicIndex = Union[int, "ellipsis", slice, None] # noqa diff --git a/conftest.py b/conftest.py index e0453e40..dc50c9ae 100644 --- a/conftest.py +++ b/conftest.py @@ -1,15 +1,38 @@ from functools import lru_cache from pathlib import Path +import argparse +import warnings +import os from hypothesis import settings +from hypothesis.errors import InvalidArgument from pytest import mark from array_api_tests import _array_module as xp +from array_api_tests import api_version from array_api_tests._array_module import _UndefinedStub +from array_api_tests.stubs import EXTENSIONS +from array_api_tests import xp_name, xp as array_module from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa -settings.register_profile("xp_default", deadline=800) +def pytest_report_header(config): + disabled_extensions = config.getoption("--disable-extension") + enabled_extensions = sorted({ + ext for ext in EXTENSIONS + ['fft'] if ext not in disabled_extensions and xp_has_ext(ext) + }) + + try: + array_module_version = array_module.__version__ + except AttributeError: + array_module_version = "version unknown" + + # make it easier to catch typos in environment variables (ARRAY_API_*** instead of ARRAY_API_TESTS_*** etc) + env_vars = "\n".join([f"{k} = {v}" for k, v in os.environ.items() if 'ARRAY_API' in k]) + env_vars = f"Environment variables:\n{'-'*22}\n{env_vars}\n\n" + + header1 = f"Array API Tests Module: {xp_name} ({array_module_version}). API Version: {api_version}. Enabled Extensions: {', '.join(enabled_extensions)}" + return env_vars + header1 def pytest_addoption(parser): # Hypothesis max examples @@ -18,7 +41,8 @@ def pytest_addoption(parser): "--hypothesis-max-examples", "--max-examples", action="store", - default=None, + default=100, + type=int, help="set the Hypothesis max_examples setting", ) # Hypothesis deadline @@ -28,6 +52,13 @@ def pytest_addoption(parser): action="store_true", help="disable the Hypothesis deadline", ) + # Hypothesis derandomize + parser.addoption( + "--hypothesis-derandomize", + "--derandomize", + action="store_true", + help="set the Hypothesis derandomize parameter", + ) # disable extensions parser.addoption( "--disable-extension", @@ -44,10 +75,16 @@ def pytest_addoption(parser): help="disable testing functions with output shapes dependent on input", ) # CI + parser.addoption("--ci", action="store_true", help=argparse.SUPPRESS ) # deprecated parser.addoption( - "--ci", - action="store_true", - help="run just the tests appropiate for CI", + "--skips-file", + action="store", + help="file with tests to skip. Defaults to skips.txt" + ) + parser.addoption( + "--xfails-file", + action="store", + help="file with tests to skip. Defaults to xfails.txt" ) @@ -58,20 +95,29 @@ def pytest_configure(config): config.addinivalue_line( "markers", "data_dependent_shapes: output shapes are dependent on inputs" ) - config.addinivalue_line("markers", "ci: primary test") + config.addinivalue_line( + "markers", + "min_version(api_version): run when greater or equal to api_version", + ) + config.addinivalue_line( + "markers", + "unvectorized: asserts against values via element-wise iteration (not performative!)", + ) # Hypothesis - hypothesis_max_examples = config.getoption("--hypothesis-max-examples") - disable_deadline = config.getoption("--hypothesis-disable-deadline") - profile_settings = {} - if hypothesis_max_examples is not None: - profile_settings["max_examples"] = int(hypothesis_max_examples) - if disable_deadline is not None: - profile_settings["deadline"] = None - if profile_settings: - settings.register_profile("xp_override", **profile_settings) - settings.load_profile("xp_override") - else: - settings.load_profile("xp_default") + deadline = None if config.getoption("--hypothesis-disable-deadline") else 800 + settings.register_profile( + "array-api-tests", + max_examples=config.getoption("--hypothesis-max-examples"), + derandomize=config.getoption("--hypothesis-derandomize"), + deadline=deadline, + ) + settings.load_profile("array-api-tests") + # CI + if config.getoption("--ci"): + warnings.warn( + "Custom pytest option --ci is deprecated as any tests not for CI " + "are now located in meta_tests/" + ) @lru_cache @@ -82,26 +128,99 @@ def xp_has_ext(ext: str) -> bool: return False -skip_ids = [] -skips_path = Path(__file__).parent / "skips.txt" -if skips_path.exists(): - with open(skips_path) as f: - for line in f: - if line.startswith("array_api_tests"): - id_ = line.strip("\n") - skip_ids.append(id_) +def check_id_match(id_, pattern): + id_ = id_.removeprefix('array-api-tests/') + + if id_ == pattern: + return True + + if id_.startswith(pattern.removesuffix("/") + "/"): + return True + + if pattern.endswith(".py") and id_.startswith(pattern): + return True + + if id_.split("::", maxsplit=2)[0] == pattern: + return True + + if id_.split("[", maxsplit=2)[0] == pattern: + return True + + return False + + +def get_xfail_mark(): + """Skip or xfail tests from the xfails-file.txt.""" + m = os.environ.get("ARRAY_API_TESTS_XFAIL_MARK", "xfail") + if m == "xfail": + return mark.xfail + elif m == "skip": + return mark.skip + else: + raise ValueError( + f'ARRAY_API_TESTS_XFAIL_MARK value should be one of "skip" or "xfail" ' + f'got {m} instead.' + ) def pytest_collection_modifyitems(config, items): + # 1. Prepare for iterating over items + # ----------------------------------- + + skips_file = skips_path = config.getoption('--skips-file') + if skips_file is None: + skips_file = Path(__file__).parent / "skips.txt" + if skips_file.exists(): + skips_path = skips_file + + skip_ids = [] + if skips_path: + with open(os.path.expanduser(skips_path)) as f: + for line in f: + if line.startswith("array_api_tests"): + id_ = line.strip("\n") + skip_ids.append(id_) + + xfails_file = xfails_path = config.getoption('--xfails-file') + if xfails_file is None: + xfails_file = Path(__file__).parent / "xfails.txt" + if xfails_file.exists(): + xfails_path = xfails_file + + xfail_ids = [] + if xfails_path: + with open(os.path.expanduser(xfails_path)) as f: + for line in f: + if not line.strip() or line.startswith('#'): + continue + id_ = line.strip("\n") + xfail_ids.append(id_) + + skip_id_matched = {id_: False for id_ in skip_ids} + xfail_id_matched = {id_: False for id_ in xfail_ids} + disabled_exts = config.getoption("--disable-extension") disabled_dds = config.getoption("--disable-data-dependent-shapes") - ci = config.getoption("--ci") + unvectorized_max_examples = max(1, config.getoption("--hypothesis-max-examples")//10) + + # 2. Iterate through items and apply markers accordingly + # ------------------------------------------------------ + + xfail_mark = get_xfail_mark() + for item in items: markers = list(item.iter_markers()) - # skip if specified in skips.txt + # skip if specified in skips file for id_ in skip_ids: - if item.nodeid.startswith(id_): - item.add_marker(mark.skip(reason="skips.txt")) + if check_id_match(item.nodeid, id_): + item.add_marker(mark.skip(reason=f"--skips-file ({skips_file})")) + skip_id_matched[id_] = True + break + # xfail if specified in xfails file + for id_ in xfail_ids: + if check_id_match(item.nodeid, id_): + item.add_marker(xfail_mark(reason=f"--xfails-file ({xfails_file})")) + xfail_id_matched[id_] = True break # skip if disabled or non-existent extension ext_mark = next((m for m in markers if m.name == "xp_extension"), None) @@ -121,8 +240,51 @@ def pytest_collection_modifyitems(config, items): mark.skip(reason="disabled via --disable-data-dependent-shapes") ) break - # skip if test not appropriate for CI - if ci: - ci_mark = next((m for m in markers if m.name == "ci"), None) - if ci_mark is None: - item.add_marker(mark.skip(reason="disabled via --ci")) + # skip if test is for greater api_version + ver_mark = next((m for m in markers if m.name == "min_version"), None) + if ver_mark is not None: + min_version = ver_mark.args[0] + if api_version < min_version: + item.add_marker( + mark.skip( + reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}" + ) + ) + # reduce max generated Hypothesis example for unvectorized tests + if any(m.name == "unvectorized" for m in markers): + # TODO: limit generated examples when settings already applied + if not hasattr(item.obj, "_hypothesis_internal_settings_applied"): + try: + item.obj = settings(max_examples=unvectorized_max_examples)(item.obj) + except InvalidArgument as e: + warnings.warn( + f"Tried decorating {item.name} with settings() but got " + f"hypothesis.errors.InvalidArgument: {e}" + ) + + + # 3. Warn on bad skipped/xfailed ids + # ---------------------------------- + + bad_ids_end_msg = ( + "Note the relevant tests might not have been collected by pytest, or " + "another specified id might have already matched a test." + ) + bad_skip_ids = [id_ for id_, matched in skip_id_matched.items() if not matched] + if bad_skip_ids: + f_bad_ids = "\n".join(f" {id_}" for id_ in bad_skip_ids) + warnings.warn( + f"{len(bad_skip_ids)} ids in skips file don't match any collected tests: \n" + f"{f_bad_ids}\n" + f"(skips file: {skips_file})\n" + f"{bad_ids_end_msg}" + ) + bad_xfail_ids = [id_ for id_, matched in xfail_id_matched.items() if not matched] + if bad_xfail_ids: + f_bad_ids = "\n".join(f" {id_}" for id_ in bad_xfail_ids) + warnings.warn( + f"{len(bad_xfail_ids)} ids in xfails file don't match any collected tests: \n" + f"{f_bad_ids}\n" + f"(xfails file: {xfails_file})\n" + f"{bad_ids_end_msg}" + ) diff --git a/meta_tests/README.md b/meta_tests/README.md new file mode 100644 index 00000000..fb563cf6 --- /dev/null +++ b/meta_tests/README.md @@ -0,0 +1 @@ +Testing the utilities used in `array_api_tests/` \ No newline at end of file diff --git a/array_api_tests/meta/__init__.py b/meta_tests/__init__.py similarity index 100% rename from array_api_tests/meta/__init__.py rename to meta_tests/__init__.py diff --git a/meta_tests/test_array_helpers.py b/meta_tests/test_array_helpers.py new file mode 100644 index 00000000..c46df50d --- /dev/null +++ b/meta_tests/test_array_helpers.py @@ -0,0 +1,35 @@ +from hypothesis import given +from hypothesis import strategies as st + +from array_api_tests import _array_module as xp +from array_api_tests.hypothesis_helpers import (int_dtypes, arrays, + two_mutually_broadcastable_shapes) +from array_api_tests.shape_helpers import iter_indices, broadcast_shapes +from array_api_tests .array_helpers import exactly_equal, notequal, less + +# TODO: These meta-tests currently only work with NumPy + +def test_exactly_equal(): + a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) + b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) + + res = xp.asarray([True, False, True, False, True, False, True, False]) + assert xp.all(xp.equal(exactly_equal(a, b), res)) + +def test_notequal(): + a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1]) + b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2]) + + res = xp.asarray([False, True, False, False, False, True, False, True]) + assert xp.all(xp.equal(notequal(a, b), res)) + + +@given(two_mutually_broadcastable_shapes, int_dtypes, int_dtypes, st.data()) +def test_less(shapes, dtype1, dtype2, data): + x = data.draw(arrays(shape=shapes[0], dtype=dtype1)) + y = data.draw(arrays(shape=shapes[1], dtype=dtype2)) + + res = less(x, y) + + for i, j, k in iter_indices(x.shape, y.shape, broadcast_shapes(x.shape, y.shape)): + assert res[k] == (int(x[i]) < int(y[j])) diff --git a/array_api_tests/meta/test_broadcasting.py b/meta_tests/test_broadcasting.py similarity index 95% rename from array_api_tests/meta/test_broadcasting.py rename to meta_tests/test_broadcasting.py index 72de61cf..2f6310c1 100644 --- a/array_api_tests/meta/test_broadcasting.py +++ b/meta_tests/test_broadcasting.py @@ -4,7 +4,7 @@ import pytest -from .. import shape_helpers as sh +from array_api_tests import shape_helpers as sh @pytest.mark.parametrize( diff --git a/array_api_tests/meta/test_equality_mapping.py b/meta_tests/test_equality_mapping.py similarity index 93% rename from array_api_tests/meta/test_equality_mapping.py rename to meta_tests/test_equality_mapping.py index 86fa7e14..8ac481f6 100644 --- a/array_api_tests/meta/test_equality_mapping.py +++ b/meta_tests/test_equality_mapping.py @@ -1,6 +1,6 @@ import pytest -from ..dtype_helpers import EqualityMapping +from array_api_tests .dtype_helpers import EqualityMapping def test_raises_on_distinct_eq_key(): diff --git a/array_api_tests/meta/test_hypothesis_helpers.py b/meta_tests/test_hypothesis_helpers.py similarity index 71% rename from array_api_tests/meta/test_hypothesis_helpers.py rename to meta_tests/test_hypothesis_helpers.py index 647cc145..b14b728c 100644 --- a/array_api_tests/meta/test_hypothesis_helpers.py +++ b/meta_tests/test_hypothesis_helpers.py @@ -1,21 +1,23 @@ from math import prod +from typing import Type import pytest from hypothesis import given, settings from hypothesis import strategies as st +from hypothesis.errors import Unsatisfiable -from .. import _array_module as xp -from .. import array_helpers as ah -from .. import dtype_helpers as dh -from .. import hypothesis_helpers as hh -from .. import shape_helpers as sh -from .. import xps -from .._array_module import _UndefinedStub +from array_api_tests import _array_module as xp +from array_api_tests import array_helpers as ah +from array_api_tests import dtype_helpers as dh +from array_api_tests import hypothesis_helpers as hh +from array_api_tests import shape_helpers as sh +from array_api_tests import xps +from array_api_tests ._array_module import _UndefinedStub UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes) pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")] -@given(hh.mutually_promotable_dtypes(dtypes=dh.float_dtypes)) +@given(hh.mutually_promotable_dtypes(dtypes=dh.real_float_dtypes)) def test_mutually_promotable_dtypes(pair): assert pair in ( (xp.float32, xp.float32), @@ -126,12 +128,9 @@ def run(n, d, data): assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results) - -@given(m=hh.symmetric_matrices(hh.shared_floating_dtypes, - finite=st.shared(st.booleans(), key='finite')), - dtype=hh.shared_floating_dtypes, - finite=st.shared(st.booleans(), key='finite')) -def test_symmetric_matrices(m, dtype, finite): +@given(finite=st.booleans(), dtype=xps.floating_dtypes(), data=st.data()) +def test_symmetric_matrices(finite, dtype, data): + m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite), label="m") assert m.dtype == dtype # TODO: This part of this test should be part of the .mT test ah.assert_exactly_equal(m, m.mT) @@ -139,8 +138,33 @@ def test_symmetric_matrices(m, dtype, finite): if finite: ah.assert_finite(m) -@given(m=hh.positive_definite_matrices(hh.shared_floating_dtypes), - dtype=hh.shared_floating_dtypes) -def test_positive_definite_matrices(m, dtype): + +@given(dtype=xps.floating_dtypes(), data=st.data()) +def test_positive_definite_matrices(dtype, data): + m = data.draw(hh.positive_definite_matrices(st.just(dtype)), label="m") assert m.dtype == dtype # TODO: Test that it actually is positive definite + + +def make_raising_func(cls: Type[Exception], msg: str): + def raises(): + raise cls(msg) + + return raises + +@pytest.mark.parametrize( + "func", + [ + make_raising_func(OverflowError, "foo"), + make_raising_func(RuntimeError, "Overflow when unpacking long"), + make_raising_func(Exception, "Got an overflow"), + ] +) +def test_reject_overflow(func): + @given(data=st.data()) + def test_case(data): + with hh.reject_overflow(): + func() + + with pytest.raises(Unsatisfiable): + test_case() diff --git a/meta_tests/test_linalg.py b/meta_tests/test_linalg.py new file mode 100644 index 00000000..82794b6c --- /dev/null +++ b/meta_tests/test_linalg.py @@ -0,0 +1,16 @@ +import pytest + +from hypothesis import given + +from array_api_tests .hypothesis_helpers import symmetric_matrices +from array_api_tests import array_helpers as ah +from array_api_tests import _array_module as xp + +@pytest.mark.xp_extension('linalg') +@given(x=symmetric_matrices(finite=True)) +def test_symmetric_matrices(x): + upper = xp.triu(x) + lower = xp.tril(x) + lowerT = ah._matrix_transpose(lower) + + ah.assert_exactly_equal(upper, lowerT) diff --git a/array_api_tests/meta/test_partial_adopters.py b/meta_tests/test_partial_adopters.py similarity index 68% rename from array_api_tests/meta/test_partial_adopters.py rename to meta_tests/test_partial_adopters.py index 6eda5c89..de3a7e76 100644 --- a/array_api_tests/meta/test_partial_adopters.py +++ b/meta_tests/test_partial_adopters.py @@ -1,10 +1,10 @@ import pytest from hypothesis import given -from .. import dtype_helpers as dh -from .. import hypothesis_helpers as hh -from .. import _array_module as xp -from .._array_module import _UndefinedStub +from array_api_tests import dtype_helpers as dh +from array_api_tests import hypothesis_helpers as hh +from array_api_tests import _array_module as xp +from array_api_tests ._array_module import _UndefinedStub # e.g. PyTorch only supports uint8 currently diff --git a/meta_tests/test_pytest_helpers.py b/meta_tests/test_pytest_helpers.py new file mode 100644 index 00000000..a0aa0930 --- /dev/null +++ b/meta_tests/test_pytest_helpers.py @@ -0,0 +1,29 @@ +from pytest import raises + +from array_api_tests import xp as _xp +from array_api_tests import _array_module as xp +from array_api_tests import pytest_helpers as ph + + +def test_assert_dtype(): + ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16) + with raises(AssertionError): + ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32) + ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool) + ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8) + ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool) + + +def test_assert_array_elements(): + ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0)) + ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0)) + ph.assert_array_elements("neg zeros", out=xp.asarray(-0.0), expected=xp.asarray(-0.0)) + if hasattr(_xp, "signbit"): + with raises(AssertionError): + ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0)) + with raises(AssertionError): + ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0)) + + ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan"))) + with raises(AssertionError): + ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0)) diff --git a/array_api_tests/meta/test_signatures.py b/meta_tests/test_signatures.py similarity index 96% rename from array_api_tests/meta/test_signatures.py rename to meta_tests/test_signatures.py index 2efe1881..937f73f3 100644 --- a/array_api_tests/meta/test_signatures.py +++ b/meta_tests/test_signatures.py @@ -2,7 +2,7 @@ import pytest -from ..test_signatures import _test_inspectable_func +from array_api_tests .test_signatures import _test_inspectable_func def stub(foo, /, bar=None, *, baz=None): diff --git a/array_api_tests/meta/test_special_cases.py b/meta_tests/test_special_cases.py similarity index 75% rename from array_api_tests/meta/test_special_cases.py rename to meta_tests/test_special_cases.py index 826e5969..40c7806c 100644 --- a/array_api_tests/meta/test_special_cases.py +++ b/meta_tests/test_special_cases.py @@ -1,6 +1,6 @@ import math -from ..test_special_cases import parse_result +from array_api_tests .test_special_cases import parse_result def test_parse_result(): diff --git a/array_api_tests/meta/test_utils.py b/meta_tests/test_utils.py similarity index 81% rename from array_api_tests/meta/test_utils.py rename to meta_tests/test_utils.py index 268a81aa..911ba899 100644 --- a/array_api_tests/meta/test_utils.py +++ b/meta_tests/test_utils.py @@ -1,18 +1,15 @@ import pytest -from hypothesis import given, reject +from hypothesis import given from hypothesis import strategies as st -from .. import _array_module as xp -from .. import dtype_helpers as dh -from .. import shape_helpers as sh -from .. import xps -from ..test_creation_functions import frange -from ..test_manipulation_functions import roll_ndindex -from ..test_operators_and_elementwise_functions import ( - mock_int_dtype, - oneway_broadcastable_shapes, - oneway_promotable_dtypes, -) +from array_api_tests import _array_module as xp +from array_api_tests import dtype_helpers as dh +from array_api_tests import hypothesis_helpers as hh +from array_api_tests import shape_helpers as sh +from array_api_tests import xps +from array_api_tests .test_creation_functions import frange +from array_api_tests .test_manipulation_functions import roll_ndindex +from array_api_tests .test_operators_and_elementwise_functions import mock_int_dtype @pytest.mark.parametrize( @@ -108,18 +105,16 @@ def test_fmt_idx(idx, expected): @given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes()) def test_int_to_dtype(x, dtype): - try: + with hh.reject_overflow(): d = xp.asarray(x, dtype=dtype) - except OverflowError: - reject() assert mock_int_dtype(x, dtype) == d -@given(oneway_promotable_dtypes(dh.all_dtypes)) +@given(hh.oneway_promotable_dtypes(dh.all_dtypes)) def test_oneway_promotable_dtypes(D): assert D.result_dtype == dh.result_type(*D) -@given(oneway_broadcastable_shapes()) +@given(hh.oneway_broadcastable_shapes()) def test_oneway_broadcastable_shapes(S): assert S.result_shape == sh.broadcast_shapes(*S) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..935c67fe --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +filterwarnings = + # Ignore floating-point warnings from NumPy + ignore:invalid value encountered in:RuntimeWarning + ignore:overflow encountered in:RuntimeWarning + ignore:divide by zero encountered in:RuntimeWarning + + diff --git a/reporting.py b/reporting.py index f7c7d6b9..579aa211 100644 --- a/reporting.py +++ b/reporting.py @@ -10,7 +10,7 @@ from hypothesis.strategies import SearchStrategy -from pytest import mark, fixture +from pytest import hookimpl, fixture try: import pytest_jsonreport # noqa except ImportError: @@ -33,6 +33,8 @@ def to_json_serializable(o): return tuple(to_json_serializable(i) for i in o) if isinstance(o, list): return [to_json_serializable(i) for i in o] + if callable(o): + return repr(o) # Ensure everything is JSON serializable. If this warning is issued, it # means the given type needs to be added above if possible. @@ -44,12 +46,12 @@ def to_json_serializable(o): return o -@mark.optionalhook +@hookimpl(optionalhook=True) def pytest_metadata(metadata): """ Additional global metadata for --json-report. """ - metadata['array_api_tests_module'] = xp.mod_name + metadata['array_api_tests_module'] = xp.__name__ metadata['array_api_tests_version'] = __version__ @fixture(autouse=True) @@ -91,7 +93,7 @@ def finalizer(): request.addfinalizer(finalizer) -@mark.optionalhook +@hookimpl(optionalhook=True) def pytest_json_modifyreport(json_report): # Deduplicate warnings. These duplicate warnings can cause the file size # to become huge. For instance, a warning from np.bool which is emitted @@ -103,7 +105,7 @@ def pytest_json_modifyreport(json_report): # doesn't store a full stack of where it was issued from. The resulting # warnings will be in order of the first time each warning is issued since # collections.Counter is ordered just like dict(). - counted_warnings = Counter([frozenset(i.items()) for i in json_report['warnings']]) + counted_warnings = Counter([frozenset(i.items()) for i in json_report.get('warnings', dict())]) deduped_warnings = [{**dict(i), 'count': counted_warnings[i]} for i in counted_warnings] json_report['warnings'] = deduped_warnings diff --git a/requirements.txt b/requirements.txt index bfc39d4c..c5508119 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ pytest pytest-json-report -hypothesis>=6.55.0 -ndindex>=1.6 +hypothesis>=6.130.5 +ndindex>=1.8