diff --git a/.flake8 b/.flake8
new file mode 100644
index 00000000..e354eaae
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,2 @@
+[flake8]
+select = F
diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 00000000..a832af83
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1 @@
+array_api_tests/_version.py} export-subst
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
new file mode 100644
index 00000000..9b49c09b
--- /dev/null
+++ b/.github/workflows/lint.yml
@@ -0,0 +1,16 @@
+name: Linting
+
+on: [push, pull_request]
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+ - name: Run pre-commit hook
+ uses: pre-commit/action@v3.0.1
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
deleted file mode 100644
index fc9a37a8..00000000
--- a/.github/workflows/main.yml
+++ /dev/null
@@ -1,30 +0,0 @@
-name: Tests
-
-on: [push, pull_request]
-
-jobs:
- build:
-
- runs-on: ubuntu-latest
- strategy:
- matrix:
- python-version: [3.8, 3.9]
-
- steps:
- - uses: actions/checkout@v1
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
- with:
- python-version: ${{ matrix.python-version }}
- # - name: Install dependencies
- # run: |
- # python -m pip install --upgrade pip
- # pip install ...
- - name: Lint with pyfalkes
- run: |
- pip install pyflakes
- pyflakes .
- # - name: Test with pytest
- # run: |
- # pip install pytest
- # pytest
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 00000000..2fab2072
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,34 @@
+name: Test Array API Strict
+
+on: [push, pull_request]
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.10", "3.11"]
+
+ steps:
+ - name: Checkout array-api-tests
+ uses: actions/checkout@v1
+ with:
+ submodules: 'true'
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v1
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install array-api-strict
+ python -m pip install -r requirements.txt
+ - name: Run the test suite
+ env:
+ ARRAY_API_TESTS_MODULE: array_api_strict
+ ARRAY_API_STRICT_API_VERSION: 2024.12
+ run: |
+ pytest -v -rxXfE --skips-file array-api-strict-skips.txt array_api_tests/
+ # We also have internal tests that isn't really necessary for adopters
+ pytest -v -rxXfE meta_tests/
diff --git a/.gitignore b/.gitignore
index b6e47617..fc5b8b8a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -117,6 +117,10 @@ venv.bak/
# Rope project settings
.ropeproject
+# IDE
+.idea/
+.vscode/
+
# mkdocs documentation
/site
@@ -127,3 +131,6 @@ dmypy.json
# Pyre type checker
.pyre/
+
+# pytest-json-report
+.report.json
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 00000000..c225c24e
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "array_api_tests/array-api"]
+ path = array-api
+ url = https://github.com/data-apis/array-api/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..a2ee60df
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,6 @@
+repos:
+- repo: https://github.com/pycqa/flake8
+ rev: '4.0.1'
+ hooks:
+ - id: flake8
+ args: [--select, F]
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..f8c5a8c3
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,3 @@
+include versioneer.py
+include array_api_tests/_version.py}
+include array_api_tests/_version.py
diff --git a/README.md b/README.md
index 0ff0bce0..fa17b763 100644
--- a/README.md
+++ b/README.md
@@ -1,132 +1,414 @@
-# Array API Standard Test Suite
+# Test Suite for Array API Compliance
-This is the test suite for the PyData Array APIs standard.
+This is the test suite for array libraries adopting the [Python Array API
+standard](https://data-apis.org/array-api/latest).
-**NOTE: This test suite is still a work in progress.**
+Keeping full coverage of the spec is an on-going priority as the Array API evolves.
+Feedback and contributions are welcome!
-Feedback and contributions are welcome, but be aware that this suite is not
-yet completed. In particular, there are still many parts of the array API
-specification that are not yet tested here.
+## Quickstart
-## Running the tests
+### Setup
-To run the tests, first install the testing dependencies
+Currently we pin the Array API specification repo [`array-api`](https://github.com/data-apis/array-api/)
+as a git submodule. This might change in the future to better support vendoring
+use cases (see [#107](https://github.com/data-apis/array-api-tests/issues/107)),
+but for now be sure submodules are pulled too, e.g.
- pip install pytest hypothesis
+```bash
+$ git submodule update --init
+```
-or
+To run the tests, install the testing dependencies.
- conda install pytest hypothesis
+```bash
+$ pip install -r requirements.txt
+```
-as well as the array libraries that you want to test. To run the tests, you
-need to set the array library that is to be tested. There are two ways to do
-this. One way is to set the `ARRAY_API_TESTS_MODULE` environment variable. For
-example
+Ensure you have the array library that you want to test installed.
- ARRAY_API_TESTS_MODULE=numpy pytest
+### Specifying the array module
-Alternately, edit the `array_api_tests/_array_module.py` file and change the
-line
+You need to specify the array library to test. It can be specified via the
+`ARRAY_API_TESTS_MODULE` environment variable, e.g.
-```py
-array_module = None
+```bash
+$ export ARRAY_API_TESTS_MODULE=array_api_strict
```
-to
+To specify a runtime-defined module, define `xp` using the `exec('...')` syntax:
+```bash
+$ export ARRAY_API_TESTS_MODULE="exec('import quantity_array, numpy; xp = quantity_array.quantity_namespace(numpy)')"
```
-import numpy as array_module
+
+Alternately, import/define the `xp` variable in `array_api_tests/__init__.py`.
+
+### Specifying the API version
+
+You can specify the API version to use when testing via the
+`ARRAY_API_TESTS_VERSION` environment variable, e.g.
+
+```bash
+$ export ARRAY_API_TESTS_VERSION="2023.12"
```
-(replacing `numpy` with the array module namespace to be tested).
+Currently this defaults to the array module's `__array_api_version__` value, and
+if that attribute doesn't exist then we fallback to `"2021.12"`.
-## Notes on Interpreting Errors
+### Run the suite
-- Some tests cannot be run unless other tests pass first. This is because very
- basic APIs such as certain array creation APIs are required for a large
- fraction of the tests to run. TODO: Write which tests are required to pass
- first here.
+Simply run `pytest` against the `array_api_tests/` folder to run the full suite.
-- If an error message involves `_UndefinedStub`, it means some name that is
- required for the test to run is not defined in the array library.
+```bash
+$ pytest array_api_tests/
+```
-- Due to the nature of the array api spec, virtually every array library will
- produce a large number of errors from nonconformance. It is still a work in
- progress to enable reporting the errors in a way that makes them easy to
- understand, even if there are a large number of them.
+The suite tries to logically organise its tests. `pytest` allows you to only run
+a specific test case, which is useful when developing functions.
-- The spec documents are the ground source of truth. If the test suite appears
- to be testing something that is different from the spec, or something that
- isn't actually mentioned in the spec, this is a bug. [Please report
- it](https://github.com/data-apis/array-api-tests/issues/new). Furthermore,
- be aware that some aspects of the spec are either impossible or extremely
- difficult to actually test, so they are not covered in the test suite (TODO:
- list what these are).
+```bash
+$ pytest array_api_tests/test_creation_functions.py::test_zeros
+```
-## Contributing
+## What the test suite covers
+
+We are interested in array libraries conforming to the
+[spec](https://data-apis.org/array-api/latest/API_specification/index.html).
+Ideally this means that if a library has fully adopted the Array API, the test
+suite passes. We take great care to _not_ test things which are out-of-scope,
+so as to not unexpectedly fail the suite.
+
+### Primary tests
+
+Every function—including array object methods—has a respective test
+method1 . We use
+[Hypothesis](https://hypothesis.readthedocs.io/en/latest/)
+to generate a diverse set of valid inputs. This means array inputs will cover
+different dtypes and shapes, as well as contain interesting elements. These
+examples generate with interesting arrangements of non-array positional
+arguments and keyword arguments.
+
+Each test case will cover the following areas if relevant:
+
+* **Smoking**: We pass our generated examples to all functions. As these
+ examples solely consist of *valid* inputs, we are testing that functions can
+ be called using their documented inputs without raising errors.
+
+* **Data type**: For functions returning/modifying arrays, we assert that output
+ arrays have the correct data types. Most functions
+ [type-promote](https://data-apis.org/array-api/latest/API_specification/type_promotion.html)
+ input arrays and some functions have bespoke rules—in both cases we simulate
+ the correct behaviour to find the expected data types.
+
+* **Shape**: For functions returning/modifying arrays, we assert that output
+ arrays have the correct shape. Most functions
+ [broadcast](https://data-apis.org/array-api/latest/API_specification/broadcasting.html)
+ input arrays and some functions have bespoke rules—in both cases we simulate
+ the correct behaviour to find the expected shapes.
+
+* **Values**: We assert output values (including the elements of
+ returned/modified arrays) are as expected. Except for manipulation functions
+ or special cases, the spec allows floating-point inputs to have inexact
+ outputs, so with such examples we only assert values are roughly as expected.
+
+### Additional tests
-### Adding Tests
+In addition to having one test case for each function, we test other properties
+of the functions and some miscellaneous things.
-It is important that every test in the test suite only uses APIs that are part
-of the standard. This means that, for instance, when creating test arrays, you
-should only use array creation functions that are part of the spec, such as
-`ones` or `full`. It also means that many array testing functions that are
-built-in to libraries like numpy are reimplemented in the test suite (see
-`array_api_tests/pytest_helpers.py`, `array_api_tests/array_helpers.py`, and
-`array_api_tests/hypothesis_helpers.py`).
+* **Special cases**: For functions with special case behaviour, we assert that
+ these functions return the correct values.
-In order to enforce this, the `array_api_tests._array_module` should be used
-everywhere in place of the actual array module that is being tested.
+* **Signatures**: We assert functions have the correct signatures.
-### Hypothesis
+* **Constants**: We assert that
+ [constants](https://data-apis.org/array-api/latest/API_specification/constants.html)
+ behave expectedly, are roughly the expected value, and that any related
+ functions interact with them correctly.
-The test suite uses [Hypothesis](https://hypothesis.readthedocs.io/en/latest/)
-to generate random input data. Any test that should be applied over all
-possible array inputs should use hypothesis tests. Custom Hypothesis
-strategies are in the `array_api_tests/hypothesis_helpers.py` file.
+Be aware that some aspects of the spec are impractical or impossible to actually
+test, so they are not covered in the suite.
-### Parameterization
+## Interpreting errors
-Any test that applies over all functions in a module should use
-`pytest.mark.parametrize` to parameterize over them. For example,
+First and foremost, note that most tests have to assume that certain aspects of
+the Array API have been correctly adopted, as fundamental APIs such as array
+creation and equalities are hard requirements for many assertions. This means a
+test case for one function might fail because another function has bugs or even
+no implementation.
-```py
-from . import function_stubs
+This means adopting libraries at first will result in a vast number of errors
+due to cascading errors. Generally the nature of the spec means many granular
+details such as type promotion is likely going to also fail nearly-conforming
+functions.
+
+We hope to improve user experience in regards to "noisy" errors in
+[#51](https://github.com/data-apis/array-api-tests/issues/51). For now, if an
+error message involves `_UndefinedStub`, it means an attribute of the array
+library (including functions) and it's objects (e.g. the array) is missing.
+
+The spec is the suite's source of truth. If the suite appears to assume
+behaviour different from the spec, or test something that is not documented,
+this is a bug—please [report such
+issues](https://github.com/data-apis/array-api-tests/issues/) to us.
+
+
+## Running on CI
+
+See our existing [GitHub Actions workflow for `array-api-strict`](https://github.com/data-apis/array-api-tests/blob/master/.github/workflows/test.yml)
+for an example of using the test suite on CI. Note [`array-api-strict`](https://github.com/data-apis/array-api-strict)
+is an implementation of the array API that uses NumPy under the hood.
+
+### Releases
+
+We recommend pinning against a [release tag](https://github.com/data-apis/array-api-tests/releases)
+when running on CI.
+
+We use [calender versioning](https://calver.org/) for the releases. You should
+expect that any version may be "breaking" compared to the previous one, in that
+new tests (or improvements to existing tests) may cause a previously passing
+library to fail.
+
+### Configuration
+
+#### Data-dependent shapes
+
+Use the `--disable-data-dependent-shapes` flag to skip testing functions which have
+[data-dependent shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html).
+
+#### Extensions
+
+By default, tests for the optional Array API extensions such as
+[`linalg`](https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html)
+will be skipped if not present in the specified array module. You can purposely
+skip testing extension(s) via the `--disable-extension` option.
+
+#### Skip or XFAIL test cases
+
+Test cases you want to skip can be specified in a skips or XFAILS file. The
+difference between skip and XFAIL is that XFAIL tests are still run and
+reported as XPASS if they pass.
+
+By default, the skips and xfails files are `skips.txt` and `fails.txt` in the root
+of this repository, but any file can be specified with the `--skips-file` and
+`--xfails-file` command line flags.
+
+The files should list the test ids to be skipped/xfailed. Empty lines and
+lines starting with `#` are ignored. The test id can be any substring of the
+test ids to skip/xfail.
-@pytest.mark.parametrize('name', function_stubs.__all__)
-def test_whatever(name):
- ...
```
+# skips.txt or xfails.txt
+# Line comments can be denoted with the hash symbol (#)
-will parameterize `test_whatever` over all the functions stubs generated from
-the spec. Parameterization should be preferred over using Hypothesis whenever
-there are a finite number of input possibilities, as this will cause pytest to
-report failures for all input values separately, as opposed to Hypothesis
-which will only report one failure.
+# Skip specific test case, e.g. when argsort() does not respect relative order
+# https://github.com/numpy/numpy/issues/20778
+array_api_tests/test_sorting_functions.py::test_argsort
-### Error Strings
+# Skip specific test case parameter, e.g. you forgot to implement in-place adds
+array_api_tests/test_add[__iadd__(x1, x2)]
+array_api_tests/test_add[__iadd__(x, s)]
-Any assertion or exception should be accompanied with a useful error message.
-The test suite is designed to be ran by people who are not familiar with the
-test suite code, so the error messages should be self explanatory as to why
-the module fails a given test.
+# Skip module, e.g. when your set functions treat NaNs as non-distinct
+# https://github.com/numpy/numpy/issues/20326
+array_api_tests/test_set_functions.py
+```
-### Meta-errors
+Here is an example GitHub Actions workflow file, where the xfails are stored
+in `array-api-tests.xfails.txt` in the base of the `your-array-library` repo.
+
+If you want, you can use `-o xfail_strict=True`, which causes XPASS tests (XFAIL
+tests that actually pass) to fail the test suite. However, be aware that
+XFAILures can be flaky (see below, so this may not be a good idea unless you
+use some other mitigation of such flakyness).
+
+If you don't want this behavior, you can remove it, or use `--skips-file`
+instead of `--xfails-file`.
+
+```yaml
+# ./.github/workflows/array_api.yml
+jobs:
+ tests:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ['3.8', '3.9', '3.10', '3.11']
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ with:
+ path: your-array-library
+
+ - name: Checkout array-api-tests
+ uses: actions/checkout@v3
+ with:
+ repository: data-apis/array-api-tests
+ submodules: 'true'
+ path: array-api-tests
+
+ - name: Run the array API test suite
+ env:
+ ARRAY_API_TESTS_MODULE: your.array.api.namespace
+ run: |
+ export PYTHONPATH="${GITHUB_WORKSPACE}/your-array-library"
+ cd ${GITHUB_WORKSPACE}/array-api-tests
+ pytest -v -rxXfE --ci --xfails-file ${GITHUB_WORKSPACE}/your-array-library/array-api-tests-xfails.txt array_api_tests/
+```
+
+> **Warning**
+>
+> XFAIL tests that use Hypothesis (basically every test in the test suite except
+> those in test_has_names.py) can be flaky, due to the fact that Hypothesis
+> might not always run the test with an input that causes the test to fail.
+> There are several ways to avoid this problem:
+>
+> - Increase the maximum number of examples, e.g., by adding `--max-examples
+> 200` to the test command (the default is `20`, see below). This will
+> make it more likely that the failing case will be found, but it will also
+> make the tests take longer to run.
+> - Don't use `-o xfail_strict=True`. This will make it so that if an XFAIL
+> test passes, it will alert you in the test summary but will not cause the
+> test run to register as failed.
+> - Use skips instead of XFAILS. The difference between XFAIL and skip is that
+> a skipped test is never run at all, whereas an XFAIL test is always run
+> but ignored if it fails.
+> - Save the [Hypothesis examples
+> database](https://hypothesis.readthedocs.io/en/latest/database.html)
+> persistently on CI. That way as soon as a run finds one failing example,
+> it will always re-run future runs with that example. But note that the
+> Hypothesis examples database may be cleared when a new version of
+> Hypothesis or the test suite is released.
+
+#### Max examples
+
+The tests make heavy use
+[Hypothesis](https://hypothesis.readthedocs.io/en/latest/). You can configure
+how many examples are generated using the `--max-examples` flag, which
+defaults to `20`. Lower values can be useful for quick checks, and larger
+values should result in more rigorous runs. For example, `--max-examples
+10_000` may find bugs where default runs don't but will take much longer to
+run.
+
+#### Skipping Dtypes
+
+The test suite will automatically skip testing of inessential dtypes if they
+are not present on the array module namespace, but dtypes can also be skipped
+manually by setting the environment variable `ARRAY_API_TESTS_SKIP_DTYPES` to
+a comma separated list of dtypes to skip. For example
+
+```
+ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 pytest array_api_tests/
+```
+
+Note that skipping certain essential dtypes such as `bool` and the default
+floating-point dtype is not supported.
+
+#### Turning xfails into skips
+
+Keeping a large number of ``xfails`` can have drastic effects on the run time. This is due
+to the way `hypothesis` works: when it detects a failure, it does a large amount
+of work to simplify the failing example.
+If the run time of the test suite becomes a problem, you can use the
+``ARRAY_API_TESTS_XFAIL_MARK`` environment variable: setting it to ``skip`` skips the
+entries from the ``xfail.txt`` file instead of xfailing them. Anecdotally, we saw
+speed-ups by a factor of 4-5---which allowed us to use 4-5 larger values of
+``--max-examples`` within the same time budget.
+
+#### Limiting the array sizes
+
+The test suite generates random arrays as inputs to functions it tests. "unvectorized"
+tests iterate over elements of arrays, which might be slow. If the run time becomes
+a problem, you can limit the maximum number of elements in generated arrays by
+setting the environment variable ``ARRAY_API_TESTS_MAX_ARRAY_SIZE`` to the
+desired value. By default, it is set to 1024.
+
+
+## Contributing
+
+### Remain in-scope
+
+It is important that every test only uses APIs that are part of the standard.
+For instance, when creating input arrays you should only use the [array creation
+functions](https://data-apis.org/array-api/latest/API_specification/creation_functions.html)
+that are documented in the spec. The same goes for testing arrays—you'll find
+many utilities that parralel NumPy's own test utils in the `*_helpers.py` files.
+
+### Tools
+
+Hypothesis should almost always be used for the primary tests, and can be useful
+elsewhere. Effort should be made so drawn arguments are labeled with their
+respective names. For
+[`st.data()`](https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.data),
+draws should be accompanied with the `label` kwarg i.e. `data.draw(,
+label=)`.
+
+[`pytest.mark.parametrize`](https://docs.pytest.org/en/latest/how-to/parametrize.html)
+should be used to run tests over multiple arguments. Parameterization should be
+preferred over using Hypothesis when there are a small number of possible
+inputs, as this allows better failure reporting. Note using both parametrize and
+Hypothesis for a single test method is possible and can be quite useful.
-Any error that indicates a bug in the test suite itself, rather than in the
-array module not following the spec, should use `RuntimeError` whenever
-possible.
+### Error messages
-(TODO: Update this policy to something better. See [#5](https://github.com/data-apis/array-api-tests/issues/5).)
+Any assertion should be accompanied with a descriptive error message, including
+the relevant values. Error messages should be self-explanatory as to why a given
+test fails, as one should not need prior knowledge of how the test is
+implemented.
-### Automatically Generated Files
+### Generated files
-Some files in the test suite are automatically generated from the API spec
-files. These files should not be edited directly. To regenerate these files,
-run the script
+Some files in the suite are automatically generated from the spec, and should
+not be edited directly. To regenerate these files, run the script
./generate_stubs.py path/to/array-api
-where `path/to/array-api` is the path to the local clone of the `array-api`
-repo. To modify the automatically generated files, edit the code that
-generates them in the `generate_stubs.py` script.
+where `path/to/array-api` is the path to a local clone of the [`array-api`
+repo](https://github.com/data-apis/array-api/). Edit `generate_stubs.py` to make
+changes to the generated files.
+
+
+### Release
+
+To make a release, first make an annotated tag with the version, e.g.:
+
+```
+git tag -a 2022.01.01
+```
+
+Be sure to use the calver version number for the tag name. Don't worry too much
+on the tag message, e.g. just write "2022.01.01".
+
+Versioneer will automatically set the version number of the `array_api_tests`
+package based on the git tag. Push the tag to GitHub:
+
+```
+git push --tags upstream 2022.1
+```
+
+Then go to the [tags page on
+GitHub](https://github.com/data-apis/array-api-tests/tags) and convert the tag
+into a release. If you want, you can add release notes, which GitHub can
+generate for you.
+
+
+---
+
+1 The only exceptions to having just one primary test per function are:
+
+* [`asarray()`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.creation_functions.asarray.html),
+ which is tested by `test_asarray_scalars` and `test_asarray_arrays` in
+ `test_creation_functions.py`. Testing `asarray()` works with scalars (and
+ nested sequences of scalars) is fundamental to testing that it works with
+ arrays, as said arrays can only be generated by passing scalar sequences to
+ `asarray()`.
+
+* Indexing methods
+ ([`__getitem__()`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.__getitem__.html)
+ and
+ [`__setitem__()`](https://data-apis.org/array-api/latest/API_specification/generated/signatures.array_object.array.__setitem__.html)),
+ which respectively have both a test for non-array indices and a test for
+ boolean array indices. This is because [masking is
+ opt-in](https://data-apis.org/array-api/latest/API_specification/indexing.html#boolean-array-indexing)
+ (and boolean arrays need to be generated by indexing arrays anyway).
diff --git a/_config.yml b/_config.yml
new file mode 100644
index 00000000..c7418817
--- /dev/null
+++ b/_config.yml
@@ -0,0 +1 @@
+theme: jekyll-theme-slate
\ No newline at end of file
diff --git a/array-api b/array-api
new file mode 160000
index 00000000..772fb461
--- /dev/null
+++ b/array-api
@@ -0,0 +1 @@
+Subproject commit 772fb461da6ff904ecfcac4a24676e40efcbdb0c
diff --git a/array-api-strict-skips.txt b/array-api-strict-skips.txt
new file mode 100644
index 00000000..afc1b845
--- /dev/null
+++ b/array-api-strict-skips.txt
@@ -0,0 +1,34 @@
+# Known special case issue in NumPy. Not worth working around here
+# https://github.com/numpy/numpy/issues/21213
+array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -infinity and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ipow__(x1_i is -0 and x2_i > 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0]
+
+# The test suite is incorrectly checking sums that have loss of significance
+# (https://github.com/data-apis/array-api-tests/issues/168)
+array_api_tests/test_statistical_functions.py::test_sum
+
+# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, all libraries do just that
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[floor_divide(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_binary[__floordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0]
+array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0]
+
+# FIXME needs array-api-strict >=2.3.2
+array_api_tests/test_data_type_functions.py::test_finfo
+array_api_tests/test_data_type_functions.py::test_finfo_dtype
+array_api_tests/test_data_type_functions.py::test_iinfo
+array_api_tests/test_data_type_functions.py::test_iinfo_dtype
diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py
index e69de29b..d01af52d 100644
--- a/array_api_tests/__init__.py
+++ b/array_api_tests/__init__.py
@@ -0,0 +1,93 @@
+import os
+from functools import wraps
+from importlib import import_module
+
+from hypothesis import strategies as st
+from hypothesis.extra import array_api
+
+from . import _version
+
+__all__ = ["xp", "api_version", "xps"]
+
+
+# You can comment the following out and instead import the specific array module
+# you want to test, e.g. `import array_api_strict as xp`.
+if "ARRAY_API_TESTS_MODULE" in os.environ:
+ env_var = os.environ["ARRAY_API_TESTS_MODULE"]
+ if env_var.startswith("exec('") and env_var.endswith("')"):
+ script = env_var[6:][:-2]
+ namespace = {}
+ exec(script, namespace)
+ xp = namespace["xp"]
+ xp_name = xp.__name__
+ else:
+ xp_name = env_var
+ _module, _sub = xp_name, None
+ if "." in xp_name:
+ _module, _sub = xp_name.split(".", 1)
+ xp = import_module(_module)
+ if _sub:
+ try:
+ xp = getattr(xp, _sub)
+ except AttributeError:
+ # _sub may be a submodule that needs to be imported. We can't
+ # do this in every case because some array modules are not
+ # submodules that can be imported (like mxnet.nd).
+ xp = import_module(xp_name)
+else:
+ raise RuntimeError(
+ "No array module specified - either edit __init__.py or set the "
+ "ARRAY_API_TESTS_MODULE environment variable."
+ )
+
+
+# If xp.bool is not available, like in some versions of NumPy and CuPy, try
+# patching in xp.bool_.
+try:
+ xp.bool
+except AttributeError as e:
+ if hasattr(xp, "bool_"):
+ xp.bool = xp.bool_
+ else:
+ raise e
+
+
+# We monkey patch floats() to always disable subnormals as they are out-of-scope
+
+_floats = st.floats
+
+
+@wraps(_floats)
+def floats(*a, **kw):
+ kw["allow_subnormal"] = False
+ return _floats(*a, **kw)
+
+
+st.floats = floats
+
+
+# We do the same with xps.from_dtype() - this is not strictly necessary, as
+# the underlying floats() will never generate subnormals. We only do this
+# because internal logic in xps.from_dtype() assumes xp.finfo() has its
+# attributes as scalar floats, which is expected behaviour but disrupts many
+# unrelated tests.
+try:
+ __from_dtype = array_api._from_dtype
+
+ @wraps(__from_dtype)
+ def _from_dtype(*a, **kw):
+ kw["allow_subnormal"] = False
+ return __from_dtype(*a, **kw)
+
+ array_api._from_dtype = _from_dtype
+except AttributeError:
+ # Ignore monkey patching if Hypothesis changes the private API
+ pass
+
+
+api_version = os.getenv(
+ "ARRAY_API_TESTS_VERSION", getattr(xp, "__array_api_version__", "2024.12")
+)
+xps = array_api.make_strategies_namespace(xp, api_version=api_version)
+
+__version__ = _version.get_versions()["version"]
diff --git a/array_api_tests/_array_module.py b/array_api_tests/_array_module.py
index aa9f117c..1c52a983 100644
--- a/array_api_tests/_array_module.py
+++ b/array_api_tests/_array_module.py
@@ -1,35 +1,5 @@
-import os
-from importlib import import_module
+from . import stubs, xp
-from . import function_stubs
-
-# Replace this with a specific array module to test it, for example,
-#
-# import numpy as array_module
-array_module = None
-
-if array_module is None:
- if 'ARRAY_API_TESTS_MODULE' in os.environ:
- mod_name = os.environ['ARRAY_API_TESTS_MODULE']
- _module, _sub = mod_name, None
- if '.' in mod_name:
- _module, _sub = mod_name.split('.', 1)
- mod = import_module(_module)
- if _sub:
- try:
- mod = getattr(mod, _sub)
- except AttributeError:
- # _sub may be a submodule that needs to be imported. WE can't
- # do this in every case because some array modules are not
- # submodules that can be imported (like mxnet.nd).
- mod = import_module(mod_name)
- else:
- raise RuntimeError("No array module specified. Either edit _array_module.py or set the ARRAY_API_TESTS_MODULE environment variable")
-else:
- mod = array_module
- mod_name = mod.__name__
-# Names from the spec. This is what should actually be imported from this
-# file.
class _UndefinedStub:
"""
@@ -45,7 +15,7 @@ def __init__(self, name):
self.name = name
def _raise(self, *args, **kwargs):
- raise AssertionError(f"{self.name} is not defined in {mod_name}")
+ raise AssertionError(f"{self.name} is not defined in {xp.__name__}")
def __repr__(self):
return f""
@@ -53,34 +23,20 @@ def __repr__(self):
__call__ = _raise
__getattr__ = _raise
-_integer_dtypes = [
- 'int8',
- 'int16',
- 'int32',
- 'int64',
- 'uint8',
- 'uint16',
- 'uint32',
- 'uint64',
-]
-
-_floating_dtypes = [
- 'float32',
- 'float64',
-]
-
-_numeric_dtypes = [
- *_integer_dtypes,
- *_floating_dtypes,
-]
-
_dtypes = [
- 'bool',
- *_numeric_dtypes
+ "bool",
+ "uint8", "uint16", "uint32", "uint64",
+ "int8", "int16", "int32", "int64",
+ "float32", "float64",
+ "complex64", "complex128",
]
+_constants = ["e", "inf", "nan", "pi"]
+_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
+_funcs += ["take", "isdtype", "conj", "imag", "real"] # TODO: bump spec and update array-api-tests to new spec layout
+_top_level_attrs = _dtypes + _constants + _funcs + stubs.EXTENSIONS + ["fft"]
-for func_name in function_stubs.__all__ + _dtypes:
+for attr in _top_level_attrs:
try:
- globals()[func_name] = getattr(mod, func_name)
+ globals()[attr] = getattr(xp, attr)
except AttributeError:
- globals()[func_name] = _UndefinedStub(func_name)
+ globals()[attr] = _UndefinedStub(attr)
diff --git a/array_api_tests/_version.py b/array_api_tests/_version.py
new file mode 100644
index 00000000..0567c1a6
--- /dev/null
+++ b/array_api_tests/_version.py
@@ -0,0 +1,644 @@
+
+# This file helps to compute a version number in source trees obtained from
+# git-archive tarball (such as those provided by githubs download-from-tag
+# feature). Distribution tarballs (built by setup.py sdist) and build
+# directories (produced by setup.py build) will contain a much shorter file
+# that just contains the computed version number.
+
+# This file is released into the public domain. Generated by
+# versioneer-0.21 (https://github.com/python-versioneer/python-versioneer)
+
+"""Git implementation of _version.py."""
+
+import errno
+import os
+import re
+import subprocess
+import sys
+from typing import Callable, Dict
+
+
+def get_keywords():
+ """Get the keywords needed to look up the version information."""
+ # these strings will be replaced by git during git-archive.
+ # setup.py/versioneer.py will grep for the variable names, so they must
+ # each be defined on a line of their own. _version.py will just call
+ # get_keywords().
+ git_refnames = "$Format:%d$"
+ git_full = "$Format:%H$"
+ git_date = "$Format:%ci$"
+ keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
+ return keywords
+
+
+class VersioneerConfig:
+ """Container for Versioneer configuration parameters."""
+
+
+def get_config():
+ """Create, populate and return the VersioneerConfig() object."""
+ # these strings are filled in when 'setup.py versioneer' creates
+ # _version.py
+ cfg = VersioneerConfig()
+ cfg.VCS = "git"
+ cfg.style = "pep440"
+ cfg.tag_prefix = ""
+ cfg.parentdir_prefix = ""
+ cfg.versionfile_source = "array_api_tests/_version.py"
+ cfg.verbose = False
+ return cfg
+
+
+class NotThisMethod(Exception):
+ """Exception raised if a method is not valid for the current scenario."""
+
+
+LONG_VERSION_PY: Dict[str, str] = {}
+HANDLERS: Dict[str, Dict[str, Callable]] = {}
+
+
+def register_vcs_handler(vcs, method): # decorator
+ """Create decorator to mark a method as the handler of a VCS."""
+ def decorate(f):
+ """Store f in HANDLERS[vcs][method]."""
+ if vcs not in HANDLERS:
+ HANDLERS[vcs] = {}
+ HANDLERS[vcs][method] = f
+ return f
+ return decorate
+
+
+def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
+ env=None):
+ """Call the given command(s)."""
+ assert isinstance(commands, list)
+ process = None
+ for command in commands:
+ try:
+ dispcmd = str([command] + args)
+ # remember shell=False, so use git.cmd on windows, not just git
+ process = subprocess.Popen([command] + args, cwd=cwd, env=env,
+ stdout=subprocess.PIPE,
+ stderr=(subprocess.PIPE if hide_stderr
+ else None))
+ break
+ except OSError:
+ e = sys.exc_info()[1]
+ if e.errno == errno.ENOENT:
+ continue
+ if verbose:
+ print("unable to run %s" % dispcmd)
+ print(e)
+ return None, None
+ else:
+ if verbose:
+ print("unable to find command, tried %s" % (commands,))
+ return None, None
+ stdout = process.communicate()[0].strip().decode()
+ if process.returncode != 0:
+ if verbose:
+ print("unable to run %s (error)" % dispcmd)
+ print("stdout was %s" % stdout)
+ return None, process.returncode
+ return stdout, process.returncode
+
+
+def versions_from_parentdir(parentdir_prefix, root, verbose):
+ """Try to determine the version from the parent directory name.
+
+ Source tarballs conventionally unpack into a directory that includes both
+ the project name and a version string. We will also support searching up
+ two directory levels for an appropriately named parent directory
+ """
+ rootdirs = []
+
+ for _ in range(3):
+ dirname = os.path.basename(root)
+ if dirname.startswith(parentdir_prefix):
+ return {"version": dirname[len(parentdir_prefix):],
+ "full-revisionid": None,
+ "dirty": False, "error": None, "date": None}
+ rootdirs.append(root)
+ root = os.path.dirname(root) # up a level
+
+ if verbose:
+ print("Tried directories %s but none started with prefix %s" %
+ (str(rootdirs), parentdir_prefix))
+ raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
+
+
+@register_vcs_handler("git", "get_keywords")
+def git_get_keywords(versionfile_abs):
+ """Extract version information from the given file."""
+ # the code embedded in _version.py can just fetch the value of these
+ # keywords. When used from setup.py, we don't want to import _version.py,
+ # so we do it with a regexp instead. This function is not used from
+ # _version.py.
+ keywords = {}
+ try:
+ with open(versionfile_abs, "r") as fobj:
+ for line in fobj:
+ if line.strip().startswith("git_refnames ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["refnames"] = mo.group(1)
+ if line.strip().startswith("git_full ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["full"] = mo.group(1)
+ if line.strip().startswith("git_date ="):
+ mo = re.search(r'=\s*"(.*)"', line)
+ if mo:
+ keywords["date"] = mo.group(1)
+ except OSError:
+ pass
+ return keywords
+
+
+@register_vcs_handler("git", "keywords")
+def git_versions_from_keywords(keywords, tag_prefix, verbose):
+ """Get version information from git keywords."""
+ if "refnames" not in keywords:
+ raise NotThisMethod("Short version file found")
+ date = keywords.get("date")
+ if date is not None:
+ # Use only the last line. Previous lines may contain GPG signature
+ # information.
+ date = date.splitlines()[-1]
+
+ # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
+ # datestamp. However we prefer "%ci" (which expands to an "ISO-8601
+ # -like" string, which we must then edit to make compliant), because
+ # it's been around since git-1.5.3, and it's too difficult to
+ # discover which version we're using, or to work around using an
+ # older one.
+ date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
+ refnames = keywords["refnames"].strip()
+ if refnames.startswith("$Format"):
+ if verbose:
+ print("keywords are unexpanded, not using")
+ raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
+ refs = {r.strip() for r in refnames.strip("()").split(",")}
+ # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
+ # just "foo-1.0". If we see a "tag: " prefix, prefer those.
+ TAG = "tag: "
+ tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
+ if not tags:
+ # Either we're using git < 1.8.3, or there really are no tags. We use
+ # a heuristic: assume all version tags have a digit. The old git %d
+ # expansion behaves like git log --decorate=short and strips out the
+ # refs/heads/ and refs/tags/ prefixes that would let us distinguish
+ # between branches and tags. By ignoring refnames without digits, we
+ # filter out many common branch names like "release" and
+ # "stabilization", as well as "HEAD" and "master".
+ tags = {r for r in refs if re.search(r'\d', r)}
+ if verbose:
+ print("discarding '%s', no digits" % ",".join(refs - tags))
+ if verbose:
+ print("likely tags: %s" % ",".join(sorted(tags)))
+ for ref in sorted(tags):
+ # sorting will prefer e.g. "2.0" over "2.0rc1"
+ if ref.startswith(tag_prefix):
+ r = ref[len(tag_prefix):]
+ # Filter out refs that exactly match prefix or that don't start
+ # with a number once the prefix is stripped (mostly a concern
+ # when prefix is '')
+ if not re.match(r'\d', r):
+ continue
+ if verbose:
+ print("picking %s" % r)
+ return {"version": r,
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False, "error": None,
+ "date": date}
+ # no suitable tags, so version is "0+unknown", but full hex is still there
+ if verbose:
+ print("no suitable tags, using unknown + full revision id")
+ return {"version": "0+unknown",
+ "full-revisionid": keywords["full"].strip(),
+ "dirty": False, "error": "no suitable tags", "date": None}
+
+
+@register_vcs_handler("git", "pieces_from_vcs")
+def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
+ """Get version from 'git describe' in the root of the source tree.
+
+ This only gets called if the git-archive 'subst' keywords were *not*
+ expanded, and _version.py hasn't already been rewritten with a short
+ version string, meaning we're inside a checked out source tree.
+ """
+ GITS = ["git"]
+ TAG_PREFIX_REGEX = "*"
+ if sys.platform == "win32":
+ GITS = ["git.cmd", "git.exe"]
+ TAG_PREFIX_REGEX = r"\*"
+
+ _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
+ hide_stderr=True)
+ if rc != 0:
+ if verbose:
+ print("Directory %s not under git control" % root)
+ raise NotThisMethod("'git rev-parse --git-dir' returned error")
+
+ # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
+ # if there isn't one, this yields HEX[-dirty] (no NUM)
+ describe_out, rc = runner(GITS, ["describe", "--tags", "--dirty",
+ "--always", "--long",
+ "--match",
+ "%s%s" % (tag_prefix, TAG_PREFIX_REGEX)],
+ cwd=root)
+ # --long was added in git-1.5.5
+ if describe_out is None:
+ raise NotThisMethod("'git describe' failed")
+ describe_out = describe_out.strip()
+ full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
+ if full_out is None:
+ raise NotThisMethod("'git rev-parse' failed")
+ full_out = full_out.strip()
+
+ pieces = {}
+ pieces["long"] = full_out
+ pieces["short"] = full_out[:7] # maybe improved later
+ pieces["error"] = None
+
+ branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
+ cwd=root)
+ # --abbrev-ref was added in git-1.6.3
+ if rc != 0 or branch_name is None:
+ raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
+ branch_name = branch_name.strip()
+
+ if branch_name == "HEAD":
+ # If we aren't exactly on a branch, pick a branch which represents
+ # the current commit. If all else fails, we are on a branchless
+ # commit.
+ branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
+ # --contains was added in git-1.5.4
+ if rc != 0 or branches is None:
+ raise NotThisMethod("'git branch --contains' returned error")
+ branches = branches.split("\n")
+
+ # Remove the first line if we're running detached
+ if "(" in branches[0]:
+ branches.pop(0)
+
+ # Strip off the leading "* " from the list of branches.
+ branches = [branch[2:] for branch in branches]
+ if "master" in branches:
+ branch_name = "master"
+ elif not branches:
+ branch_name = None
+ else:
+ # Pick the first branch that is returned. Good or bad.
+ branch_name = branches[0]
+
+ pieces["branch"] = branch_name
+
+ # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
+ # TAG might have hyphens.
+ git_describe = describe_out
+
+ # look for -dirty suffix
+ dirty = git_describe.endswith("-dirty")
+ pieces["dirty"] = dirty
+ if dirty:
+ git_describe = git_describe[:git_describe.rindex("-dirty")]
+
+ # now we have TAG-NUM-gHEX or HEX
+
+ if "-" in git_describe:
+ # TAG-NUM-gHEX
+ mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
+ if not mo:
+ # unparsable. Maybe git-describe is misbehaving?
+ pieces["error"] = ("unable to parse git-describe output: '%s'"
+ % describe_out)
+ return pieces
+
+ # tag
+ full_tag = mo.group(1)
+ if not full_tag.startswith(tag_prefix):
+ if verbose:
+ fmt = "tag '%s' doesn't start with prefix '%s'"
+ print(fmt % (full_tag, tag_prefix))
+ pieces["error"] = ("tag '%s' doesn't start with prefix '%s'"
+ % (full_tag, tag_prefix))
+ return pieces
+ pieces["closest-tag"] = full_tag[len(tag_prefix):]
+
+ # distance: number of commits since tag
+ pieces["distance"] = int(mo.group(2))
+
+ # commit: short hex revision ID
+ pieces["short"] = mo.group(3)
+
+ else:
+ # HEX: no tags
+ pieces["closest-tag"] = None
+ count_out, rc = runner(GITS, ["rev-list", "HEAD", "--count"], cwd=root)
+ pieces["distance"] = int(count_out) # total number of commits
+
+ # commit date: see ISO-8601 comment in git_versions_from_keywords()
+ date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
+ # Use only the last line. Previous lines may contain GPG signature
+ # information.
+ date = date.splitlines()[-1]
+ pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
+
+ return pieces
+
+
+def plus_or_dot(pieces):
+ """Return a + if we don't already have one, else return a ."""
+ if "+" in pieces.get("closest-tag", ""):
+ return "."
+ return "+"
+
+
+def render_pep440(pieces):
+ """Build up version string, with post-release "local version identifier".
+
+ Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
+ get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
+
+ Exceptions:
+ 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += plus_or_dot(pieces)
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0+untagged.%d.g%s" % (pieces["distance"],
+ pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
+def render_pep440_branch(pieces):
+ """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
+
+ The ".dev0" means not master branch. Note that .dev0 sorts backwards
+ (a feature branch will appear "older" than the master branch).
+
+ Exceptions:
+ 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0"
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+untagged.%d.g%s" % (pieces["distance"],
+ pieces["short"])
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
+def pep440_split_post(ver):
+ """Split pep440 version string at the post-release segment.
+
+ Returns the release segments before the post-release and the
+ post-release version number (or -1 if no post-release segment is present).
+ """
+ vc = str.split(ver, ".post")
+ return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
+
+
+def render_pep440_pre(pieces):
+ """TAG[.postN.devDISTANCE] -- No -dirty.
+
+ Exceptions:
+ 1: no tags. 0.post0.devDISTANCE
+ """
+ if pieces["closest-tag"]:
+ if pieces["distance"]:
+ # update the post release segment
+ tag_version, post_version = pep440_split_post(pieces["closest-tag"])
+ rendered = tag_version
+ if post_version is not None:
+ rendered += ".post%d.dev%d" % (post_version+1, pieces["distance"])
+ else:
+ rendered += ".post0.dev%d" % (pieces["distance"])
+ else:
+ # no commits, use the tag as the version
+ rendered = pieces["closest-tag"]
+ else:
+ # exception #1
+ rendered = "0.post0.dev%d" % pieces["distance"]
+ return rendered
+
+
+def render_pep440_post(pieces):
+ """TAG[.postDISTANCE[.dev0]+gHEX] .
+
+ The ".dev0" means dirty. Note that .dev0 sorts backwards
+ (a dirty tree will appear "older" than the corresponding clean one),
+ but you shouldn't be releasing software with -dirty anyways.
+
+ Exceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "g%s" % pieces["short"]
+ else:
+ # exception #1
+ rendered = "0.post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ rendered += "+g%s" % pieces["short"]
+ return rendered
+
+
+def render_pep440_post_branch(pieces):
+ """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
+
+ The ".dev0" means not master branch.
+
+ Exceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%d" % pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += plus_or_dot(pieces)
+ rendered += "g%s" % pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ else:
+ # exception #1
+ rendered = "0.post%d" % pieces["distance"]
+ if pieces["branch"] != "master":
+ rendered += ".dev0"
+ rendered += "+g%s" % pieces["short"]
+ if pieces["dirty"]:
+ rendered += ".dirty"
+ return rendered
+
+
+def render_pep440_old(pieces):
+ """TAG[.postDISTANCE[.dev0]] .
+
+ The ".dev0" means dirty.
+
+ Exceptions:
+ 1: no tags. 0.postDISTANCE[.dev0]
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"] or pieces["dirty"]:
+ rendered += ".post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ else:
+ # exception #1
+ rendered = "0.post%d" % pieces["distance"]
+ if pieces["dirty"]:
+ rendered += ".dev0"
+ return rendered
+
+
+def render_git_describe(pieces):
+ """TAG[-DISTANCE-gHEX][-dirty].
+
+ Like 'git describe --tags --dirty --always'.
+
+ Exceptions:
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ if pieces["distance"]:
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
+ else:
+ # exception #1
+ rendered = pieces["short"]
+ if pieces["dirty"]:
+ rendered += "-dirty"
+ return rendered
+
+
+def render_git_describe_long(pieces):
+ """TAG-DISTANCE-gHEX[-dirty].
+
+ Like 'git describe --tags --dirty --always -long'.
+ The distance/hash is unconditional.
+
+ Exceptions:
+ 1: no tags. HEX[-dirty] (note: no 'g' prefix)
+ """
+ if pieces["closest-tag"]:
+ rendered = pieces["closest-tag"]
+ rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
+ else:
+ # exception #1
+ rendered = pieces["short"]
+ if pieces["dirty"]:
+ rendered += "-dirty"
+ return rendered
+
+
+def render(pieces, style):
+ """Render the given version pieces into the requested style."""
+ if pieces["error"]:
+ return {"version": "unknown",
+ "full-revisionid": pieces.get("long"),
+ "dirty": None,
+ "error": pieces["error"],
+ "date": None}
+
+ if not style or style == "default":
+ style = "pep440" # the default
+
+ if style == "pep440":
+ rendered = render_pep440(pieces)
+ elif style == "pep440-branch":
+ rendered = render_pep440_branch(pieces)
+ elif style == "pep440-pre":
+ rendered = render_pep440_pre(pieces)
+ elif style == "pep440-post":
+ rendered = render_pep440_post(pieces)
+ elif style == "pep440-post-branch":
+ rendered = render_pep440_post_branch(pieces)
+ elif style == "pep440-old":
+ rendered = render_pep440_old(pieces)
+ elif style == "git-describe":
+ rendered = render_git_describe(pieces)
+ elif style == "git-describe-long":
+ rendered = render_git_describe_long(pieces)
+ else:
+ raise ValueError("unknown style '%s'" % style)
+
+ return {"version": rendered, "full-revisionid": pieces["long"],
+ "dirty": pieces["dirty"], "error": None,
+ "date": pieces.get("date")}
+
+
+def get_versions():
+ """Get version information or return default if unable to do so."""
+ # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
+ # __file__, we can work backwards from there to the root. Some
+ # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
+ # case we can only use expanded keywords.
+
+ cfg = get_config()
+ verbose = cfg.verbose
+
+ try:
+ return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,
+ verbose)
+ except NotThisMethod:
+ pass
+
+ try:
+ root = os.path.realpath(__file__)
+ # versionfile_source is the relative path from the top of the source
+ # tree (where the .git directory might live) to this file. Invert
+ # this to find the root from __file__.
+ for _ in cfg.versionfile_source.split('/'):
+ root = os.path.dirname(root)
+ except NameError:
+ return {"version": "0+unknown", "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to find root of source tree",
+ "date": None}
+
+ try:
+ pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
+ return render(pieces, cfg.style)
+ except NotThisMethod:
+ pass
+
+ try:
+ if cfg.parentdir_prefix:
+ return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
+ except NotThisMethod:
+ pass
+
+ return {"version": "0+unknown", "full-revisionid": None,
+ "dirty": None,
+ "error": "unable to compute version", "date": None}
diff --git a/array_api_tests/algos.py b/array_api_tests/algos.py
new file mode 100644
index 00000000..8aa77b3f
--- /dev/null
+++ b/array_api_tests/algos.py
@@ -0,0 +1,53 @@
+__all__ = ["broadcast_shapes"]
+
+
+from .typing import Shape
+
+
+# We use a custom exception to differentiate from potential bugs
+class BroadcastError(ValueError):
+ pass
+
+
+def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
+ """Broadcasts `shape1` and `shape2`"""
+ N1 = len(shape1)
+ N2 = len(shape2)
+ N = max(N1, N2)
+ shape = [None for _ in range(N)]
+ i = N - 1
+ while i >= 0:
+ n1 = N1 - N + i
+ if N1 - N + i >= 0:
+ d1 = shape1[n1]
+ else:
+ d1 = 1
+ n2 = N2 - N + i
+ if N2 - N + i >= 0:
+ d2 = shape2[n2]
+ else:
+ d2 = 1
+
+ if d1 == 1:
+ shape[i] = d2
+ elif d2 == 1:
+ shape[i] = d1
+ elif d1 == d2:
+ shape[i] = d1
+ else:
+ raise BroadcastError
+
+ i = i - 1
+
+ return tuple(shape)
+
+
+def broadcast_shapes(*shapes: Shape):
+ if len(shapes) == 0:
+ raise ValueError("shapes=[] must be non-empty")
+ elif len(shapes) == 1:
+ return shapes[0]
+ result = _broadcast_shapes(shapes[0], shapes[1])
+ for i in range(2, len(shapes)):
+ result = _broadcast_shapes(result, shapes[i])
+ return result
diff --git a/array_api_tests/array_helpers.py b/array_api_tests/array_helpers.py
index 457cbd56..a74dab24 100644
--- a/array_api_tests/array_helpers.py
+++ b/array_api_tests/array_helpers.py
@@ -1,28 +1,32 @@
-from ._array_module import (isnan, all, equal, not_equal, logical_and,
- logical_or, isfinite, greater, less, zeros, ones,
- full, bool, int8, int16, int32, int64, uint8,
- uint16, uint32, uint64, float32, float64, nan,
- inf, pi, remainder, divide, isinf)
-
+from ._array_module import (isnan, all, any, equal, not_equal, logical_and,
+ logical_or, isfinite, greater, less_equal,
+ zeros, ones, full, bool, int8, int16, int32,
+ int64, uint8, uint16, uint32, uint64, float32,
+ float64, nan, inf, pi, remainder, divide, isinf,
+ negative, asarray)
# These are exported here so that they can be included in the special cases
# tests from this file.
from ._array_module import logical_not, subtract, floor, ceil, where
-
-__all__ = ['logical_and', 'logical_or', 'logical_not', 'less', 'greater',
- 'subtract', 'floor', 'ceil', 'where', 'isfinite', 'equal',
- 'not_equal', 'zero', 'one', 'NaN', 'infinity', 'π', 'isnegzero',
- 'non_zero', 'isposzero', 'exactly_equal', 'assert_exactly_equal',
- 'notequal', 'assert_finite', 'assert_non_zero', 'ispositive',
+from . import _array_module as xp
+from . import dtype_helpers as dh
+
+__all__ = ['all', 'any', 'logical_and', 'logical_or', 'logical_not', 'less',
+ 'less_equal', 'greater', 'subtract', 'negative', 'floor', 'ceil',
+ 'where', 'isfinite', 'equal', 'not_equal', 'zero', 'one', 'NaN',
+ 'infinity', 'π', 'isnegzero', 'non_zero', 'isposzero',
+ 'exactly_equal', 'assert_exactly_equal', 'notequal',
+ 'assert_finite', 'assert_non_zero', 'ispositive',
'assert_positive', 'isnegative', 'assert_negative', 'isintegral',
'assert_integral', 'isodd', 'iseven', "assert_iseven",
'assert_isinf', 'positive_mathematical_sign',
'assert_positive_mathematical_sign', 'negative_mathematical_sign',
'assert_negative_mathematical_sign', 'same_sign',
- 'assert_same_sign']
+ 'assert_same_sign', 'float64',
+ 'asarray', 'full', 'true', 'false', 'isnan']
def zero(shape, dtype):
"""
- Returns a scalar 0 of the given dtype.
+ Returns a full 0 array of the given dtype.
This should be used in place of the literal "0" in the test suite, as the
spec does not require any behavior with Python literals (and in
@@ -36,7 +40,7 @@ def zero(shape, dtype):
def one(shape, dtype):
"""
- Returns a scalar 1 of the given dtype.
+ Returns a full 1 array of the given dtype.
This should be used in place of the literal "1" in the test suite, as the
spec does not require any behavior with Python literals (and in
@@ -49,7 +53,7 @@ def one(shape, dtype):
def NaN(shape, dtype):
"""
- Returns a scalar nan of the given dtype.
+ Returns a full nan array of the given dtype.
Note that this is only defined for floating point dtypes.
"""
@@ -59,7 +63,7 @@ def NaN(shape, dtype):
def infinity(shape, dtype):
"""
- Returns a scalar positive infinity of the given dtype.
+ Returns a full positive infinity array of the given dtype.
Note that this is only defined for floating point dtypes.
@@ -72,7 +76,7 @@ def infinity(shape, dtype):
def π(shape, dtype):
"""
- Returns a scalar π.
+ Returns a full π array of the given dtype.
Note that this function is only defined for floating point dtype.
@@ -83,22 +87,38 @@ def π(shape, dtype):
raise RuntimeError(f"Unexpected dtype {dtype} in π().")
return full(shape, pi, dtype=dtype)
+def true(shape):
+ """
+ Returns a full True array with dtype=bool.
+ """
+ return full(shape, True, dtype=bool)
+
+def false(shape):
+ """
+ Returns a full False array with dtype=bool.
+ """
+ return full(shape, False, dtype=bool)
+
def isnegzero(x):
"""
- Returns a mask where x is -0.
+ Returns a mask where x is -0. Is all False if x has integer dtype.
"""
# TODO: If copysign or signbit are added to the spec, use those instead.
shape = x.shape
dtype = x.dtype
+ if dh.is_int_dtype(dtype):
+ return false(shape)
return equal(divide(one(shape, dtype), x), -infinity(shape, dtype))
def isposzero(x):
"""
- Returns a mask where x is +0 (but not -0).
+ Returns a mask where x is +0 (but not -0). Is all True if x has integer dtype.
"""
# TODO: If copysign or signbit are added to the spec, use those instead.
shape = x.shape
dtype = x.dtype
+ if dh.is_int_dtype(dtype):
+ return true(shape)
return equal(divide(one(shape, dtype), x), infinity(shape, dtype))
def exactly_equal(x, y):
@@ -144,7 +164,17 @@ def notequal(x, y):
return not_equal(x, y)
-def assert_exactly_equal(x, y):
+def less(x, y):
+ """
+ Same as less(x, y) except it allows comparing uint64 with signed int dtypes
+ """
+ if x.dtype == uint64 and dh.dtype_signed[y.dtype]:
+ return xp.where(y < 0, xp.asarray(False), xp.less(x, xp.astype(y, uint64)))
+ if y.dtype == uint64 and dh.dtype_signed[x.dtype]:
+ return xp.where(x < 0, xp.asarray(True), xp.less(xp.astype(x, uint64), y))
+ return xp.less(x, y)
+
+def assert_exactly_equal(x, y, msg_extra=None):
"""
Test that the arrays x and y are exactly equal.
@@ -152,11 +182,13 @@ def assert_exactly_equal(x, y):
equal.
"""
- assert x.shape == y.shape, "The input arrays do not have the same shapes"
+ extra = '' if not msg_extra else f' ({msg_extra})'
- assert x.dtype == y.dtype, "The input arrays do not have the same dtype"
+ assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"
- assert all(exactly_equal(x, y)), "The input arrays have different values"
+ assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"
+
+ assert all(exactly_equal(x, y)), f"The input arrays have different values ({x!r} != {y!r}){extra}"
def assert_finite(x):
"""
@@ -182,9 +214,19 @@ def isnegative(x):
def assert_negative(x):
assert all(isnegative(x)), "The input array is not negative"
+def inrange(x, a, b, epsilon=0, open=False):
+ """
+ Returns a mask for values of x in the range [a-epsilon, a+epsilon], inclusive
+
+ If open=True, the range is (a-epsilon, a+epsilon) (i.e., not inclusive).
+ """
+ eps = full(x.shape, epsilon, dtype=x.dtype)
+ l = less if open else less_equal
+ return logical_and(l(a-eps, x), l(x, b+eps))
+
def isintegral(x):
"""
- Returns a mask the shape of x where the values are integral
+ Returns a mask on x where the values are integral
x is integral if its dtype is an integer dtype, or if it is a floating
point value that can be exactly represented as an integer.
@@ -237,7 +279,8 @@ def positive_mathematical_sign(x):
nans, as signed nans are not required by the spec.
"""
- return logical_or(greater(x, 0), isposzero(x))
+ z = zero(x.shape, x.dtype)
+ return logical_or(greater(x, z), isposzero(x))
def assert_positive_mathematical_sign(x):
assert all(positive_mathematical_sign(x)), "The input arrays do not have a positive mathematical sign"
@@ -251,7 +294,10 @@ def negative_mathematical_sign(x):
nans, as signed nans are not required by the spec.
"""
- return logical_or(less(x, 0), isnegzero(x))
+ z = zero(x.shape, x.dtype)
+ if x.dtype in [float32, float64]:
+ return logical_or(less(x, z), isnegzero(x))
+ return less(x, z)
def assert_negative_mathematical_sign(x):
assert all(negative_mathematical_sign(x)), "The input arrays do not have a negative mathematical sign"
@@ -272,21 +318,13 @@ def same_sign(x, y):
def assert_same_sign(x, y):
assert all(same_sign(x, y)), "The input arrays do not have the same sign"
-def is_integer_dtype(dtype):
- return dtype in [int8, int16, int32, int16, uint8, uint16, uint32, uint64]
-
-def is_float_dtype(dtype):
- # TODO: Return True even for floating point dtypes that aren't part of the
- # spec, like np.float16
- return dtype in [float32, float64]
-
-dtype_ranges = {
- int8: [-128, +127],
- int16: [-32_767, +32_767],
- int32: [-2_147_483_647, +2_147_483_647],
- int64: [-9_223_372_036_854_775_807, +9_223_372_036_854_775_807],
- uint8: [0, +255],
- uint16: [0, +65_535],
- uint32: [0, +4_294_967_295],
- uint64: [0, +18_446_744_073_709_551_615],
-}
+def _matrix_transpose(x):
+ if not isinstance(xp.matrix_transpose, xp._UndefinedStub):
+ return xp.matrix_transpose(x)
+ if hasattr(x, 'mT'):
+ return x.mT
+ if not isinstance(xp.permute_dims, xp._UndefinedStub):
+ perm = list(range(x.ndim))
+ perm[-1], perm[-2] = perm[-2], perm[-1]
+ return xp.permute_dims(x, axes=tuple(perm))
+ raise NotImplementedError("No way to compute matrix transpose")
diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py
new file mode 100644
index 00000000..f7fa306b
--- /dev/null
+++ b/array_api_tests/dtype_helpers.py
@@ -0,0 +1,594 @@
+import os
+import re
+from collections import defaultdict
+from collections.abc import Mapping
+from functools import lru_cache
+from typing import Any, DefaultDict, Dict, List, NamedTuple, Sequence, Tuple, Union
+from warnings import warn
+
+from . import api_version
+from . import xp
+from .stubs import name_to_func
+from .typing import DataType, ScalarType
+
+__all__ = [
+ "uint_names",
+ "int_names",
+ "all_int_names",
+ "real_float_names",
+ "real_names",
+ "complex_names",
+ "numeric_names",
+ "dtype_names",
+ "int_dtypes",
+ "uint_dtypes",
+ "all_int_dtypes",
+ "real_float_dtypes",
+ "real_dtypes",
+ "numeric_dtypes",
+ "all_dtypes",
+ "all_float_dtypes",
+ "bool_and_all_int_dtypes",
+ "dtype_to_name",
+ "kind_to_dtypes",
+ "is_int_dtype",
+ "is_float_dtype",
+ "get_scalar_type",
+ "is_scalar",
+ "dtype_ranges",
+ "default_int",
+ "default_uint",
+ "default_float",
+ "default_complex",
+ "promotion_table",
+ "dtype_nbits",
+ "dtype_signed",
+ "dtype_components",
+ "func_in_dtypes",
+ "func_returns_bool",
+ "binary_op_to_symbol",
+ "unary_op_to_symbol",
+ "inplace_op_to_symbol",
+ "op_to_func",
+ "fmt_types",
+]
+
+
+class EqualityMapping(Mapping):
+ """
+ Mapping that uses equality for indexing
+
+ Typical mappings (e.g. the built-in dict) use hashing for indexing. This
+ isn't ideal for the Array API, as no __hash__() method is specified for
+ dtype objects - but __eq__() is!
+
+ See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
+ """
+
+ def __init__(self, key_value_pairs: Sequence[Tuple[Any, Any]]):
+ keys = [k for k, _ in key_value_pairs]
+ for i, key in enumerate(keys):
+ if not (key == key): # specifically checking __eq__, not __neq__
+ raise ValueError(f"Key {key!r} does not have equality with itself")
+ other_keys = keys[:]
+ other_keys.pop(i)
+ for other_key in other_keys:
+ if key == other_key:
+ raise ValueError(f"Key {key!r} has equality with key {other_key!r}")
+ self._key_value_pairs = key_value_pairs
+
+ def __getitem__(self, key):
+ for k, v in self._key_value_pairs:
+ if key == k:
+ return v
+ else:
+ raise KeyError(f"{key!r} not found")
+
+ def __iter__(self):
+ return (k for k, _ in self._key_value_pairs)
+
+ def __len__(self):
+ return len(self._key_value_pairs)
+
+ def __str__(self):
+ return "{" + ", ".join(f"{k!r}: {v!r}" for k, v in self._key_value_pairs) + "}"
+
+ def __repr__(self):
+ return f"EqualityMapping({self})"
+
+
+uint_names = ("uint8", "uint16", "uint32", "uint64")
+int_names = ("int8", "int16", "int32", "int64")
+all_int_names = uint_names + int_names
+real_float_names = ("float32", "float64")
+real_names = uint_names + int_names + real_float_names
+complex_names = ("complex64", "complex128")
+numeric_names = real_names + complex_names
+dtype_names = ("bool",) + numeric_names
+
+_skip_dtypes = os.getenv("ARRAY_API_TESTS_SKIP_DTYPES", '')
+_skip_dtypes = _skip_dtypes.split(',')
+skip_dtypes = []
+for dtype in _skip_dtypes:
+ if dtype and dtype not in dtype_names:
+ raise ValueError(f"Invalid dtype name in ARRAY_API_TESTS_SKIP_DTYPES: {dtype}")
+ skip_dtypes.append(dtype)
+
+_name_to_dtype = {}
+for name in dtype_names:
+ if name in skip_dtypes:
+ continue
+ try:
+ dtype = getattr(xp, name)
+ except AttributeError:
+ continue
+ _name_to_dtype[name] = dtype
+dtype_to_name = EqualityMapping([(d, n) for n, d in _name_to_dtype.items()])
+
+
+def _make_dtype_tuple_from_names(names: List[str]) -> Tuple[DataType]:
+ dtypes = []
+ for name in names:
+ try:
+ dtype = _name_to_dtype[name]
+ except KeyError:
+ continue
+ dtypes.append(dtype)
+ return tuple(dtypes)
+
+
+uint_dtypes = _make_dtype_tuple_from_names(uint_names)
+int_dtypes = _make_dtype_tuple_from_names(int_names)
+real_float_dtypes = _make_dtype_tuple_from_names(real_float_names)
+all_int_dtypes = uint_dtypes + int_dtypes
+real_dtypes = all_int_dtypes + real_float_dtypes
+complex_dtypes = _make_dtype_tuple_from_names(complex_names)
+numeric_dtypes = real_dtypes
+if api_version > "2021.12":
+ numeric_dtypes += complex_dtypes
+all_dtypes = (xp.bool,) + numeric_dtypes
+all_float_dtypes = real_float_dtypes
+if api_version > "2021.12":
+ all_float_dtypes += complex_dtypes
+bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
+
+
+kind_to_dtypes = {
+ "bool": [xp.bool],
+ "signed integer": int_dtypes,
+ "unsigned integer": uint_dtypes,
+ "integral": all_int_dtypes,
+ "real floating": real_float_dtypes,
+ "complex floating": complex_dtypes,
+ "numeric": numeric_dtypes,
+}
+
+
+def available_kinds():
+ return {
+ kind for kind, dtypes in kind_to_dtypes.items() if dtypes
+ }
+
+def is_int_dtype(dtype):
+ return dtype in all_int_dtypes
+
+
+def is_float_dtype(dtype, *, include_complex=True):
+ # None equals NumPy's xp.float64 object, so we specifically check it here.
+ # xp.float64 is in fact an alias of np.dtype('float64'), and its equality
+ # with None is meant to be deprecated at some point.
+ # See https://github.com/numpy/numpy/issues/18434
+ if dtype is None:
+ return False
+ valid_dtypes = real_float_dtypes
+ if api_version > "2021.12" and include_complex:
+ valid_dtypes += complex_dtypes
+ return dtype in valid_dtypes
+
+def get_scalar_type(dtype: DataType) -> ScalarType:
+ if dtype in all_int_dtypes:
+ return int
+ elif dtype in real_float_dtypes:
+ return float
+ elif dtype in complex_dtypes:
+ return complex
+ else:
+ return bool
+
+def is_scalar(x):
+ return isinstance(x, (int, float, complex, bool))
+
+
+def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
+ dtype_value_pairs = []
+ for name, value in mapping.items():
+ assert isinstance(name, str) and name in dtype_names # sanity check
+ if name in _name_to_dtype:
+ dtype = _name_to_dtype[name]
+ else:
+ continue
+ dtype_value_pairs.append((dtype, value))
+ return EqualityMapping(dtype_value_pairs)
+
+
+class MinMax(NamedTuple):
+ min: Union[int, float]
+ max: Union[int, float]
+
+ def __contains__(self, other):
+ assert isinstance(other, (int, float))
+ return self.min <= other <= self.max
+
+dtype_ranges = _make_dtype_mapping_from_names(
+ {
+ "int8": MinMax(-128, +127),
+ "int16": MinMax(-32_768, +32_767),
+ "int32": MinMax(-2_147_483_648, +2_147_483_647),
+ "int64": MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
+ "uint8": MinMax(0, +255),
+ "uint16": MinMax(0, +65_535),
+ "uint32": MinMax(0, +4_294_967_295),
+ "uint64": MinMax(0, +18_446_744_073_709_551_615),
+ "float32": MinMax(-3.4028234663852886e38, 3.4028234663852886e38),
+ "float64": MinMax(-1.7976931348623157e308, 1.7976931348623157e308),
+ }
+)
+
+
+r_nbits = re.compile(r"[a-z]+([0-9]+)")
+_dtype_nbits: Dict[str, int] = {}
+for name in numeric_names:
+ m = r_nbits.fullmatch(name)
+ assert m is not None # sanity check / for mypy
+ _dtype_nbits[name] = int(m.group(1))
+dtype_nbits = _make_dtype_mapping_from_names(_dtype_nbits)
+
+
+dtype_signed = _make_dtype_mapping_from_names(
+ {**{name: True for name in int_names}, **{name: False for name in uint_names}}
+)
+
+
+dtype_components = _make_dtype_mapping_from_names(
+ {"complex64": xp.float32, "complex128": xp.float64}
+)
+
+def as_real_dtype(dtype):
+ """
+ Return the corresponding real dtype for a given floating-point dtype.
+ """
+ if dtype in real_float_dtypes:
+ return dtype
+ elif dtype_to_name[dtype] in complex_names:
+ return dtype_components[dtype]
+ else:
+ raise ValueError("as_real_dtype requires a floating-point dtype")
+
+def accumulation_result_dtype(x_dtype, dtype_kwarg):
+ """
+ Result dtype logic for sum(), prod(), and trace()
+
+ Note: may return None if a default uint cannot exist (e.g., for pytorch
+ which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106
+
+ """
+ if dtype_kwarg is None:
+ if is_int_dtype(x_dtype):
+ if x_dtype in uint_dtypes:
+ default_dtype = default_uint
+ else:
+ default_dtype = default_int
+ if default_dtype is None:
+ _dtype = None
+ else:
+ m, M = dtype_ranges[x_dtype]
+ d_m, d_M = dtype_ranges[default_dtype]
+ if m < d_m or M > d_M:
+ _dtype = x_dtype
+ else:
+ _dtype = default_dtype
+ elif api_version >= '2023.12':
+ # Starting in 2023.12, floats should not promote with dtype=None
+ _dtype = x_dtype
+ elif is_float_dtype(x_dtype, include_complex=False):
+ if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
+ _dtype = x_dtype
+ else:
+ _dtype = default_float
+ elif api_version > "2021.12":
+ # Complex dtype
+ if dtype_nbits[x_dtype] > dtype_nbits[default_complex]:
+ _dtype = x_dtype
+ else:
+ _dtype = default_complex
+ else:
+ raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
+ else:
+ _dtype = dtype_kwarg
+
+ return _dtype
+
+if not hasattr(xp, "asarray"):
+ default_int = xp.int32
+ default_float = xp.float32
+ # TODO: when api_version > '2021.12', just assign to xp.complex64,
+ # otherwise default to None. Need array-api spec to be bumped first (#187).
+ try:
+ default_complex = xp.complex64
+ except AttributeError:
+ default_complex = None
+ warn(
+ "array module does not have attribute asarray. "
+ "default int is assumed int32, default float is assumed float32"
+ )
+else:
+ default_int = xp.asarray(int()).dtype
+ if default_int not in int_dtypes:
+ warn(f"inferred default int is {default_int!r}, which is not an int")
+ default_float = xp.asarray(float()).dtype
+ if default_float not in real_float_dtypes:
+ warn(f"inferred default float is {default_float!r}, which is not a float")
+ if api_version > "2021.12" and ({'complex64', 'complex128'} - set(skip_dtypes)):
+ default_complex = xp.asarray(complex()).dtype
+ if default_complex not in complex_dtypes:
+ warn(
+ f"inferred default complex is {default_complex!r}, "
+ "which is not a complex"
+ )
+ else:
+ default_complex = None
+
+if dtype_nbits[default_int] == 32:
+ default_uint = _name_to_dtype.get("uint32")
+else:
+ default_uint = _name_to_dtype.get("uint64")
+
+_promotion_table: Dict[Tuple[str, str], str] = {
+ ("bool", "bool"): "bool",
+ # ints
+ ("int8", "int8"): "int8",
+ ("int8", "int16"): "int16",
+ ("int8", "int32"): "int32",
+ ("int8", "int64"): "int64",
+ ("int16", "int16"): "int16",
+ ("int16", "int32"): "int32",
+ ("int16", "int64"): "int64",
+ ("int32", "int32"): "int32",
+ ("int32", "int64"): "int64",
+ ("int64", "int64"): "int64",
+ # uints
+ ("uint8", "uint8"): "uint8",
+ ("uint8", "uint16"): "uint16",
+ ("uint8", "uint32"): "uint32",
+ ("uint8", "uint64"): "uint64",
+ ("uint16", "uint16"): "uint16",
+ ("uint16", "uint32"): "uint32",
+ ("uint16", "uint64"): "uint64",
+ ("uint32", "uint32"): "uint32",
+ ("uint32", "uint64"): "uint64",
+ ("uint64", "uint64"): "uint64",
+ # ints and uints (mixed sign)
+ ("int8", "uint8"): "int16",
+ ("int8", "uint16"): "int32",
+ ("int8", "uint32"): "int64",
+ ("int16", "uint8"): "int16",
+ ("int16", "uint16"): "int32",
+ ("int16", "uint32"): "int64",
+ ("int32", "uint8"): "int32",
+ ("int32", "uint16"): "int32",
+ ("int32", "uint32"): "int64",
+ ("int64", "uint8"): "int64",
+ ("int64", "uint16"): "int64",
+ ("int64", "uint32"): "int64",
+ # floats
+ ("float32", "float32"): "float32",
+ ("float32", "float64"): "float64",
+ ("float64", "float64"): "float64",
+ # complex
+ ("complex64", "complex64"): "complex64",
+ ("complex64", "complex128"): "complex128",
+ ("complex128", "complex128"): "complex128",
+}
+_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()})
+_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = []
+for (in_name1, in_name2), res_name in _promotion_table.items():
+ if in_name1 not in _name_to_dtype or in_name2 not in _name_to_dtype or res_name not in _name_to_dtype:
+ continue
+ in_dtype1 = _name_to_dtype[in_name1]
+ in_dtype2 = _name_to_dtype[in_name2]
+ res_dtype = _name_to_dtype[res_name]
+
+ _promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype))
+promotion_table = EqualityMapping(_promotion_table_pairs)
+
+
+def result_type(*dtypes: DataType):
+ if len(dtypes) == 0:
+ raise ValueError()
+ elif len(dtypes) == 1:
+ return dtypes[0]
+ result = promotion_table[dtypes[0], dtypes[1]]
+ for i in range(2, len(dtypes)):
+ result = promotion_table[result, dtypes[i]]
+ return result
+
+
+r_alias = re.compile("[aA]lias")
+r_in_dtypes = re.compile("x1?: array\n.+have an? (.+) data type.")
+r_int_note = re.compile(
+ "If one or both of the input arrays have integer data types, "
+ "the result is implementation-dependent"
+)
+category_to_dtypes = {
+ "boolean": (xp.bool,),
+ "integer": all_int_dtypes,
+ "floating-point": real_float_dtypes,
+ "real-valued": real_float_dtypes,
+ "real-valued floating-point": real_float_dtypes,
+ "complex floating-point": complex_dtypes,
+ "numeric": numeric_dtypes,
+ "integer or boolean": bool_and_all_int_dtypes,
+}
+func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes)
+for name, func in name_to_func.items():
+ assert func.__doc__ is not None # for mypy
+ if m := r_in_dtypes.search(func.__doc__):
+ dtype_category = m.group(1)
+ if dtype_category == "numeric" and r_int_note.search(func.__doc__):
+ dtype_category = "floating-point"
+ dtypes = category_to_dtypes[dtype_category]
+ func_in_dtypes[name] = dtypes
+
+
+func_returns_bool = {
+ # elementwise
+ "abs": False,
+ "acos": False,
+ "acosh": False,
+ "add": False,
+ "asin": False,
+ "asinh": False,
+ "atan": False,
+ "atan2": False,
+ "atanh": False,
+ "bitwise_and": False,
+ "bitwise_invert": False,
+ "bitwise_left_shift": False,
+ "bitwise_or": False,
+ "bitwise_right_shift": False,
+ "bitwise_xor": False,
+ "ceil": False,
+ "cos": False,
+ "cosh": False,
+ "divide": False,
+ "equal": True,
+ "exp": False,
+ "expm1": False,
+ "floor": False,
+ "floor_divide": False,
+ "greater": True,
+ "greater_equal": True,
+ "isfinite": True,
+ "isinf": True,
+ "isnan": True,
+ "less": True,
+ "less_equal": True,
+ "log": False,
+ "logaddexp": False,
+ "log10": False,
+ "log1p": False,
+ "log2": False,
+ "logical_and": True,
+ "logical_not": True,
+ "logical_or": True,
+ "logical_xor": True,
+ "multiply": False,
+ "negative": False,
+ "not_equal": True,
+ "positive": False,
+ "pow": False,
+ "remainder": False,
+ "round": False,
+ "sign": False,
+ "sin": False,
+ "sinh": False,
+ "sqrt": False,
+ "square": False,
+ "subtract": False,
+ "tan": False,
+ "tanh": False,
+ "trunc": False,
+ # searching
+ "where": False,
+ # linalg
+ "matmul": False,
+}
+
+
+unary_op_to_symbol = {
+ "__invert__": "~",
+ "__neg__": "-",
+ "__pos__": "+",
+}
+
+
+binary_op_to_symbol = {
+ "__add__": "+",
+ "__and__": "&",
+ "__eq__": "==",
+ "__floordiv__": "//",
+ "__ge__": ">=",
+ "__gt__": ">",
+ "__le__": "<=",
+ "__lshift__": "<<",
+ "__lt__": "<",
+ "__matmul__": "@",
+ "__mod__": "%",
+ "__mul__": "*",
+ "__ne__": "!=",
+ "__or__": "|",
+ "__pow__": "**",
+ "__rshift__": ">>",
+ "__sub__": "-",
+ "__truediv__": "/",
+ "__xor__": "^",
+}
+
+
+op_to_func = {
+ "__abs__": "abs",
+ "__add__": "add",
+ "__and__": "bitwise_and",
+ "__eq__": "equal",
+ "__floordiv__": "floor_divide",
+ "__ge__": "greater_equal",
+ "__gt__": "greater",
+ "__le__": "less_equal",
+ "__lt__": "less",
+ "__matmul__": "matmul",
+ "__mod__": "remainder",
+ "__mul__": "multiply",
+ "__ne__": "not_equal",
+ "__or__": "bitwise_or",
+ "__pow__": "pow",
+ "__lshift__": "bitwise_left_shift",
+ "__rshift__": "bitwise_right_shift",
+ "__sub__": "subtract",
+ "__truediv__": "divide",
+ "__xor__": "bitwise_xor",
+ "__invert__": "bitwise_invert",
+ "__neg__": "negative",
+ "__pos__": "positive",
+}
+
+
+# Construct func_in_dtypes and func_returns bool
+for op, elwise_func in op_to_func.items():
+ func_in_dtypes[op] = func_in_dtypes[elwise_func]
+ func_returns_bool[op] = func_returns_bool[elwise_func]
+inplace_op_to_symbol = {}
+for op, symbol in binary_op_to_symbol.items():
+ if op == "__matmul__" or func_returns_bool[op]:
+ continue
+ iop = f"__i{op[2:]}"
+ inplace_op_to_symbol[iop] = f"{symbol}="
+ func_in_dtypes[iop] = func_in_dtypes[op]
+ func_returns_bool[iop] = func_returns_bool[op]
+func_in_dtypes["__bool__"] = (xp.bool,)
+func_in_dtypes["__int__"] = all_int_dtypes
+func_in_dtypes["__index__"] = all_int_dtypes
+func_in_dtypes["__float__"] = real_float_dtypes
+func_in_dtypes["from_dlpack"] = numeric_dtypes
+func_in_dtypes["__dlpack__"] = numeric_dtypes
+
+
+@lru_cache
+def fmt_types(types: Tuple[Union[DataType, ScalarType], ...]) -> str:
+ f_types = []
+ for type_ in types:
+ try:
+ f_types.append(dtype_to_name[type_])
+ except KeyError:
+ # i.e. dtype is bool, int, or float
+ f_types.append(type_.__name__)
+ return ", ".join(f_types)
diff --git a/array_api_tests/function_stubs/__init__.py b/array_api_tests/function_stubs/__init__.py
deleted file mode 100644
index 8028bd1f..00000000
--- a/array_api_tests/function_stubs/__init__.py
+++ /dev/null
@@ -1,54 +0,0 @@
-"""
-Stub definitions for functions defined in the spec
-
-These are used to test function signatures.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-__all__ = []
-
-from .array_object import __abs__, __add__, __and__, __eq__, __floordiv__, __ge__, __getitem__, __gt__, __invert__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, dtype, device, ndim, shape, size, T
-
-__all__ += ['__abs__', '__add__', '__and__', '__eq__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
-
-from .constants import e, inf, nan, pi
-
-__all__ += ['e', 'inf', 'nan', 'pi']
-
-from .creation_functions import arange, empty, empty_like, eye, full, full_like, linspace, ones, ones_like, zeros, zeros_like
-
-__all__ += ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like']
-
-from .elementwise_functions import abs, acos, acosh, add, asin, asinh, atan, atan2, atanh, bitwise_and, bitwise_left_shift, bitwise_invert, bitwise_or, bitwise_right_shift, bitwise_xor, ceil, cos, cosh, divide, equal, exp, expm1, floor, floor_divide, greater, greater_equal, isfinite, isinf, isnan, less, less_equal, log, log1p, log2, log10, logical_and, logical_not, logical_or, logical_xor, multiply, negative, not_equal, positive, pow, remainder, round, sign, sin, sinh, square, sqrt, subtract, tan, tanh, trunc
-
-__all__ += ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
-
-from .linear_algebra_functions import cholesky, cross, det, diagonal, dot, eig, eigvalsh, einsum, inv, lstsq, matmul, matrix_power, matrix_rank, norm, outer, pinv, qr, slogdet, solve, svd, trace, transpose
-
-__all__ += ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
-
-from .manipulation_functions import concat, expand_dims, flip, reshape, roll, squeeze, stack
-
-__all__ += ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack']
-
-from .searching_functions import argmax, argmin, nonzero, where
-
-__all__ += ['argmax', 'argmin', 'nonzero', 'where']
-
-from .set_functions import unique
-
-__all__ += ['unique']
-
-from .sorting_functions import argsort, sort
-
-__all__ += ['argsort', 'sort']
-
-from .statistical_functions import max, mean, min, prod, std, sum, var
-
-__all__ += ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var']
-
-from .utility_functions import all, any
-
-__all__ += ['all', 'any']
diff --git a/array_api_tests/function_stubs/_types.py b/array_api_tests/function_stubs/_types.py
deleted file mode 100644
index ef894e7e..00000000
--- a/array_api_tests/function_stubs/_types.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""
-This file defines the types for type annotations.
-
-The type variables should be replaced with the actual types for a given
-library, e.g., for NumPy TypeVar('array') would be replaced with ndarray.
-"""
-
-from typing import Literal, Optional, Tuple, Union, TypeVar
-
-array = TypeVar('array')
-device = TypeVar('device')
-dtype = TypeVar('dtype')
-
-__all__ = ['Literal', 'Optional', 'Tuple', 'Union', 'array', 'device', 'dtype']
-
diff --git a/array_api_tests/function_stubs/array_object.py b/array_api_tests/function_stubs/array_object.py
deleted file mode 100644
index 4f8080d6..00000000
--- a/array_api_tests/function_stubs/array_object.py
+++ /dev/null
@@ -1,189 +0,0 @@
-"""
-Function stubs for array object.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/array_object.md
-"""
-
-from __future__ import annotations
-
-from ._types import array
-
-def __abs__(x: array, /) -> array:
- """
- Note: __abs__ is a method of the array object.
- """
- pass
-
-def __add__(x1: array, x2: array, /) -> array:
- """
- Note: __add__ is a method of the array object.
- """
- pass
-
-def __and__(x1: array, x2: array, /) -> array:
- """
- Note: __and__ is a method of the array object.
- """
- pass
-
-def __eq__(x1: array, x2: array, /) -> array:
- """
- Note: __eq__ is a method of the array object.
- """
- pass
-
-def __floordiv__(x1: array, x2: array, /) -> array:
- """
- Note: __floordiv__ is a method of the array object.
- """
- pass
-
-def __ge__(x1: array, x2: array, /) -> array:
- """
- Note: __ge__ is a method of the array object.
- """
- pass
-
-def __getitem__(x, key, /):
- """
- Note: __getitem__ is a method of the array object.
- """
- pass
-
-def __gt__(x1: array, x2: array, /) -> array:
- """
- Note: __gt__ is a method of the array object.
- """
- pass
-
-def __invert__(x: array, /) -> array:
- """
- Note: __invert__ is a method of the array object.
- """
- pass
-
-def __le__(x1: array, x2: array, /) -> array:
- """
- Note: __le__ is a method of the array object.
- """
- pass
-
-def __len__(x, /):
- """
- Note: __len__ is a method of the array object.
- """
- pass
-
-def __lshift__(x1: array, x2: array, /) -> array:
- """
- Note: __lshift__ is a method of the array object.
- """
- pass
-
-def __lt__(x1: array, x2: array, /) -> array:
- """
- Note: __lt__ is a method of the array object.
- """
- pass
-
-def __matmul__(x1: array, x2: array, /) -> array:
- """
- Note: __matmul__ is a method of the array object.
- """
- pass
-
-def __mod__(x1: array, x2: array, /) -> array:
- """
- Note: __mod__ is a method of the array object.
- """
- pass
-
-def __mul__(x1: array, x2: array, /) -> array:
- """
- Note: __mul__ is a method of the array object.
- """
- pass
-
-def __ne__(x1: array, x2: array, /) -> array:
- """
- Note: __ne__ is a method of the array object.
- """
- pass
-
-def __neg__(x: array, /) -> array:
- """
- Note: __neg__ is a method of the array object.
- """
- pass
-
-def __or__(x1: array, x2: array, /) -> array:
- """
- Note: __or__ is a method of the array object.
- """
- pass
-
-def __pos__(x: array, /) -> array:
- """
- Note: __pos__ is a method of the array object.
- """
- pass
-
-def __pow__(x1: array, x2: array, /) -> array:
- """
- Note: __pow__ is a method of the array object.
- """
- pass
-
-def __rshift__(x1: array, x2: array, /) -> array:
- """
- Note: __rshift__ is a method of the array object.
- """
- pass
-
-def __setitem__(x, key, value, /):
- """
- Note: __setitem__ is a method of the array object.
- """
- pass
-
-def __sub__(x1: array, x2: array, /) -> array:
- """
- Note: __sub__ is a method of the array object.
- """
- pass
-
-def __truediv__(x1: array, x2: array, /) -> array:
- """
- Note: __truediv__ is a method of the array object.
- """
- pass
-
-def __xor__(x1: array, x2: array, /) -> array:
- """
- Note: __xor__ is a method of the array object.
- """
- pass
-
-# Note: dtype is an attribute of the array object.
-dtype = None
-
-# Note: device is an attribute of the array object.
-device = None
-
-# Note: ndim is an attribute of the array object.
-ndim = None
-
-# Note: shape is an attribute of the array object.
-shape = None
-
-# Note: size is an attribute of the array object.
-size = None
-
-# Note: T is an attribute of the array object.
-T = None
-
-__all__ = ['__abs__', '__add__', '__and__', '__eq__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
diff --git a/array_api_tests/function_stubs/constants.py b/array_api_tests/function_stubs/constants.py
deleted file mode 100644
index 602f0399..00000000
--- a/array_api_tests/function_stubs/constants.py
+++ /dev/null
@@ -1,22 +0,0 @@
-"""
-Function stubs for constants.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/constants.md
-"""
-
-from __future__ import annotations
-
-
-e = None
-
-inf = None
-
-nan = None
-
-pi = None
-
-__all__ = ['e', 'inf', 'nan', 'pi']
diff --git a/array_api_tests/function_stubs/creation_functions.py b/array_api_tests/function_stubs/creation_functions.py
deleted file mode 100644
index 7fa71090..00000000
--- a/array_api_tests/function_stubs/creation_functions.py
+++ /dev/null
@@ -1,48 +0,0 @@
-"""
-Function stubs for creation functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/creation_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import Optional, Tuple, Union, array, device, dtype
-
-def arange(start: Union[int, float], /, *, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def empty(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def empty_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def eye(N: int, /, *, M: Optional[int] = None, k: Optional[int] = 0, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def full_like(x: array, fill_value: Union[int, float], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def linspace(start: Union[int, float], stop: Union[int, float], num: int, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None, endpoint: bool = True) -> array:
- pass
-
-def ones(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def ones_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def zeros(shape: Union[int, Tuple[int, ...]], /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-def zeros_like(x: array, /, *, dtype: Optional[dtype] = None, device: Optional[device] = None) -> array:
- pass
-
-__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like']
diff --git a/array_api_tests/function_stubs/elementwise_functions.py b/array_api_tests/function_stubs/elementwise_functions.py
deleted file mode 100644
index 608aad43..00000000
--- a/array_api_tests/function_stubs/elementwise_functions.py
+++ /dev/null
@@ -1,180 +0,0 @@
-"""
-Function stubs for elementwise functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/elementwise_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import array
-
-def abs(x: array, /) -> array:
- pass
-
-def acos(x: array, /) -> array:
- pass
-
-def acosh(x: array, /) -> array:
- pass
-
-def add(x1: array, x2: array, /) -> array:
- pass
-
-def asin(x: array, /) -> array:
- pass
-
-def asinh(x: array, /) -> array:
- pass
-
-def atan(x: array, /) -> array:
- pass
-
-def atan2(x1: array, x2: array, /) -> array:
- pass
-
-def atanh(x: array, /) -> array:
- pass
-
-def bitwise_and(x1: array, x2: array, /) -> array:
- pass
-
-def bitwise_left_shift(x1: array, x2: array, /) -> array:
- pass
-
-def bitwise_invert(x: array, /) -> array:
- pass
-
-def bitwise_or(x1: array, x2: array, /) -> array:
- pass
-
-def bitwise_right_shift(x1: array, x2: array, /) -> array:
- pass
-
-def bitwise_xor(x1: array, x2: array, /) -> array:
- pass
-
-def ceil(x: array, /) -> array:
- pass
-
-def cos(x: array, /) -> array:
- pass
-
-def cosh(x: array, /) -> array:
- pass
-
-def divide(x1: array, x2: array, /) -> array:
- pass
-
-def equal(x1: array, x2: array, /) -> array:
- pass
-
-def exp(x: array, /) -> array:
- pass
-
-def expm1(x: array, /) -> array:
- pass
-
-def floor(x: array, /) -> array:
- pass
-
-def floor_divide(x1: array, x2: array, /) -> array:
- pass
-
-def greater(x1: array, x2: array, /) -> array:
- pass
-
-def greater_equal(x1: array, x2: array, /) -> array:
- pass
-
-def isfinite(x: array, /) -> array:
- pass
-
-def isinf(x: array, /) -> array:
- pass
-
-def isnan(x: array, /) -> array:
- pass
-
-def less(x1: array, x2: array, /) -> array:
- pass
-
-def less_equal(x1: array, x2: array, /) -> array:
- pass
-
-def log(x: array, /) -> array:
- pass
-
-def log1p(x: array, /) -> array:
- pass
-
-def log2(x: array, /) -> array:
- pass
-
-def log10(x: array, /) -> array:
- pass
-
-def logical_and(x1: array, x2: array, /) -> array:
- pass
-
-def logical_not(x: array, /) -> array:
- pass
-
-def logical_or(x1: array, x2: array, /) -> array:
- pass
-
-def logical_xor(x1: array, x2: array, /) -> array:
- pass
-
-def multiply(x1: array, x2: array, /) -> array:
- pass
-
-def negative(x: array, /) -> array:
- pass
-
-def not_equal(x1: array, x2: array, /) -> array:
- pass
-
-def positive(x: array, /) -> array:
- pass
-
-def pow(x1: array, x2: array, /) -> array:
- pass
-
-def remainder(x1: array, x2: array, /) -> array:
- pass
-
-def round(x: array, /) -> array:
- pass
-
-def sign(x: array, /) -> array:
- pass
-
-def sin(x: array, /) -> array:
- pass
-
-def sinh(x: array, /) -> array:
- pass
-
-def square(x: array, /) -> array:
- pass
-
-def sqrt(x: array, /) -> array:
- pass
-
-def subtract(x1: array, x2: array, /) -> array:
- pass
-
-def tan(x: array, /) -> array:
- pass
-
-def tanh(x: array, /) -> array:
- pass
-
-def trunc(x: array, /) -> array:
- pass
-
-__all__ = ['abs', 'acos', 'acosh', 'add', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_and', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'ceil', 'cos', 'cosh', 'divide', 'equal', 'exp', 'expm1', 'floor', 'floor_divide', 'greater', 'greater_equal', 'isfinite', 'isinf', 'isnan', 'less', 'less_equal', 'log', 'log1p', 'log2', 'log10', 'logical_and', 'logical_not', 'logical_or', 'logical_xor', 'multiply', 'negative', 'not_equal', 'positive', 'pow', 'remainder', 'round', 'sign', 'sin', 'sinh', 'square', 'sqrt', 'subtract', 'tan', 'tanh', 'trunc']
diff --git a/array_api_tests/function_stubs/linear_algebra_functions.py b/array_api_tests/function_stubs/linear_algebra_functions.py
deleted file mode 100644
index 7bbd952f..00000000
--- a/array_api_tests/function_stubs/linear_algebra_functions.py
+++ /dev/null
@@ -1,82 +0,0 @@
-"""
-Function stubs for linear algebra functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/linear_algebra_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import Literal, Optional, Tuple, Union, array
-from .constants import inf
-
-def cholesky():
- pass
-
-def cross(x1: array, x2: array, /, *, axis: int = -1) -> array:
- pass
-
-def det(x: array, /) -> array:
- pass
-
-def diagonal(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
- pass
-
-def dot():
- pass
-
-def eig():
- pass
-
-def eigvalsh():
- pass
-
-def einsum():
- pass
-
-def inv(x: array, /) -> array:
- pass
-
-def lstsq():
- pass
-
-def matmul():
- pass
-
-def matrix_power():
- pass
-
-def matrix_rank():
- pass
-
-def norm(x: array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float, Literal[inf, -inf, 'fro', 'nuc']]] = None) -> array:
- pass
-
-def outer(x1: array, x2: array, /) -> array:
- pass
-
-def pinv():
- pass
-
-def qr():
- pass
-
-def slogdet():
- pass
-
-def solve():
- pass
-
-def svd():
- pass
-
-def trace(x: array, /, *, axis1: int = 0, axis2: int = 1, offset: int = 0) -> array:
- pass
-
-def transpose(x: array, /, *, axes: Optional[Tuple[int, ...]] = None) -> array:
- pass
-
-__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'dot', 'eig', 'eigvalsh', 'einsum', 'inv', 'lstsq', 'matmul', 'matrix_power', 'matrix_rank', 'norm', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'trace', 'transpose']
diff --git a/array_api_tests/function_stubs/manipulation_functions.py b/array_api_tests/function_stubs/manipulation_functions.py
deleted file mode 100644
index f4055e44..00000000
--- a/array_api_tests/function_stubs/manipulation_functions.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""
-Function stubs for manipulation functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/manipulation_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import Optional, Tuple, Union, array
-
-def concat(arrays: Tuple[array], /, *, axis: Optional[int] = 0) -> array:
- pass
-
-def expand_dims(x: array, axis: int, /) -> array:
- pass
-
-def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
- pass
-
-def reshape(x: array, shape: Tuple[int, ...], /) -> array:
- pass
-
-def roll(x: array, shift: Union[int, Tuple[int, ...]], /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
- pass
-
-def squeeze(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> array:
- pass
-
-def stack(arrays: Tuple[array], /, *, axis: int = 0) -> array:
- pass
-
-__all__ = ['concat', 'expand_dims', 'flip', 'reshape', 'roll', 'squeeze', 'stack']
diff --git a/array_api_tests/function_stubs/searching_functions.py b/array_api_tests/function_stubs/searching_functions.py
deleted file mode 100644
index 10c072c1..00000000
--- a/array_api_tests/function_stubs/searching_functions.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""
-Function stubs for searching functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/searching_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import Tuple, array
-
-def argmax(x: array, /, *, axis: int = None, keepdims: bool = False) -> array:
- pass
-
-def argmin(x: array, /, *, axis: int = None, keepdims: bool = False) -> array:
- pass
-
-def nonzero(x: array, /) -> Tuple[array, ...]:
- pass
-
-def where(condition: array, x1: array, x2: array, /) -> array:
- pass
-
-__all__ = ['argmax', 'argmin', 'nonzero', 'where']
diff --git a/array_api_tests/function_stubs/set_functions.py b/array_api_tests/function_stubs/set_functions.py
deleted file mode 100644
index 048dbcf6..00000000
--- a/array_api_tests/function_stubs/set_functions.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""
-Function stubs for set functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/set_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import Tuple, Union, array
-
-def unique(x: array, /, *, return_counts: bool = False, return_index: bool = False, return_inverse: bool = False, sorted: bool = True) -> Union[array, Tuple[array, ...]]:
- pass
-
-__all__ = ['unique']
diff --git a/array_api_tests/function_stubs/sorting_functions.py b/array_api_tests/function_stubs/sorting_functions.py
deleted file mode 100644
index 2040de54..00000000
--- a/array_api_tests/function_stubs/sorting_functions.py
+++ /dev/null
@@ -1,21 +0,0 @@
-"""
-Function stubs for sorting functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/sorting_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import array
-
-def argsort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array:
- pass
-
-def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True) -> array:
- pass
-
-__all__ = ['argsort', 'sort']
diff --git a/array_api_tests/function_stubs/statistical_functions.py b/array_api_tests/function_stubs/statistical_functions.py
deleted file mode 100644
index 3a0cc7bc..00000000
--- a/array_api_tests/function_stubs/statistical_functions.py
+++ /dev/null
@@ -1,36 +0,0 @@
-"""
-Function stubs for statistical functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/statistical_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import Optional, Tuple, Union, array
-
-def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- pass
-
-def mean(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- pass
-
-def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- pass
-
-def prod(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- pass
-
-def std(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:
- pass
-
-def sum(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- pass
-
-def var(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False) -> array:
- pass
-
-__all__ = ['max', 'mean', 'min', 'prod', 'std', 'sum', 'var']
diff --git a/array_api_tests/function_stubs/utility_functions.py b/array_api_tests/function_stubs/utility_functions.py
deleted file mode 100644
index ae427401..00000000
--- a/array_api_tests/function_stubs/utility_functions.py
+++ /dev/null
@@ -1,21 +0,0 @@
-"""
-Function stubs for utility functions.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/utility_functions.md
-"""
-
-from __future__ import annotations
-
-from ._types import Optional, Tuple, Union, array
-
-def all(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- pass
-
-def any(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array:
- pass
-
-__all__ = ['all', 'any']
diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py
index 288b005c..e1df108c 100644
--- a/array_api_tests/hypothesis_helpers.py
+++ b/array_api_tests/hypothesis_helpers.py
@@ -1,114 +1,524 @@
-from functools import reduce
-from operator import mul
-from math import sqrt
+from __future__ import annotations
-from hypothesis.strategies import (lists, integers, builds, sampled_from,
- shared, tuples as hypotheses_tuples,
- floats, just, composite, one_of, none,
- booleans)
-from hypothesis import assume
+import os
+import re
+from contextlib import contextmanager
+from functools import wraps
+import math
+import struct
+from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
+from hypothesis import assume, reject
+from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
+ integers, complex_numbers, just, lists, none, one_of,
+ sampled_from, shared, builds, nothing, permutations)
+
+from . import _array_module as xp, api_version
+from . import array_helpers as ah
+from . import dtype_helpers as dh
+from . import shape_helpers as sh
+from . import xps
+from ._array_module import _UndefinedStub
+from ._array_module import bool as bool_dtype
+from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
+from .stubs import category_to_funcs
from .pytest_helpers import nargs
-from .array_helpers import dtype_ranges
-from ._array_module import (_integer_dtypes, _floating_dtypes,
- _numeric_dtypes, _dtypes, ones, full, float32,
- float64, bool as bool_dtype)
-from . import _array_module
+from .typing import Array, DataType, Scalar, Shape
+
+
+def _float32ify(n: Union[int, float]) -> float:
+ n = float(n)
+ return struct.unpack("!f", struct.pack("!f", n))[0]
+
+
+@wraps(xps.from_dtype)
+def from_dtype(dtype, **kwargs) -> SearchStrategy[Scalar]:
+ """xps.from_dtype() without the crazy large numbers."""
+ if dtype == xp.bool:
+ return xps.from_dtype(dtype, **kwargs)
+
+ if dtype in dh.complex_dtypes:
+ component_dtype = dh.dtype_components[dtype]
+ else:
+ component_dtype = dtype
+
+ min_, max_ = dh.dtype_ranges[component_dtype]
+
+ if "min_value" not in kwargs.keys() and min_ != 0:
+ assert min_ < 0 # sanity check
+ min_value = -1 * math.floor(math.sqrt(abs(min_)))
+ if component_dtype == xp.float32:
+ min_value = _float32ify(min_value)
+ kwargs["min_value"] = min_value
+ if "max_value" not in kwargs.keys():
+ assert max_ > 0 # sanity check
+ max_value = math.floor(math.sqrt(max_))
+ if component_dtype == xp.float32:
+ max_value = _float32ify(max_value)
+ kwargs["max_value"] = max_value
+
+ if dtype in dh.complex_dtypes:
+ component_strat = xps.from_dtype(dh.dtype_components[dtype], **kwargs)
+ return builds(complex, component_strat, component_strat)
+ else:
+ return xps.from_dtype(dtype, **kwargs)
+
+
+@wraps(xps.arrays)
+def arrays_no_scalars(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
+ """xps.arrays() without the crazy large numbers."""
+ if isinstance(dtype, SearchStrategy):
+ return dtype.flatmap(lambda d: arrays(d, *args, elements=elements, **kwargs))
+
+ if elements is None:
+ elements = from_dtype(dtype)
+ elif isinstance(elements, Mapping):
+ elements = from_dtype(dtype, **elements)
+
+ return xps.arrays(dtype, *args, elements=elements, **kwargs)
+
+
+def _f(a, flag):
+ return a[()] if a.ndim==0 and flag else a
+
+
+@wraps(xps.arrays)
+def arrays(dtype, *args, elements=None, **kwargs) -> SearchStrategy[Array]:
+ """xps.arrays() without the crazy large numbers. Also draw 0D arrays or numpy scalars.
+
+ Is only relevant for numpy: on all other libraries, array[()] is no-op.
+ """
+ return builds(_f, arrays_no_scalars(dtype, *args, elements=elements, **kwargs), booleans())
+
+
+_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.real_float_dtypes, dh.complex_dtypes]
+_sorted_dtypes = [d for category in _dtype_categories for d in category]
+
+def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
+ dtype1, dtype2 = dtype_pair
+ if dtype1 == dtype2:
+ return _sorted_dtypes.index(dtype1)
+ key = len(_sorted_dtypes)
+ rank1 = _sorted_dtypes.index(dtype1)
+ rank2 = _sorted_dtypes.index(dtype2)
+ for category in _dtype_categories:
+ if dtype1 in category and dtype2 in category:
+ break
+ else:
+ key += len(_sorted_dtypes) ** 2
+ key += 2 * (rank1 + rank2)
+ if rank1 > rank2:
+ key += 1
+ return key
+
+_promotable_dtypes = list(dh.promotion_table.keys())
+_promotable_dtypes = [
+ (d1, d2) for d1, d2 in _promotable_dtypes
+ if not isinstance(d1, _UndefinedStub) or not isinstance(d2, _UndefinedStub)
+]
+promotable_dtypes: List[Tuple[DataType, DataType]] = sorted(_promotable_dtypes, key=_dtypes_sorter)
+
+def mutually_promotable_dtypes(
+ max_size: Optional[int] = 2,
+ *,
+ dtypes: Sequence[DataType] = dh.all_dtypes,
+) -> SearchStrategy[Tuple[DataType, ...]]:
+ dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
+ assert len(dtypes) > 0, "all dtypes undefined" # sanity check
+ if max_size == 2:
+ return sampled_from(
+ [(i, j) for i, j in promotable_dtypes if i in dtypes and j in dtypes]
+ )
+ if isinstance(max_size, int) and max_size < 2:
+ raise ValueError(f'{max_size=} should be >=2')
+ strats = []
+ category_samples = {
+ category: [d for d in dtypes if d in category] for category in _dtype_categories
+ }
+ for samples in category_samples.values():
+ if len(samples) > 0:
+ strat = lists(sampled_from(samples), min_size=2, max_size=max_size)
+ strats.append(strat)
+ if len(category_samples[dh.uint_dtypes]) > 0 and len(category_samples[dh.int_dtypes]) > 0:
+ mixed_samples = category_samples[dh.uint_dtypes] + category_samples[dh.int_dtypes]
+ strat = lists(sampled_from(mixed_samples), min_size=2, max_size=max_size)
+ if xp.uint64 in mixed_samples:
+ strat = strat.filter(
+ lambda l: not (xp.uint64 in l and any(d in dh.int_dtypes for d in l))
+ )
+ return one_of(strats).map(tuple)
+
+
+@composite
+def pair_of_mutually_promotable_dtypes(draw, max_size=2, *, dtypes=dh.all_dtypes):
+ sample = draw(mutually_promotable_dtypes( max_size, dtypes=dtypes))
+ permuted = draw(permutations(sample))
+ return sample, tuple(permuted)
+
-from .function_stubs import elementwise_functions
+class OnewayPromotableDtypes(NamedTuple):
+ input_dtype: DataType
+ result_dtype: DataType
-integer_dtype_objects = [getattr(_array_module, t) for t in _integer_dtypes]
-floating_dtype_objects = [getattr(_array_module, t) for t in _floating_dtypes]
-numeric_dtype_objects = [getattr(_array_module, t) for t in _numeric_dtypes]
-dtype_objects = [getattr(_array_module, t) for t in _dtypes]
-integer_dtypes = sampled_from(integer_dtype_objects)
-floating_dtypes = sampled_from(floating_dtype_objects)
-numeric_dtypes = sampled_from(numeric_dtype_objects)
-dtypes = sampled_from(dtype_objects)
+@composite
+def oneway_promotable_dtypes(
+ draw, dtypes: Sequence[DataType]
+) -> OnewayPromotableDtypes:
+ """Return a strategy for input dtypes that promote to result dtypes."""
+ d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes))
+ result_dtype = dh.result_type(d1, d2)
+ if d1 == result_dtype:
+ return OnewayPromotableDtypes(d2, d1)
+ elif d2 == result_dtype:
+ return OnewayPromotableDtypes(d1, d2)
+ else:
+ reject()
+
+
+class OnewayBroadcastableShapes(NamedTuple):
+ input_shape: Shape
+ result_shape: Shape
+
+
+@composite
+def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
+ """Return a strategy for input shapes that broadcast to result shapes."""
+ result_shape = draw(shapes(min_side=1))
+ input_shape = draw(
+ xps.broadcastable_shapes(
+ result_shape,
+ # Override defaults so bad shapes are less likely to be generated.
+ max_side=None if result_shape == () else max(result_shape),
+ max_dims=len(result_shape),
+ ).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
+ )
+ return OnewayBroadcastableShapes(input_shape, result_shape)
+
+
+# Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
+# ARRAY_API_TESTS_SKIP_DTYPES
+all_dtypes = sampled_from(_sorted_dtypes)
+int_dtypes = sampled_from(dh.all_int_dtypes)
+uint_dtypes = sampled_from(dh.uint_dtypes)
+real_dtypes = sampled_from(dh.real_dtypes)
+# Warning: The hypothesis "floating_dtypes" is what we call
+# "real_floating_dtypes"
+floating_dtypes = sampled_from(dh.all_float_dtypes)
+real_floating_dtypes = sampled_from(dh.real_float_dtypes)
+numeric_dtypes = sampled_from(dh.numeric_dtypes)
+# Note: this always returns complex dtypes, even if api_version < 2022.12
+complex_dtypes: SearchStrategy[Any] = sampled_from(dh.complex_dtypes) if dh.complex_dtypes else nothing()
+
+def all_floating_dtypes() -> SearchStrategy[DataType]:
+ strat = floating_dtypes
+ if api_version >= "2022.12" and not complex_dtypes.is_empty:
+ strat |= complex_dtypes
+ return strat
-shared_dtypes = shared(dtypes)
# shared() allows us to draw either the function or the function name and they
# will both correspond to the same function.
# TODO: Extend this to all functions, not just elementwise
-elementwise_functions_names = shared(sampled_from(elementwise_functions.__all__))
+elementwise_functions_names = shared(sampled_from([f.__name__ for f in category_to_funcs["elementwise"]]))
array_functions_names = elementwise_functions_names
multiarg_array_functions_names = array_functions_names.filter(
lambda func_name: nargs(func_name) > 1)
elementwise_function_objects = elementwise_functions_names.map(
- lambda i: getattr(_array_module, i))
+ lambda i: getattr(xp, i))
array_functions = elementwise_function_objects
multiarg_array_functions = multiarg_array_functions_names.map(
- lambda i: getattr(_array_module, i))
+ lambda i: getattr(xp, i))
# Limit the total size of an array shape
-MAX_ARRAY_SIZE = 10000
+MAX_ARRAY_SIZE = int(os.environ.get("ARRAY_API_TESTS_MAX_ARRAY_SIZE", 1024))
# Size to use for 2-dim arrays
-SQRT_MAX_ARRAY_SIZE = int(sqrt(MAX_ARRAY_SIZE))
+SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
-# np.prod and others have overflow and math.prod is Python 3.8+ only
-def prod(seq):
- return reduce(mul, seq, 1)
# hypotheses.strategies.tuples only generates tuples of a fixed size
def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False):
return lists(elements, min_size=min_size, max_size=max_size,
unique_by=unique_by, unique=unique).map(tuple)
-shapes = tuples(integers(0, 10)).filter(lambda shape: prod(shape) < MAX_ARRAY_SIZE)
-
# Use this to avoid memory errors with NumPy.
# See https://github.com/numpy/numpy/issues/15753
+# Note, the hypothesis default for max_dims is min_dims + 2 (i.e., 0 + 2)
+def shapes(**kw):
+ kw.setdefault('min_dims', 0)
+ kw.setdefault('min_side', 0)
+ return xps.array_shapes(**kw).filter(
+ lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE
+ )
-shapes = tuples(integers(0, 10)).filter(
- lambda shape: prod([i for i in shape if i]) < MAX_ARRAY_SIZE)
+def _factorize(n: int) -> List[int]:
+ # Simple prime factorization. Only needs to handle n ~ MAX_ARRAY_SIZE
+ factors = []
+ while n % 2 == 0:
+ factors.append(2)
+ n //= 2
-sizes = integers(0, MAX_ARRAY_SIZE)
-sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
+ for i in range(3, int(math.sqrt(n)) + 1, 2):
+ while n % i == 0:
+ factors.append(i)
+ n //= i
+
+ if n > 1: # n is a prime number greater than 2
+ factors.append(n)
-ones_arrays = builds(ones, shapes, dtype=shared_dtypes)
+ return factors
-nonbroadcastable_ones_array_two_args = hypotheses_tuples(ones_arrays, ones_arrays)
+MAX_SIDE = MAX_ARRAY_SIZE // 64
+# NumPy only supports up to 32 dims. TODO: Get this from the new inspection APIs
+MAX_DIMS = min(MAX_ARRAY_SIZE // MAX_SIDE, 32)
-# TODO: Generate general arrays here, rather than just scalars.
-numeric_arrays = builds(full, just((1,)), floats())
@composite
-def shared_scalars(draw):
+def reshape_shapes(draw, arr_shape, ndims=integers(1, MAX_DIMS)):
"""
- Strategy to generate a scalar that matches the dtype from shared_dtypes
+ Generate shape tuples whose product equals the product of array_shape.
"""
- dtype = draw(shared_dtypes)
- if dtype in dtype_ranges:
- m, M = dtype_ranges[dtype]
+ shape = draw(arr_shape)
+
+ array_size = math.prod(shape)
+
+ n_dims = draw(ndims)
+
+ # Handle special cases
+ if array_size == 0:
+ # Generate a random tuple, and ensure at least one of the entries is 0
+ result = list(draw(shapes(min_dims=n_dims, max_dims=n_dims)))
+ pos = draw(integers(0, n_dims - 1))
+ result[pos] = 0
+ return tuple(result)
+
+ if array_size == 1:
+ return tuple(1 for _ in range(n_dims))
+
+ # Get prime factorization
+ factors = _factorize(array_size)
+
+ # Distribute prime factors randomly
+ result = [1] * n_dims
+ for factor in factors:
+ pos = draw(integers(0, n_dims - 1))
+ result[pos] *= factor
+
+ assert math.prod(result) == array_size
+
+ # An element of the reshape tuple can be -1, which means it is a stand-in
+ # for the remaining factors.
+ if draw(booleans()):
+ pos = draw(integers(0, n_dims - 1))
+ result[pos] = -1
+
+ return tuple(result)
+
+one_d_shapes = xps.array_shapes(min_dims=1, max_dims=1, min_side=0, max_side=SQRT_MAX_ARRAY_SIZE)
+
+# Matrix shapes assume stacks of matrices
+@composite
+def matrix_shapes(draw, stack_shapes=shapes()):
+ stack_shape = draw(stack_shapes)
+ mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2))
+ shape = stack_shape + mat_shape
+ assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE)
+ return shape
+
+square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
+
+@composite
+def finite_matrices(draw, shape=matrix_shapes()):
+ return draw(arrays(dtype=floating_dtypes,
+ shape=shape,
+ elements=dict(allow_nan=False,
+ allow_infinity=False)))
+
+rtol_shared_matrix_shapes = shared(matrix_shapes())
+# Should we set a max_value here?
+_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
+rtols = one_of(floats(**_rtol_float_kw),
+ arrays(dtype=real_floating_dtypes,
+ shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
+ elements=_rtol_float_kw))
+
+
+def mutually_broadcastable_shapes(
+ num_shapes: int,
+ *,
+ base_shape: Shape = (),
+ min_dims: int = 0,
+ max_dims: Optional[int] = None,
+ min_side: int = 0,
+ max_side: Optional[int] = None,
+) -> SearchStrategy[Tuple[Shape, ...]]:
+ if max_dims is None:
+ max_dims = min(max(len(base_shape), min_dims) + 5, 32)
+ if max_side is None:
+ max_side = max(base_shape[-max_dims:] + (min_side,)) + 5
+ return (
+ xps.mutually_broadcastable_shapes(
+ num_shapes,
+ base_shape=base_shape,
+ min_dims=min_dims,
+ max_dims=max_dims,
+ min_side=min_side,
+ max_side=max_side,
+ )
+ .map(lambda BS: BS.input_shapes)
+ .filter(lambda shapes: all(
+ math.prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
+ ))
+ )
+
+two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2)
+
+# TODO: Add support for complex Hermitian matrices
+@composite
+def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.):
+ # for now, only generate elements from (1, bound); TODO: restore
+ # generating from (-bound, -1/bound).or.(1/bound, bound)
+ # Note that using `assume` triggers a HealthCheck for filtering too much.
+ shape = draw(square_matrix_shapes)
+ dtype = draw(dtypes)
+ if not isinstance(finite, bool):
+ finite = draw(finite)
+ if finite:
+ elements = {'allow_nan': False, 'allow_infinity': False,
+ 'min_value': 1, 'max_value': bound}
+ else:
+ elements = None
+ a = draw(arrays(dtype=dtype, shape=shape, elements=elements))
+ at = ah._matrix_transpose(a)
+ H = (a + at)*0.5
+ if finite:
+ assume(not xp.any(xp.isinf(H)))
+ return H
+
+@composite
+def positive_definite_matrices(draw, dtypes=floating_dtypes):
+ # For now just generate stacks of identity matrices
+ # TODO: Generate arbitrary positive definite matrices, for instance, by
+ # using something like
+ # https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
+ base_shape = draw(shapes())
+ n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
+ shape = base_shape + (n, n)
+ assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE)
+ dtype = draw(dtypes)
+ return broadcast_to(eye(n, dtype=dtype), shape)
+
+@composite
+def invertible_matrices(draw, dtypes=floating_dtypes, stack_shapes=shapes()):
+ # For now, just generate stacks of diagonal matrices.
+ stack_shape = draw(stack_shapes)
+ n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),)
+ dtype = draw(dtypes)
+ elements = one_of(
+ from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False),
+ from_dtype(dtype, max_value=-0.5, allow_nan=False, allow_infinity=False),
+ )
+ d = draw(arrays(dtype, shape=(*stack_shape, 1, n), elements=elements))
+
+ # Functions that require invertible matrices may do anything when it is
+ # singular, including raising an exception, so we make sure the diagonals
+ # are sufficiently nonzero to avoid any numerical issues.
+ assert xp.all(xp.abs(d) >= 0.5)
+
+ diag_mask = xp.arange(n) == xp.reshape(xp.arange(n), (n, 1))
+ return xp.where(diag_mask, d, xp.zeros_like(d))
+
+# TODO: Better name
+@composite
+def two_broadcastable_shapes(draw):
+ """
+ This will produce two shapes (shape1, shape2) such that shape2 can be
+ broadcast to shape1.
+ """
+ shape1, shape2 = draw(two_mutually_broadcastable_shapes)
+ assume(sh.broadcast_shapes(shape1, shape2) == shape1)
+ return (shape1, shape2)
+
+sizes = integers(0, MAX_ARRAY_SIZE)
+sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)
+
+numeric_arrays = arrays(
+ dtype=shared(floating_dtypes, key='dtypes'),
+ shape=shared(xps.array_shapes(), key='shapes'),
+)
+
+@composite
+def scalars(draw, dtypes, finite=False, **kwds):
+ """
+ Strategy to generate a scalar that matches a dtype strategy
+
+ dtypes should be one of the shared_* dtypes strategies.
+ """
+ dtype = draw(dtypes)
+ mM = kwds.pop('mM', None)
+ if dh.is_int_dtype(dtype):
+ if mM is None:
+ m, M = dh.dtype_ranges[dtype]
+ else:
+ m, M = mM
return draw(integers(m, M))
elif dtype == bool_dtype:
return draw(booleans())
elif dtype == float64:
- return draw(floats())
+ if finite:
+ return draw(floats(allow_nan=False, allow_infinity=False, **kwds))
+ return draw(floats(), **kwds)
elif dtype == float32:
- return draw(floats(width=32))
+ if finite:
+ return draw(floats(width=32, allow_nan=False, allow_infinity=False, **kwds))
+ return draw(floats(width=32, **kwds))
+ elif dtype == complex64:
+ if finite:
+ return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
+ return draw(complex_numbers(width=32))
+ elif dtype == complex128:
+ if finite:
+ return draw(complex_numbers(allow_nan=False, allow_infinity=False))
+ return draw(complex_numbers())
else:
raise ValueError(f"Unrecognized dtype {dtype}")
@composite
-def integer_indices(draw, sizes):
+def array_scalars(draw, dtypes):
+ dtype = draw(dtypes)
+ return full((), draw(scalars(just(dtype))), dtype=dtype)
+
+@composite
+def python_integer_indices(draw, sizes):
size = draw(sizes)
if size == 0:
assume(False)
return draw(integers(-size, size - 1))
+@composite
+def integer_indices(draw, sizes):
+ # Return either a Python integer or a 0-D array with some integer dtype
+ idx = draw(python_integer_indices(sizes))
+ dtype = draw(int_dtypes | uint_dtypes)
+ m, M = dh.dtype_ranges[dtype]
+ if m <= idx <= M:
+ return draw(one_of(just(idx),
+ just(full((), idx, dtype=dtype))))
+ return idx
+
@composite
def slices(draw, sizes):
size = draw(sizes)
# The spec does not specify out of bounds behavior.
- start = draw(one_of(integers(-size, max(0, size-1)), none()))
- stop = draw(one_of(integers(-size, size)), none())
max_step_size = draw(integers(1, max(1, size)))
step = draw(one_of(integers(-max_step_size, -1), integers(1, max_step_size), none()))
+ start = draw(one_of(integers(-size, size), none()))
+ if step is None or step > 0:
+ stop = draw(one_of(integers(-size, size)), none())
+ else:
+ stop = draw(one_of(integers(-size - 1, size - 1)), none())
s = slice(start, stop, step)
l = list(range(size))
sliced_list = l[s]
@@ -129,18 +539,139 @@ def multiaxis_indices(draw, shapes):
# Generate tuples no longer than the shape, with indices corresponding to
# each dimension.
shape = draw(shapes)
- guard = draw(tuples(just(object()), max_size=len(shape)))
+ n_entries = draw(integers(0, len(shape)))
# from hypothesis import note
- # note(f"multiaxis_indices guard: {guard}")
+ # note(f"multiaxis_indices n_entries: {n_entries}")
- for size, _ in zip(shape, guard):
- res.append(draw(one_of(
+ k = 0
+ for i in range(n_entries):
+ size = shape[k]
+ idx = draw(one_of(
integer_indices(just(size)),
slices(just(size)),
- just(...))))
+ just(...)))
+ if idx is ... and k >= 0:
+ # If there is an ellipsis, index from the end of the shape
+ k = k - n_entries
+ k += 1
+ res.append(idx)
# Sometimes add more entries than necessary to test this.
- if len(guard) == len(shape) and ... not in res:
- # note("Adding extra")
- extra = draw(lists(one_of(integer_indices(sizes), slices(sizes)), min_size=0, max_size=3))
- res += extra
+
+ # Avoid using 'in', which might do == on an array.
+ res_has_ellipsis = any(i is ... for i in res)
+ if not res_has_ellipsis:
+ if n_entries < len(shape):
+ # The spec requires either an ellipsis or exactly as many indices
+ # as dimensions.
+ assume(False)
+ elif n_entries == len(shape):
+ # note("Adding extra")
+ extra = draw(lists(one_of(integer_indices(sizes), slices(sizes)), min_size=0, max_size=3))
+ res += extra
return tuple(res)
+
+
+def two_mutual_arrays(
+ dtypes: Sequence[DataType] = dh.all_dtypes,
+ two_shapes: SearchStrategy[Tuple[Shape, Shape]] = two_mutually_broadcastable_shapes,
+) -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]:
+ if not isinstance(dtypes, Sequence):
+ raise TypeError(f"{dtypes=} not a sequence")
+ dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
+ assert len(dtypes) > 0 # sanity check
+ mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtypes))
+ mutual_shapes = shared(two_shapes)
+ arrays1 = arrays(
+ dtype=mutual_dtypes.map(lambda pair: pair[0]),
+ shape=mutual_shapes.map(lambda pair: pair[0]),
+ )
+ arrays2 = arrays(
+ dtype=mutual_dtypes.map(lambda pair: pair[1]),
+ shape=mutual_shapes.map(lambda pair: pair[1]),
+ )
+ return arrays1, arrays2
+
+
+@composite
+def array_and_py_scalar(draw, dtypes, mM=None, positive=False):
+ """Draw a pair: (array, scalar) or (scalar, array)."""
+ dtype = draw(sampled_from(dtypes))
+
+ scalar_var = draw(scalars(just(dtype), finite=True, mM=mM))
+ if positive:
+ assume (scalar_var > 0)
+
+ elements={}
+ if dtype in dh.real_float_dtypes:
+ elements = {'allow_nan': False, 'allow_infinity': False,
+ 'min_value': 1.0 / (2<<5), 'max_value': 2<<5}
+ if positive:
+ elements = {'min_value': 0}
+ array_var = draw(arrays(dtype, shape=shapes(min_dims=1), elements=elements))
+
+ if draw(booleans()):
+ return scalar_var, array_var
+ else:
+ return array_var, scalar_var
+
+
+@composite
+def kwargs(draw, **kw):
+ """
+ Strategy for keyword arguments
+
+ For a signature like f(x, /, dtype=None, val=1) use
+
+ @given(x=arrays(), kw=kwargs(a=none() | dtypes, val=integers()))
+ def test_f(x, kw):
+ res = f(x, **kw)
+
+ kw may omit the keyword argument, meaning the default for f will be used.
+
+ """
+ result = {}
+ for k, strat in kw.items():
+ if draw(booleans()):
+ result[k] = draw(strat)
+ return result
+
+
+class KVD(NamedTuple):
+ keyword: str
+ value: Any
+ default: Any
+
+
+@composite
+def specified_kwargs(draw, *keys_values_defaults: KVD):
+ """Generates valid kwargs given expected defaults.
+
+ When we can't realistically use hh.kwargs() and thus test whether xp infact
+ defaults correctly, this strategy lets us remove generated arguments if they
+ are of the default value anyway.
+ """
+ kw = {}
+ for keyword, value, default in keys_values_defaults:
+ if value is not default or draw(booleans()):
+ kw[keyword] = value
+ return kw
+
+
+def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]:
+ """Generate valid arguments for some axis keywords"""
+ axes_strats = [none()]
+ if ndim != 0:
+ axes_strats.append(integers(-ndim, ndim - 1))
+ axes_strats.append(xps.valid_tuple_axes(ndim))
+ return one_of(axes_strats)
+
+
+@contextmanager
+def reject_overflow():
+ try:
+ yield
+ except Exception as e:
+ if isinstance(e, OverflowError) or re.search("[Oo]verflow", str(e)):
+ reject()
+ else:
+ raise e
diff --git a/array_api_tests/meta_tests/test_array_helpers.py b/array_api_tests/meta_tests/test_array_helpers.py
deleted file mode 100644
index 1e9571a4..00000000
--- a/array_api_tests/meta_tests/test_array_helpers.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from ..array_helpers import exactly_equal, notequal
-import numpy as np
-
-# TODO: These meta-tests currently only work with NumPy
-
-def test_exactly_equal():
- a = np.array([0, 0., -0., -0., np.nan, np.nan, 1, 1])
- b = np.array([0, -1, -0., 0., np.nan, 1, 1, 2])
-
- res = np.array([True, False, True, False, True, False, True, False])
- np.testing.assert_equal(exactly_equal(a, b), res)
-
-def test_notequal():
- a = np.array([0, 0., -0., -0., np.nan, np.nan, 1, 1])
- b = np.array([0, -1, -0., 0., np.nan, 1, 1, 2])
-
- res = np.array([False, True, False, False, False, True, False, True])
- np.testing.assert_equal(notequal(a, b), res)
diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py
index 354ad50e..f6b7ae25 100644
--- a/array_api_tests/pytest_helpers.py
+++ b/array_api_tests/pytest_helpers.py
@@ -1,7 +1,37 @@
+import cmath
+import math
from inspect import getfullargspec
-from . import function_stubs
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
-def raises(exceptions, function, message=''):
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import shape_helpers as sh
+from . import stubs
+from . import xp as _xp
+from .typing import Array, DataType, Scalar, ScalarType, Shape
+
+__all__ = [
+ "raises",
+ "doesnt_raise",
+ "nargs",
+ "fmt_kw",
+ "is_pos_zero",
+ "is_neg_zero",
+ "assert_dtype",
+ "assert_kw_dtype",
+ "assert_default_float",
+ "assert_default_int",
+ "assert_default_index",
+ "assert_shape",
+ "assert_result_shape",
+ "assert_keepdimable_shape",
+ "assert_0d_equals",
+ "assert_fill",
+ "assert_array_elements",
+]
+
+
+def raises(exceptions, function, message=""):
"""
Like pytest.raises() except it allows custom error messages
"""
@@ -11,11 +41,14 @@ def raises(exceptions, function, message=''):
return
except Exception as e:
if message:
- raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions}): {message}")
+ raise AssertionError(
+ f"Unexpected exception {e!r} (expected {exceptions}): {message}"
+ )
raise AssertionError(f"Unexpected exception {e!r} (expected {exceptions})")
raise AssertionError(message)
-def doesnt_raise(function, message=''):
+
+def doesnt_raise(function, message=""):
"""
The inverse of raises().
@@ -31,5 +64,538 @@ def doesnt_raise(function, message=''):
raise AssertionError(f"Unexpected exception {e!r}: {message}")
raise AssertionError(f"Unexpected exception {e!r}")
+
def nargs(func_name):
- return len(getfullargspec(getattr(function_stubs, func_name)).args)
+ return len(getfullargspec(stubs.name_to_func[func_name]).args)
+
+
+def fmt_kw(kw: Dict[str, Any]) -> str:
+ return ", ".join(f"{k}={v}" for k, v in kw.items())
+
+
+def is_pos_zero(n: float) -> bool:
+ return n == 0 and math.copysign(1, n) == 1
+
+
+def is_neg_zero(n: float) -> bool:
+ return n == 0 and math.copysign(1, n) == -1
+
+
+def assert_dtype(
+ func_name: str,
+ *,
+ in_dtype: Union[DataType, Sequence[DataType]],
+ out_dtype: DataType,
+ expected: Optional[DataType] = None,
+ repr_name: str = "out.dtype",
+):
+ """
+ Assert the output dtype is as expected.
+
+ If expected=None, we infer the expected dtype as in_dtype, to test
+ out_dtype, e.g.
+
+ >>> x = xp.arange(5, dtype=xp.uint8)
+ >>> out = xp.abs(x)
+ >>> assert_dtype('abs', in_dtype=x.dtype, out_dtype=out.dtype)
+
+ is equivalent to
+
+ >>> assert out.dtype == xp.uint8
+
+ Or for multiple input dtypes, the expected dtype is inferred from their
+ resulting type promotion, e.g.
+
+ >>> x1 = xp.arange(5, dtype=xp.uint8)
+ >>> x2 = xp.arange(5, dtype=xp.uint16)
+ >>> out = xp.add(x1, x2)
+ >>> assert_dtype('add', in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
+
+ is equivalent to
+
+ >>> assert out.dtype == xp.uint16
+
+ We can also specify the expected dtype ourselves, e.g.
+
+ >>> x = xp.arange(5, dtype=xp.int8)
+ >>> out = xp.sum(x)
+ >>> default_int = xp.asarray(0).dtype
+ >>> assert_dtype('sum', in_dtype=x, out_dtype=out.dtype, expected=default_int)
+
+ """
+ __tracebackhide__ = True
+ in_dtypes = in_dtype if isinstance(in_dtype, Sequence) and not isinstance(in_dtype, str) else [in_dtype]
+ f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
+ f_out_dtype = dh.dtype_to_name[out_dtype]
+ if expected is None:
+ expected = dh.result_type(*in_dtypes)
+ f_expected = dh.dtype_to_name[expected]
+ msg = (
+ f"{repr_name}={f_out_dtype}, but should be {f_expected} "
+ f"[{func_name}({f_in_dtypes})]"
+ )
+ assert out_dtype == expected, msg
+
+
+def assert_float_to_complex_dtype(
+ func_name: str, *, in_dtype: DataType, out_dtype: DataType
+):
+ if in_dtype == xp.float32:
+ expected = xp.complex64
+ else:
+ assert in_dtype == xp.float64 # sanity check
+ expected = xp.complex128
+ assert_dtype(
+ func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
+ )
+
+
+def assert_complex_to_float_dtype(
+ func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype"
+):
+ if in_dtype == xp.complex64:
+ expected = xp.float32
+ elif in_dtype == xp.complex128:
+ expected = xp.float64
+ else:
+ assert in_dtype in (xp.float32, xp.float64) # sanity check
+ expected = in_dtype
+ assert_dtype(
+ func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name
+ )
+
+
+def assert_kw_dtype(
+ func_name: str,
+ *,
+ kw_dtype: DataType,
+ out_dtype: DataType,
+):
+ """
+ Assert the output dtype is the passed keyword dtype, e.g.
+
+ >>> kw = {'dtype': xp.uint8}
+ >>> out = xp.ones(5, kw=kw)
+ >>> assert_kw_dtype('ones', kw_dtype=kw['dtype'], out_dtype=out.dtype)
+
+ """
+ __tracebackhide__ = True
+ f_kw_dtype = dh.dtype_to_name[kw_dtype]
+ f_out_dtype = dh.dtype_to_name[out_dtype]
+ msg = (
+ f"out.dtype={f_out_dtype}, but should be {f_kw_dtype} "
+ f"[{func_name}(dtype={f_kw_dtype})]"
+ )
+ assert out_dtype == kw_dtype, msg
+
+
+def assert_default_float(func_name: str, out_dtype: DataType):
+ """
+ Assert the output dtype is the default float, e.g.
+
+ >>> out = xp.ones(5)
+ >>> assert_default_float('ones', out.dtype)
+
+ """
+ __tracebackhide__ = True
+ f_dtype = dh.dtype_to_name[out_dtype]
+ f_default = dh.dtype_to_name[dh.default_float]
+ msg = (
+ f"out.dtype={f_dtype}, should be default "
+ f"floating-point dtype {f_default} [{func_name}()]"
+ )
+ assert out_dtype == dh.default_float, msg
+
+
+def assert_default_complex(func_name: str, out_dtype: DataType):
+ """
+ Assert the output dtype is the default complex, e.g.
+
+ >>> out = xp.asarray(4+2j)
+ >>> assert_default_complex('asarray', out.dtype)
+
+ """
+ __tracebackhide__ = True
+ f_dtype = dh.dtype_to_name[out_dtype]
+ f_default = dh.dtype_to_name[dh.default_complex]
+ msg = (
+ f"out.dtype={f_dtype}, should be default "
+ f"complex dtype {f_default} [{func_name}()]"
+ )
+ assert out_dtype == dh.default_complex, msg
+
+
+def assert_default_int(func_name: str, out_dtype: DataType):
+ """
+ Assert the output dtype is the default int, e.g.
+
+ >>> out = xp.full(5, 42)
+ >>> assert_default_int('full', out.dtype)
+
+ """
+ __tracebackhide__ = True
+ f_dtype = dh.dtype_to_name[out_dtype]
+ f_default = dh.dtype_to_name[dh.default_int]
+ msg = (
+ f"out.dtype={f_dtype}, should be default "
+ f"integer dtype {f_default} [{func_name}()]"
+ )
+ assert out_dtype == dh.default_int, msg
+
+
+def assert_default_index(func_name: str, out_dtype: DataType, repr_name="out.dtype"):
+ """
+ Assert the output dtype is the default index dtype, e.g.
+
+ >>> out = xp.argmax(xp.arange(5))
+ >>> assert_default_int('argmax', out.dtype)
+
+ """
+ __tracebackhide__ = True
+ f_dtype = dh.dtype_to_name[out_dtype]
+ msg = (
+ f"{repr_name}={f_dtype}, should be the default index dtype, "
+ f"which is either int32 or int64 [{func_name}()]"
+ )
+ assert out_dtype in (xp.int32, xp.int64), msg
+
+
+def assert_shape(
+ func_name: str,
+ *,
+ out_shape: Union[int, Shape],
+ expected: Union[int, Shape],
+ repr_name="out.shape",
+ kw: dict = {},
+):
+ """
+ Assert the output shape is as expected, e.g.
+
+ >>> out = xp.ones((3, 3, 3))
+ >>> assert_shape('ones', out_shape=out.shape, expected=(3, 3, 3))
+
+ """
+ __tracebackhide__ = True
+ if isinstance(out_shape, int):
+ out_shape = (out_shape,)
+ if isinstance(expected, int):
+ expected = (expected,)
+ msg = (
+ f"{repr_name}={out_shape}, but should be {expected} [{func_name}({fmt_kw(kw)})]"
+ )
+ assert out_shape == expected, msg
+
+
+def assert_result_shape(
+ func_name: str,
+ in_shapes: Sequence[Shape],
+ out_shape: Shape,
+ expected: Optional[Shape] = None,
+ *,
+ repr_name="out.shape",
+ kw: dict = {},
+):
+ """
+ Assert the output shape is as expected.
+
+ If expected=None, we infer the expected shape as the result of broadcasting
+ in_shapes, to test against out_shape, e.g.
+
+ >>> out = xp.add(xp.ones((3, 1)), xp.ones((1, 3)))
+ >>> assert_result_shape('add', in_shape=[(3, 1), (1, 3)], out_shape=out.shape)
+
+ is equivalent to
+
+ >>> assert out.shape == (3, 3)
+
+ """
+ __tracebackhide__ = True
+ if expected is None:
+ expected = sh.broadcast_shapes(*in_shapes)
+ f_in_shapes = " . ".join(str(s) for s in in_shapes)
+ f_sig = f" {f_in_shapes} "
+ if kw:
+ f_sig += f", {fmt_kw(kw)}"
+ msg = f"{repr_name}={out_shape}, but should be {expected} [{func_name}({f_sig})]"
+ assert out_shape == expected, msg
+
+
+def assert_keepdimable_shape(
+ func_name: str,
+ *,
+ in_shape: Shape,
+ out_shape: Shape,
+ axes: Tuple[int, ...],
+ keepdims: bool,
+ kw: dict = {},
+):
+ """
+ Assert the output shape from a keepdimable function is as expected, e.g.
+
+ >>> x = xp.asarray([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
+ >>> out1 = xp.max(x, keepdims=False)
+ >>> out2 = xp.max(x, keepdims=True)
+ >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out1.shape, axes=(0, 1), keepdims=False)
+ >>> assert_keepdimable_shape('max', in_shape=x.shape, out_shape=out2.shape, axes=(0, 1), keepdims=True)
+
+ is equivalent to
+
+ >>> assert out1.shape == ()
+ >>> assert out2.shape == (1, 1)
+
+ """
+ __tracebackhide__ = True
+ if keepdims:
+ shape = tuple(1 if axis in axes else side for axis, side in enumerate(in_shape))
+ else:
+ shape = tuple(side for axis, side in enumerate(in_shape) if axis not in axes)
+ assert_shape(func_name, out_shape=out_shape, expected=shape, kw=kw)
+
+
+def assert_0d_equals(
+ func_name: str,
+ *,
+ x_repr: str,
+ x_val: Array,
+ out_repr: str,
+ out_val: Array,
+ kw: dict = {},
+):
+ """
+ Assert a 0d array is as expected, e.g.
+
+ >>> x = xp.asarray([0, 1, 2])
+ >>> kw = {'copy': True}
+ >>> res = xp.asarray(x, **kw)
+ >>> res[0] = 42
+ >>> assert_0d_equals('asarray', x_repr='x[0]', x_val=x[0], out_repr='x[0]', out_val=res[0], kw=kw)
+
+ is equivalent to
+
+ >>> assert res[0] == x[0]
+
+ """
+ __tracebackhide__ = True
+ msg = (
+ f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
+ f"[{func_name}({fmt_kw(kw)})]"
+ )
+ if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val):
+ assert xp.isnan(x_val), msg
+ else:
+ assert x_val == out_val, msg
+
+
+def assert_scalar_equals(
+ func_name: str,
+ *,
+ type_: ScalarType,
+ idx: Shape,
+ out: Scalar,
+ expected: Scalar,
+ repr_name: str = "out",
+ kw: dict = {},
+):
+ """
+ Assert a 0d array, converted to a scalar, is as expected, e.g.
+
+ >>> x = xp.ones(5, dtype=xp.uint8)
+ >>> out = xp.sum(x)
+ >>> assert_scalar_equals('sum', type_int, out=(), out=int(out), expected=5)
+
+ is equivalent to
+
+ >>> assert int(out) == 5
+
+ NOTE: This function does *exact* comparison, even for floats. For
+ approximate float comparisons use assert_scalar_isclose
+ """
+ __tracebackhide__ = True
+ repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
+ f_func = f"{func_name}({fmt_kw(kw)})"
+ if type_ in [bool, int]:
+ msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
+ assert out == expected, msg
+ elif cmath.isnan(expected):
+ msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
+ assert cmath.isnan(out), msg
+ else:
+ msg = f"{repr_name}={out}, but should be {expected} [{f_func}]"
+ assert out == expected, msg
+
+
+def assert_scalar_isclose(
+ func_name: str,
+ *,
+ rel_tol: float = 0.25,
+ abs_tol: float = 1,
+ type_: ScalarType,
+ idx: Shape,
+ out: Scalar,
+ expected: Scalar,
+ repr_name: str = "out",
+ kw: dict = {},
+):
+ """
+ Assert a 0d array, converted to a scalar, is close to the expected value, e.g.
+
+ >>> x = xp.ones(5., dtype=xp.float64)
+ >>> out = xp.sum(x)
+ >>> assert_scalar_isclose('sum', type_int, out=(), out=int(out), expected=5.)
+
+ is equivalent to
+
+ >>> assert math.isclose(float(out) == 5.)
+
+ """
+ __tracebackhide__ = True
+ repr_name = repr_name if idx == () else f"{repr_name}[{idx}]"
+ f_func = f"{func_name}({fmt_kw(kw)})"
+ msg = f"{repr_name}={out}, but should be roughly {expected} [{f_func}]"
+ assert type_ in [float, complex] # Sanity check
+ assert cmath.isclose(out, expected, rel_tol=rel_tol, abs_tol=abs_tol), msg
+
+
+def assert_fill(
+ func_name: str,
+ *,
+ fill_value: Scalar,
+ dtype: DataType,
+ out: Array,
+ kw: dict = {},
+):
+ """
+ Assert all elements of an array is as expected, e.g.
+
+ >>> out = xp.full(5, 42, dtype=xp.uint8)
+ >>> assert_fill('full', fill_value=42, dtype=xp.uint8, out=out, kw=dict(shape=5))
+
+ is equivalent to
+
+ >>> assert xp.all(out == 42)
+
+ """
+ __tracebackhide__ = True
+ msg = f"out not filled with {fill_value} [{func_name}({fmt_kw(kw)})]\n{out=}"
+ if cmath.isnan(fill_value):
+ assert xp.all(xp.isnan(out)), msg
+ else:
+ assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
+
+
+def _has_functional_signbit() -> bool:
+ # signbit can be available but not implemented (e.g., in array-api-strict)
+ if not hasattr(_xp, "signbit"):
+ return False
+ try:
+ assert _xp.all(_xp.signbit(_xp.asarray(0.0)) == False)
+ except:
+ return False
+ return True
+
+def _real_float_strict_equals(out: Array, expected: Array) -> bool:
+ nan_mask = xp.isnan(out)
+ if not xp.all(nan_mask == xp.isnan(expected)):
+ return False
+ ignore_mask = nan_mask
+
+ # Test sign of zeroes if xp.signbit() available, otherwise ignore as it's
+ # not that big of a deal for the perf costs.
+ if _has_functional_signbit():
+ out_zero_mask = out == 0
+ out_sign_mask = _xp.signbit(out)
+ out_pos_zero_mask = out_zero_mask & out_sign_mask
+ out_neg_zero_mask = out_zero_mask & ~out_sign_mask
+ expected_zero_mask = expected == 0
+ expected_sign_mask = _xp.signbit(expected)
+ expected_pos_zero_mask = expected_zero_mask & expected_sign_mask
+ expected_neg_zero_mask = expected_zero_mask & ~expected_sign_mask
+ pos_zero_match = out_pos_zero_mask == expected_pos_zero_mask
+ neg_zero_match = out_neg_zero_mask == expected_neg_zero_mask
+ if not (xp.all(pos_zero_match) and xp.all(neg_zero_match)):
+ return False
+ ignore_mask |= out_zero_mask
+
+ replacement = xp.asarray(42, dtype=out.dtype) # i.e. an arbitrary non-zero value that equals itself
+ assert replacement == replacement # sanity check
+ match = xp.where(ignore_mask, replacement, out) == xp.where(ignore_mask, replacement, expected)
+ return xp.all(match)
+
+
+def _assert_float_element(at_out: Array, at_expected: Array, msg: str):
+ if xp.isnan(at_expected):
+ assert xp.isnan(at_out), msg
+ elif at_expected == 0.0 or at_expected == -0.0:
+ scalar_at_expected = float(at_expected)
+ scalar_at_out = float(at_out)
+ if is_pos_zero(scalar_at_expected):
+ assert is_pos_zero(scalar_at_out), msg
+ else:
+ assert is_neg_zero(scalar_at_expected) # sanity check
+ assert is_neg_zero(scalar_at_out), msg
+ else:
+ assert at_out == at_expected, msg
+
+
+def assert_array_elements(
+ func_name: str,
+ *,
+ out: Array,
+ expected: Array,
+ out_repr: str = "out",
+ kw: dict = {},
+):
+ """
+ Assert array elements are (strictly) as expected, e.g.
+
+ >>> x = xp.arange(5)
+ >>> out = xp.asarray(x)
+ >>> assert_array_elements('asarray', out=out, expected=x)
+
+ is equivalent to
+
+ >>> assert xp.all(out == x)
+
+ """
+ __tracebackhide__ = True
+ dh.result_type(out.dtype, expected.dtype) # sanity check
+ assert_shape(func_name, out_shape=out.shape, expected=expected.shape, kw=kw) # sanity check
+ f_func = f"[{func_name}({fmt_kw(kw)})]"
+
+ # First we try short-circuit for a successful assertion by using vectorised checks.
+ if out.dtype in dh.real_float_dtypes:
+ if _real_float_strict_equals(out, expected):
+ return
+ elif out.dtype in dh.complex_dtypes:
+ real_match = _real_float_strict_equals(xp.real(out), xp.real(expected))
+ imag_match = _real_float_strict_equals(xp.imag(out), xp.imag(expected))
+ if real_match and imag_match:
+ return
+ else:
+ match = out == expected
+ if xp.all(match):
+ return
+
+ # In case of mismatch, generate a more helpful error. Cycling through all indices is
+ # costly in some array api implementations, so we only do this in the case of a failure.
+ msg_template = "{}={}, but should be {} " + f_func
+ if out.dtype in dh.real_float_dtypes:
+ for idx in sh.ndindex(out.shape):
+ at_out = out[idx]
+ at_expected = expected[idx]
+ msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
+ _assert_float_element(at_out, at_expected, msg)
+ elif out.dtype in dh.complex_dtypes:
+ assert (out.dtype in dh.complex_dtypes) == (expected.dtype in dh.complex_dtypes)
+ for idx in sh.ndindex(out.shape):
+ at_out = out[idx]
+ at_expected = expected[idx]
+ msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
+ _assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
+ _assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
+ else:
+ for idx in sh.ndindex(out.shape):
+ at_out = out[idx]
+ at_expected = expected[idx]
+ msg = msg_template.format(sh.fmt_idx(out_repr, idx), at_out, at_expected)
+ assert at_out == at_expected, msg
diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py
new file mode 100644
index 00000000..52c6f3fc
--- /dev/null
+++ b/array_api_tests/shape_helpers.py
@@ -0,0 +1,177 @@
+import math
+from itertools import product
+from typing import Iterator, List, Optional, Sequence, Tuple, Union
+
+from ndindex import iter_indices as _iter_indices
+
+from .typing import AtomicIndex, Index, Scalar, Shape
+
+__all__ = [
+ "broadcast_shapes",
+ "normalize_axis",
+ "ndindex",
+ "axis_ndindex",
+ "axes_ndindex",
+ "reshape",
+ "fmt_idx",
+]
+
+
+class BroadcastError(ValueError):
+ """Shapes do not broadcast with eachother"""
+
+
+def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
+ """Broadcasts `shape1` and `shape2`"""
+ N1 = len(shape1)
+ N2 = len(shape2)
+ N = max(N1, N2)
+ shape = [None for _ in range(N)]
+ i = N - 1
+ while i >= 0:
+ n1 = N1 - N + i
+ if N1 - N + i >= 0:
+ d1 = shape1[n1]
+ else:
+ d1 = 1
+ n2 = N2 - N + i
+ if N2 - N + i >= 0:
+ d2 = shape2[n2]
+ else:
+ d2 = 1
+
+ if d1 == 1:
+ shape[i] = d2
+ elif d2 == 1:
+ shape[i] = d1
+ elif d1 == d2:
+ shape[i] = d1
+ else:
+ raise BroadcastError()
+
+ i = i - 1
+
+ return tuple(shape)
+
+
+def broadcast_shapes(*shapes: Shape):
+ if len(shapes) == 0:
+ raise ValueError("shapes=[] must be non-empty")
+ elif len(shapes) == 1:
+ return shapes[0]
+ result = _broadcast_shapes(shapes[0], shapes[1])
+ for i in range(2, len(shapes)):
+ result = _broadcast_shapes(result, shapes[i])
+ return result
+
+
+def normalize_axis(
+ axis: Optional[Union[int, Sequence[int]]], ndim: int
+) -> Tuple[int, ...]:
+ if axis is None:
+ return tuple(range(ndim))
+ elif isinstance(axis, Sequence) and not isinstance(axis, tuple):
+ axis = tuple(axis)
+ axes = axis if isinstance(axis, tuple) else (axis,)
+ axes = tuple(axis if axis >= 0 else ndim + axis for axis in axes)
+ return axes
+
+
+def ndindex(shape: Shape) -> Iterator[Index]:
+ """Yield every index of a shape"""
+ return (indices[0] for indices in iter_indices(shape))
+
+
+def iter_indices(
+ *shapes: Shape, skip_axes: Tuple[int, ...] = ()
+) -> Iterator[Tuple[Index, ...]]:
+ """Wrapper for ndindex.iter_indices()"""
+ # Prevent iterations if any shape has 0-sides
+ for shape in shapes:
+ if 0 in shape:
+ return
+ for indices in _iter_indices(*shapes, skip_axes=skip_axes):
+ yield tuple(i.raw for i in indices) # type: ignore
+
+
+def axis_ndindex(
+ shape: Shape, axis: int
+) -> Iterator[Tuple[Tuple[Union[int, slice], ...], ...]]:
+ """Generate indices that index all elements in dimensions beyond `axis`"""
+ assert axis >= 0 # sanity check
+ axis_indices = [range(side) for side in shape[:axis]]
+ for _ in range(axis, len(shape)):
+ axis_indices.append([slice(None, None)])
+ yield from product(*axis_indices)
+
+
+def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
+ """Generate indices that index all elements except in `axes` dimensions"""
+ base_indices = []
+ axes_indices = []
+ for axis, side in enumerate(shape):
+ if axis in axes:
+ base_indices.append([None])
+ axes_indices.append(range(side))
+ else:
+ base_indices.append(range(side))
+ axes_indices.append([None])
+ for base_idx in product(*base_indices):
+ indices = []
+ for idx in product(*axes_indices):
+ idx = list(idx)
+ for axis, side in enumerate(idx):
+ if axis not in axes:
+ idx[axis] = base_idx[axis]
+ idx = tuple(idx)
+ indices.append(idx)
+ yield list(indices)
+
+
+def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]:
+ """Reshape a flat sequence"""
+ if any(s == 0 for s in shape):
+ raise ValueError(
+ f"{shape=} contains 0-sided dimensions, "
+ f"but that's not representable in lists"
+ )
+ if len(shape) == 0:
+ assert len(flat_seq) == 1 # sanity check
+ return flat_seq[0]
+ elif len(shape) == 1:
+ return flat_seq
+ size = len(flat_seq)
+ n = math.prod(shape[1:])
+ return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
+
+
+def fmt_i(i: AtomicIndex) -> str:
+ if isinstance(i, int):
+ return str(i)
+ elif isinstance(i, slice):
+ res = ""
+ if i.start is not None:
+ res += str(i.start)
+ res += ":"
+ if i.stop is not None:
+ res += str(i.stop)
+ if i.step is not None:
+ res += f":{i.step}"
+ return res
+ elif i is None:
+ return "None"
+ else:
+ return "..."
+
+
+def fmt_idx(sym: str, idx: Index) -> str:
+ if idx == ():
+ return sym
+ res = f"{sym}["
+ _idx = idx if isinstance(idx, tuple) else (idx,)
+ if len(_idx) == 1:
+ res += fmt_i(_idx[0])
+ else:
+ res += ", ".join(fmt_i(i) for i in _idx)
+ res += "]"
+ return res
diff --git a/array_api_tests/special_cases/__init__.py b/array_api_tests/special_cases/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/array_api_tests/special_cases/test_abs.py b/array_api_tests/special_cases/test_abs.py
deleted file mode 100644
index 4ed04d02..00000000
--- a/array_api_tests/special_cases/test_abs.py
+++ /dev/null
@@ -1,53 +0,0 @@
-"""
-Special cases tests for abs.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import abs
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_abs_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `abs(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = abs(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_abs_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `abs(x, /)`:
-
- - If `x_i` is `-0`, the result is `+0`.
-
- """
- res = abs(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_abs_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `abs(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `+infinity`.
-
- """
- res = abs(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_acos.py b/array_api_tests/special_cases/test_acos.py
deleted file mode 100644
index b1c3cb56..00000000
--- a/array_api_tests/special_cases/test_acos.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""
-Special cases tests for acos.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, greater, less, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import acos
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_acos_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `acos(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = acos(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_acos_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `acos(x, /)`:
-
- - If `x_i` is `1`, the result is `+0`.
-
- """
- res = acos(arg1)
- mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_acos_special_cases_one_arg_greater(arg1):
- """
- Special case test for `acos(x, /)`:
-
- - If `x_i` is greater than `1`, the result is `NaN`.
-
- """
- res = acos(arg1)
- mask = greater(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_acos_special_cases_one_arg_less(arg1):
- """
- Special case test for `acos(x, /)`:
-
- - If `x_i` is less than `-1`, the result is `NaN`.
-
- """
- res = acos(arg1)
- mask = less(arg1, -one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_acosh.py b/array_api_tests/special_cases/test_acosh.py
deleted file mode 100644
index 8749eaf2..00000000
--- a/array_api_tests/special_cases/test_acosh.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""
-Special cases tests for acosh.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, less, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import acosh
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_acosh_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `acosh(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = acosh(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_acosh_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `acosh(x, /)`:
-
- - If `x_i` is `1`, the result is `+0`.
-
- """
- res = acosh(arg1)
- mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_acosh_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `acosh(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = acosh(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_acosh_special_cases_one_arg_less(arg1):
- """
- Special case test for `acosh(x, /)`:
-
- - If `x_i` is less than `1`, the result is `NaN`.
-
- """
- res = acosh(arg1)
- mask = less(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_add.py b/array_api_tests/special_cases/test_add.py
deleted file mode 100644
index eaccb803..00000000
--- a/array_api_tests/special_cases/test_add.py
+++ /dev/null
@@ -1,226 +0,0 @@
-"""
-Special cases tests for add.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, isfinite,
- logical_and, logical_or, non_zero, zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import add
-
-from hypothesis import given
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_either(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`.
-
- """
- res = add(arg1, arg2)
- mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_1(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is `-infinity`, the result is `NaN`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_2(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is `+infinity`, the result is `NaN`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_3(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_4(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is `-infinity`, the result is `-infinity`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_5(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is a finite number, the result is `+infinity`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), isfinite(arg2))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_6(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is a finite number, the result is `-infinity`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), isfinite(arg2))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_7(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is a finite number and `x2_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(isfinite(arg1), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_8(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is a finite number and `x2_i` is `-infinity`, the result is `-infinity`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(isfinite(arg1), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_9(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is `-0`, the result is `-0`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_10(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is `+0`, the result is `+0`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_11(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is `-0`, the result is `+0`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_12(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is `+0`, the result is `+0`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__equal_13(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is a nonzero finite number and `x2_i` is `-x1_i`, the result is `+0`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), exactly_equal(arg2, -arg1))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_either__equal(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is either `+0` or `-0` and `x2_i` is a nonzero finite number, the result is `x2_i`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2)))
- assert_exactly_equal(res[mask], (arg2)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_add_special_cases_two_args_equal__either(arg1, arg2):
- """
- Special case test for `add(x1, x2, /)`:
-
- - If `x1_i` is a nonzero finite number and `x2_i` is either `+0` or `-0`, the result is `x1_i`.
-
- """
- res = add(arg1, arg2)
- mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))))
- assert_exactly_equal(res[mask], (arg1)[mask])
-
-# TODO: Implement REMAINING test for:
-# - In the remaining cases, when neither `infinity`, `+0`, `-0`, nor a `NaN` is involved, and the operands have the same mathematical sign or have different magnitudes, the sum must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported round mode. If the magnitude is too large to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign.
diff --git a/array_api_tests/special_cases/test_asin.py b/array_api_tests/special_cases/test_asin.py
deleted file mode 100644
index 0a41b716..00000000
--- a/array_api_tests/special_cases/test_asin.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for asin.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, greater, less, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import asin
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_asin_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `asin(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = asin(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asin_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `asin(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = asin(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asin_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `asin(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = asin(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asin_special_cases_one_arg_greater(arg1):
- """
- Special case test for `asin(x, /)`:
-
- - If `x_i` is greater than `1`, the result is `NaN`.
-
- """
- res = asin(arg1)
- mask = greater(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asin_special_cases_one_arg_less(arg1):
- """
- Special case test for `asin(x, /)`:
-
- - If `x_i` is less than `-1`, the result is `NaN`.
-
- """
- res = asin(arg1)
- mask = less(arg1, -one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_asinh.py b/array_api_tests/special_cases/test_asinh.py
deleted file mode 100644
index a54d3346..00000000
--- a/array_api_tests/special_cases/test_asinh.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for asinh.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import asinh
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_asinh_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `asinh(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = asinh(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asinh_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `asinh(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = asinh(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asinh_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `asinh(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = asinh(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asinh_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `asinh(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = asinh(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_asinh_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `asinh(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `-infinity`.
-
- """
- res = asinh(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_atan.py b/array_api_tests/special_cases/test_atan.py
deleted file mode 100644
index 4b6936ed..00000000
--- a/array_api_tests/special_cases/test_atan.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for atan.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero, π
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import atan
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_atan_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `atan(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = atan(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atan_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `atan(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = atan(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atan_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `atan(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = atan(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atan_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `atan(x, /)`:
-
- - If `x_i` is `+infinity`, the result is an implementation-dependent approximation to `+π/2`.
-
- """
- res = atan(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask])
-
-
-@given(numeric_arrays)
-def test_atan_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `atan(x, /)`:
-
- - If `x_i` is `-infinity`, the result is an implementation-dependent approximation to `-π/2`.
-
- """
- res = atan(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask])
diff --git a/array_api_tests/special_cases/test_atan2.py b/array_api_tests/special_cases/test_atan2.py
deleted file mode 100644
index 9d7452e7..00000000
--- a/array_api_tests/special_cases/test_atan2.py
+++ /dev/null
@@ -1,314 +0,0 @@
-"""
-Special cases tests for atan2.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, isfinite,
- less, logical_and, logical_or, zero, π)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import atan2
-
-from hypothesis import given
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_either(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_greater__equal_1(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is greater than `0` and `x2_i` is `+0`, the result is an implementation-dependent approximation to `+π/2`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_greater__equal_2(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is greater than `0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `+π/2`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__greater_1(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__greater_2(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is greater than `0`, the result is `-0`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_1(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is `+0`, the result is `+0`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_2(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `+π`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_3(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is `+0`, the result is `-0`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_4(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `-π`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_5(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is finite, the result is an implementation-dependent approximation to `+π/2`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), isfinite(arg2))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/2)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_6(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is finite, the result is an implementation-dependent approximation to `-π/2`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), isfinite(arg2))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_7(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is `+infinity`, the result is an implementation-dependent approximation to `+π/4`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype)/4)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_8(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `+3π/4`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (+3*π(arg1.shape, arg1.dtype)/4)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_9(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is `+infinity`, the result is an implementation-dependent approximation to `-π/4`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/4)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__equal_10(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `-3π/4`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-3*π(arg1.shape, arg1.dtype)/4)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__less_1(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is an implementation-dependent approximation to `+π`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_equal__less_2(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is less than `0`, the result is an implementation-dependent approximation to `-π`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_less__equal_1(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is less than `0` and `x2_i` is `+0`, the result is an implementation-dependent approximation to `-π/2`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_less__equal_2(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is less than `0` and `x2_i` is `-0`, the result is an implementation-dependent approximation to `-π/2`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype)/2)[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_greater_equal__equal_1(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is greater than `0`, `x1_i` is a finite number, and `x2_i` is `+infinity`, the result is `+0`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_greater_equal__equal_2(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is greater than `0`, `x1_i` is a finite number, and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `+π`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (+π(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_less_equal__equal_1(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is less than `0`, `x1_i` is a finite number, and `x2_i` is `+infinity`, the result is `-0`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_atan2_special_cases_two_args_less_equal__equal_2(arg1, arg2):
- """
- Special case test for `atan2(x1, x2, /)`:
-
- - If `x1_i` is less than `0`, `x1_i` is a finite number, and `x2_i` is `-infinity`, the result is an implementation-dependent approximation to `-π`.
-
- """
- res = atan2(arg1, arg2)
- mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-π(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_atanh.py b/array_api_tests/special_cases/test_atanh.py
deleted file mode 100644
index 6e26cc99..00000000
--- a/array_api_tests/special_cases/test_atanh.py
+++ /dev/null
@@ -1,106 +0,0 @@
-"""
-Special cases tests for atanh.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, less, one,
- zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import atanh
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_atanh_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `atanh(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = atanh(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atanh_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `atanh(x, /)`:
-
- - If `x_i` is `-1`, the result is `-infinity`.
-
- """
- res = atanh(arg1)
- mask = exactly_equal(arg1, -one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atanh_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `atanh(x, /)`:
-
- - If `x_i` is `+1`, the result is `+infinity`.
-
- """
- res = atanh(arg1)
- mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atanh_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `atanh(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = atanh(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atanh_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `atanh(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = atanh(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atanh_special_cases_one_arg_less(arg1):
- """
- Special case test for `atanh(x, /)`:
-
- - If `x_i` is less than `-1`, the result is `NaN`.
-
- """
- res = atanh(arg1)
- mask = less(arg1, -one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_atanh_special_cases_one_arg_greater(arg1):
- """
- Special case test for `atanh(x, /)`:
-
- - If `x_i` is greater than `1`, the result is `NaN`.
-
- """
- res = atanh(arg1)
- mask = greater(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_ceil.py b/array_api_tests/special_cases/test_ceil.py
deleted file mode 100644
index 5056db66..00000000
--- a/array_api_tests/special_cases/test_ceil.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""
-Special cases tests for ceil.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import assert_exactly_equal, isintegral
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import ceil
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_ceil_special_cases_one_arg_equal(arg1):
- """
- Special case test for `ceil(x, /)`:
-
- - If `x_i` is already integer-valued, the result is `x_i`.
-
- """
- res = ceil(arg1)
- mask = isintegral(arg1)
- assert_exactly_equal(res[mask], (arg1)[mask])
diff --git a/array_api_tests/special_cases/test_cos.py b/array_api_tests/special_cases/test_cos.py
deleted file mode 100644
index e80a7130..00000000
--- a/array_api_tests/special_cases/test_cos.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for cos.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import cos
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_cos_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `cos(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = cos(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cos_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `cos(x, /)`:
-
- - If `x_i` is `+0`, the result is `1`.
-
- """
- res = cos(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cos_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `cos(x, /)`:
-
- - If `x_i` is `-0`, the result is `1`.
-
- """
- res = cos(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cos_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `cos(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `NaN`.
-
- """
- res = cos(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cos_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `cos(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `NaN`.
-
- """
- res = cos(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_cosh.py b/array_api_tests/special_cases/test_cosh.py
deleted file mode 100644
index bdca4a82..00000000
--- a/array_api_tests/special_cases/test_cosh.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for cosh.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import cosh
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_cosh_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `cosh(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = cosh(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cosh_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `cosh(x, /)`:
-
- - If `x_i` is `+0`, the result is `1`.
-
- """
- res = cosh(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cosh_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `cosh(x, /)`:
-
- - If `x_i` is `-0`, the result is `1`.
-
- """
- res = cosh(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cosh_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `cosh(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = cosh(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_cosh_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `cosh(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `+infinity`.
-
- """
- res = cosh(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_divide.py b/array_api_tests/special_cases/test_divide.py
deleted file mode 100644
index fe1596c9..00000000
--- a/array_api_tests/special_cases/test_divide.py
+++ /dev/null
@@ -1,293 +0,0 @@
-"""
-Special cases tests for divide.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, assert_negative_mathematical_sign,
- assert_positive_mathematical_sign, exactly_equal, greater, infinity,
- isfinite, isnegative, ispositive, less, logical_and, logical_not,
- logical_or, non_zero, same_sign, zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import divide
-
-from hypothesis import given
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_either(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_either__either_1(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_either__either_2(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+0` or `-0`, the result is `NaN`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__greater_1(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__greater_2(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is greater than `0`, the result is `-0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__less_1(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `-0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__less_2(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `-0` and `x2_i` is less than `0`, the result is `+0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_greater__equal_1(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is greater than `0` and `x2_i` is `+0`, the result is `+infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_greater__equal_2(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is greater than `0` and `x2_i` is `-0`, the result is `-infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(greater(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_less__equal_1(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is less than `0` and `x2_i` is `+0`, the result is `-infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_less__equal_2(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is less than `0` and `x2_i` is `-0`, the result is `+infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_1(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `+infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_2(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `-infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_3(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is a positive (i.e., greater than `0`) finite number, the result is `-infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), ispositive(arg2)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_4(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is `-infinity` and `x2_i` is a negative (i.e., less than `0`) finite number, the result is `+infinity`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(isfinite(arg2), isnegative(arg2)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_5(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `+infinity`, the result is `+0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_6(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is a positive (i.e., greater than `0`) finite number and `x2_i` is `-infinity`, the result is `-0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(logical_and(isfinite(arg1), ispositive(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_7(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `+infinity`, the result is `-0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_equal__equal_8(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` is a negative (i.e., less than `0`) finite number and `x2_i` is `-infinity`, the result is `+0`.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(logical_and(isfinite(arg1), isnegative(arg1)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_same_sign_both(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` and `x2_i` have the same mathematical sign and are both nonzero finite numbers, the result has a positive mathematical sign.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(same_sign(arg1, arg2), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2))))
- assert_positive_mathematical_sign(res[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_divide_special_cases_two_args_different_signs_both(arg1, arg2):
- """
- Special case test for `divide(x1, x2, /)`:
-
- - If `x1_i` and `x2_i` have different mathematical signs and are both nonzero finite numbers, the result has a negative mathematical sign.
-
- """
- res = divide(arg1, arg2)
- mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_and(isfinite(arg2), non_zero(arg2))))
- assert_negative_mathematical_sign(res[mask])
-
-# TODO: Implement REMAINING test for:
-# - In the remaining cases, where neither `-infinity`, `+0`, `-0`, nor `NaN` is involved, the quotient must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too larger to represent, the operation overflows and the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the operation underflows and the result is a zero of appropriate mathematical sign.
diff --git a/array_api_tests/special_cases/test_exp.py b/array_api_tests/special_cases/test_exp.py
deleted file mode 100644
index 47399648..00000000
--- a/array_api_tests/special_cases/test_exp.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for exp.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import exp
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_exp_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `exp(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = exp(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_exp_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `exp(x, /)`:
-
- - If `x_i` is `+0`, the result is `1`.
-
- """
- res = exp(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_exp_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `exp(x, /)`:
-
- - If `x_i` is `-0`, the result is `1`.
-
- """
- res = exp(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_exp_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `exp(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = exp(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_exp_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `exp(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `+0`.
-
- """
- res = exp(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_expm1.py b/array_api_tests/special_cases/test_expm1.py
deleted file mode 100644
index d96b742e..00000000
--- a/array_api_tests/special_cases/test_expm1.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for expm1.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import expm1
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_expm1_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `expm1(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = expm1(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_expm1_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `expm1(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = expm1(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_expm1_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `expm1(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = expm1(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_expm1_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `expm1(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = expm1(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_expm1_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `expm1(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `-1`.
-
- """
- res = expm1(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-one(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_floor.py b/array_api_tests/special_cases/test_floor.py
deleted file mode 100644
index a7a0f473..00000000
--- a/array_api_tests/special_cases/test_floor.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""
-Special cases tests for floor.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import assert_exactly_equal, isintegral
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import floor
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_floor_special_cases_one_arg_equal(arg1):
- """
- Special case test for `floor(x, /)`:
-
- - If `x_i` is already integer-valued, the result is `x_i`.
-
- """
- res = floor(arg1)
- mask = isintegral(arg1)
- assert_exactly_equal(res[mask], (arg1)[mask])
diff --git a/array_api_tests/special_cases/test_log.py b/array_api_tests/special_cases/test_log.py
deleted file mode 100644
index 0ea6cd25..00000000
--- a/array_api_tests/special_cases/test_log.py
+++ /dev/null
@@ -1,80 +0,0 @@
-"""
-Special cases tests for log.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, less, logical_or,
- one, zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import log
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_log_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `log(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = log(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `log(x, /)`:
-
- - If `x_i` is `1`, the result is `+0`.
-
- """
- res = log(arg1)
- mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `log(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = log(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log_special_cases_one_arg_less(arg1):
- """
- Special case test for `log(x, /)`:
-
- - If `x_i` is less than `0`, the result is `NaN`.
-
- """
- res = log(arg1)
- mask = less(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log_special_cases_one_arg_either(arg1):
- """
- Special case test for `log(x, /)`:
-
- - If `x_i` is either `+0` or `-0`, the result is `-infinity`.
-
- """
- res = log(arg1)
- mask = logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_log10.py b/array_api_tests/special_cases/test_log10.py
deleted file mode 100644
index 8dc5a5de..00000000
--- a/array_api_tests/special_cases/test_log10.py
+++ /dev/null
@@ -1,80 +0,0 @@
-"""
-Special cases tests for log10.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, less, logical_or,
- one, zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import log10
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_log10_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `log10(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = log10(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log10_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `log10(x, /)`:
-
- - If `x_i` is `1`, the result is `+0`.
-
- """
- res = log10(arg1)
- mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log10_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `log10(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = log10(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log10_special_cases_one_arg_less(arg1):
- """
- Special case test for `log10(x, /)`:
-
- - If `x_i` is less than `0`, the result is `NaN`.
-
- """
- res = log10(arg1)
- mask = less(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log10_special_cases_one_arg_either(arg1):
- """
- Special case test for `log10(x, /)`:
-
- - If `x_i` is either `+0` or `-0`, the result is `-infinity`.
-
- """
- res = log10(arg1)
- mask = logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_log1p.py b/array_api_tests/special_cases/test_log1p.py
deleted file mode 100644
index 432a761b..00000000
--- a/array_api_tests/special_cases/test_log1p.py
+++ /dev/null
@@ -1,92 +0,0 @@
-"""
-Special cases tests for log1p.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, less, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import log1p
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_log1p_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `log1p(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = log1p(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log1p_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `log1p(x, /)`:
-
- - If `x_i` is `-1`, the result is `-infinity`.
-
- """
- res = log1p(arg1)
- mask = exactly_equal(arg1, -one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log1p_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `log1p(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = log1p(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log1p_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `log1p(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = log1p(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log1p_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `log1p(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = log1p(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log1p_special_cases_one_arg_less(arg1):
- """
- Special case test for `log1p(x, /)`:
-
- - If `x_i` is less than `-1`, the result is `NaN`.
-
- """
- res = log1p(arg1)
- mask = less(arg1, -one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_log2.py b/array_api_tests/special_cases/test_log2.py
deleted file mode 100644
index 41797dd7..00000000
--- a/array_api_tests/special_cases/test_log2.py
+++ /dev/null
@@ -1,80 +0,0 @@
-"""
-Special cases tests for log2.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, infinity, less, logical_or,
- one, zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import log2
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_log2_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `log2(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = log2(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log2_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `log2(x, /)`:
-
- - If `x_i` is `1`, the result is `+0`.
-
- """
- res = log2(arg1)
- mask = exactly_equal(arg1, one(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log2_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `log2(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = log2(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log2_special_cases_one_arg_less(arg1):
- """
- Special case test for `log2(x, /)`:
-
- - If `x_i` is less than `0`, the result is `NaN`.
-
- """
- res = log2(arg1)
- mask = less(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_log2_special_cases_one_arg_either(arg1):
- """
- Special case test for `log2(x, /)`:
-
- - If `x_i` is either `+0` or `-0`, the result is `-infinity`.
-
- """
- res = log2(arg1)
- mask = logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_multiply.py b/array_api_tests/special_cases/test_multiply.py
deleted file mode 100644
index 0ab2eec0..00000000
--- a/array_api_tests/special_cases/test_multiply.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""
-Special cases tests for multiply.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, assert_isinf,
- assert_negative_mathematical_sign, assert_positive_mathematical_sign,
- exactly_equal, infinity, isfinite, logical_and, logical_not,
- logical_or, non_zero, same_sign, zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import multiply
-
-from hypothesis import given
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_either(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If either `x1_i` or `x2_i` is `NaN`, the result is `NaN`.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_or(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), exactly_equal(arg2, NaN(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_either__either_1(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+0` or `-0`, the result is `NaN`.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, zero(arg2.shape, arg2.dtype)), exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_either__either_2(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If `x1_i` is either `+0` or `-0` and `x2_i` is either `+infinity` or `-infinity`, the result is `NaN`.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_and(logical_or(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_either__either_3(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))))
- assert_isinf(res[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_same_sign_except(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If `x1_i` and `x2_i` have the same mathematical sign, the result has a positive mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, NaN(res.shape, res.dtype))))
- assert_positive_mathematical_sign(res[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_different_signs_except(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If `x1_i` and `x2_i` have different mathematical signs, the result has a negative mathematical sign, unless the result is `NaN`. If the result is `NaN`, the "sign" of `NaN` is implementation-defined.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, NaN(res.shape, res.dtype))))
- assert_negative_mathematical_sign(res[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_either__equal(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If `x1_i` is either `+infinity` or `-infinity` and `x2_i` is a nonzero finite number, the result is a signed infinity with the mathematical sign determined by the rule already stated above.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_and(logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))), logical_and(isfinite(arg2), non_zero(arg2)))
- assert_isinf(res[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_multiply_special_cases_two_args_equal__either(arg1, arg2):
- """
- Special case test for `multiply(x1, x2, /)`:
-
- - If `x1_i` is a nonzero finite number and `x2_i` is either `+infinity` or `-infinity`, the result is a signed infinity with the mathematical sign determined by the rule already stated above.
-
- """
- res = multiply(arg1, arg2)
- mask = logical_and(logical_and(isfinite(arg1), non_zero(arg1)), logical_or(exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype))))
- assert_isinf(res[mask])
-
-# TODO: Implement REMAINING test for:
-# - In the remaining cases, where neither `infinity` nor `NaN` is involved, the product must be computed and rounded to the nearest representable value according to IEEE 754-2019 and a supported rounding mode. If the magnitude is too large to represent, the result is an `infinity` of appropriate mathematical sign. If the magnitude is too small to represent, the result is a zero of appropriate mathematical sign.
diff --git a/array_api_tests/special_cases/test_pow.py b/array_api_tests/special_cases/test_pow.py
deleted file mode 100644
index a422ffd3..00000000
--- a/array_api_tests/special_cases/test_pow.py
+++ /dev/null
@@ -1,327 +0,0 @@
-"""
-Special cases tests for pow.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (NaN, assert_exactly_equal, exactly_equal, greater, infinity, isfinite,
- isintegral, isodd, less, logical_and, logical_not, notequal, one, zero)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import pow
-
-from hypothesis import given
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_notequal__equal(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is not equal to `1` and `x2_i` is `NaN`, the result is `NaN`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(logical_not(exactly_equal(arg1, one(arg1.shape, arg1.dtype))), exactly_equal(arg2, NaN(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_even_if_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x2_i` is `+0`, the result is `1`, even if `x1_i` is `NaN`.
-
- """
- res = pow(arg1, arg2)
- mask = exactly_equal(arg2, zero(arg2.shape, arg2.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_even_if_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x2_i` is `-0`, the result is `1`, even if `x1_i` is `NaN`.
-
- """
- res = pow(arg1, arg2)
- mask = exactly_equal(arg2, -zero(arg2.shape, arg2.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__notequal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `NaN` and `x2_i` is not equal to `0`, the result is `NaN`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, NaN(arg1.shape, arg1.dtype)), notequal(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__notequal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `1` and `x2_i` is not `NaN`, the result is `1`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, one(arg1.shape, arg1.dtype)), logical_not(exactly_equal(arg2, NaN(arg2.shape, arg2.dtype))))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_absgreater__equal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `abs(x1_i)` is greater than `1` and `x2_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_absgreater__equal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `abs(x1_i)` is greater than `1` and `x2_i` is `-infinity`, the result is `+0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(greater(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_absequal__equal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `abs(x1_i)` is `1` and `x2_i` is `+infinity`, the result is `1`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_absequal__equal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `abs(x1_i)` is `1` and `x2_i` is `-infinity`, the result is `1`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_absless__equal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `abs(x1_i)` is less than `1` and `x2_i` is `+infinity`, the result is `+0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_absless__equal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `abs(x1_i)` is less than `1` and `x2_i` is `-infinity`, the result is `+infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(less(abs(arg1), one(arg1.shape, arg1.dtype)), exactly_equal(arg2, -infinity(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__greater_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is greater than `0`, the result is `+infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__greater_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is greater than `0`, the result is `+0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), greater(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__less_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `+infinity` and `x2_i` is less than `0`, the result is `+0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__less_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `+0` and `x2_i` is less than `0`, the result is `+infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, zero(arg1.shape, arg1.dtype)), less(arg2, zero(arg2.shape, arg2.dtype)))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__greater_equal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__greater_equal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is an odd integer value, the result is `-0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__greater_notequal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-infinity`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2))))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__greater_notequal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-0`, `x2_i` is greater than `0`, and `x2_i` is not an odd integer value, the result is `+0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(greater(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2))))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__less_equal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2)))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__less_equal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is an odd integer value, the result is `-infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), isodd(arg2)))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__less_notequal_1(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-infinity`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+0`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2))))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_equal__less_notequal_2(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is `-0`, `x2_i` is less than `0`, and `x2_i` is not an odd integer value, the result is `+infinity`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), logical_and(less(arg2, zero(arg2.shape, arg2.dtype)), logical_not(isodd(arg2))))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays, numeric_arrays)
-def test_pow_special_cases_two_args_less_equal__equal_notequal(arg1, arg2):
- """
- Special case test for `pow(x1, x2, /)`:
-
- - If `x1_i` is less than `0`, `x1_i` is a finite number, `x2_i` is a finite number, and `x2_i` is not an integer value, the result is `NaN`.
-
- """
- res = pow(arg1, arg2)
- mask = logical_and(logical_and(less(arg1, zero(arg1.shape, arg1.dtype)), isfinite(arg1)), logical_and(isfinite(arg2), logical_not(isintegral(arg2))))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_round.py b/array_api_tests/special_cases/test_round.py
deleted file mode 100644
index 89b66db0..00000000
--- a/array_api_tests/special_cases/test_round.py
+++ /dev/null
@@ -1,42 +0,0 @@
-"""
-Special cases tests for round.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import (assert_exactly_equal, assert_iseven, assert_positive, ceil, equal,
- floor, isintegral, logical_and, not_equal, one, subtract)
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import round
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_round_special_cases_one_arg_equal(arg1):
- """
- Special case test for `round(x, /)`:
-
- - If `x_i` is already integer-valued, the result is `x_i`.
-
- """
- res = round(arg1)
- mask = isintegral(arg1)
- assert_exactly_equal(res[mask], (arg1)[mask])
-
-
-@given(numeric_arrays)
-def test_round_special_cases_one_arg_two_integers_equally_close(arg1):
- """
- Special case test for `round(x, /)`:
-
- - If two integers are equally close to `x_i`, the result is the even integer closest to `x_i`.
-
- """
- res = round(arg1)
- mask = logical_and(not_equal(floor(arg1), ceil(arg1)), equal(subtract(arg1, floor(arg1)), subtract(ceil(arg1), arg1)))
- assert_iseven(res[mask])
- assert_positive(subtract(one(arg1[mask].shape, arg1[mask].dtype), abs(subtract(arg1[mask], res[mask]))))
diff --git a/array_api_tests/special_cases/test_sign.py b/array_api_tests/special_cases/test_sign.py
deleted file mode 100644
index dd661811..00000000
--- a/array_api_tests/special_cases/test_sign.py
+++ /dev/null
@@ -1,53 +0,0 @@
-"""
-Special cases tests for sign.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import assert_exactly_equal, exactly_equal, greater, less, logical_or, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import sign
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_sign_special_cases_one_arg_less(arg1):
- """
- Special case test for `sign(x, /)`:
-
- - If `x_i` is less than `0`, the result is `-1`.
-
- """
- res = sign(arg1)
- mask = less(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sign_special_cases_one_arg_either(arg1):
- """
- Special case test for `sign(x, /)`:
-
- - If `x_i` is either `-0` or `+0`, the result is `0`.
-
- """
- res = sign(arg1)
- mask = logical_or(exactly_equal(arg1, -zero(arg1.shape, arg1.dtype)), exactly_equal(arg1, zero(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sign_special_cases_one_arg_greater(arg1):
- """
- Special case test for `sign(x, /)`:
-
- - If `x_i` is greater than `0`, the result is `+1`.
-
- """
- res = sign(arg1)
- mask = greater(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_sin.py b/array_api_tests/special_cases/test_sin.py
deleted file mode 100644
index 4af01736..00000000
--- a/array_api_tests/special_cases/test_sin.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""
-Special cases tests for sin.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, logical_or, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import sin
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_sin_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `sin(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = sin(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sin_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `sin(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = sin(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sin_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `sin(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = sin(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sin_special_cases_one_arg_either(arg1):
- """
- Special case test for `sin(x, /)`:
-
- - If `x_i` is either `+infinity` or `-infinity`, the result is `NaN`.
-
- """
- res = sin(arg1)
- mask = logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_sinh.py b/array_api_tests/special_cases/test_sinh.py
deleted file mode 100644
index 4d2ff217..00000000
--- a/array_api_tests/special_cases/test_sinh.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for sinh.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import sinh
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_sinh_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `sinh(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = sinh(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sinh_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `sinh(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = sinh(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sinh_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `sinh(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = sinh(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sinh_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `sinh(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = sinh(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sinh_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `sinh(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `-infinity`.
-
- """
- res = sinh(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-infinity(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_sqrt.py b/array_api_tests/special_cases/test_sqrt.py
deleted file mode 100644
index 18244755..00000000
--- a/array_api_tests/special_cases/test_sqrt.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for sqrt.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, less, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import sqrt
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_sqrt_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `sqrt(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = sqrt(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sqrt_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `sqrt(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = sqrt(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sqrt_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `sqrt(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = sqrt(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sqrt_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `sqrt(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+infinity`.
-
- """
- res = sqrt(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (infinity(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_sqrt_special_cases_one_arg_less(arg1):
- """
- Special case test for `sqrt(x, /)`:
-
- - If `x_i` is less than `0`, the result is `NaN`.
-
- """
- res = sqrt(arg1)
- mask = less(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_tan.py b/array_api_tests/special_cases/test_tan.py
deleted file mode 100644
index ec09878d..00000000
--- a/array_api_tests/special_cases/test_tan.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""
-Special cases tests for tan.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, logical_or, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import tan
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_tan_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `tan(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = tan(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_tan_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `tan(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = tan(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_tan_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `tan(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = tan(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_tan_special_cases_one_arg_either(arg1):
- """
- Special case test for `tan(x, /)`:
-
- - If `x_i` is either `+infinity` or `-infinity`, the result is `NaN`.
-
- """
- res = tan(arg1)
- mask = logical_or(exactly_equal(arg1, infinity(arg1.shape, arg1.dtype)), exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype)))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_tanh.py b/array_api_tests/special_cases/test_tanh.py
deleted file mode 100644
index 91304c2f..00000000
--- a/array_api_tests/special_cases/test_tanh.py
+++ /dev/null
@@ -1,79 +0,0 @@
-"""
-Special cases tests for tanh.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import NaN, assert_exactly_equal, exactly_equal, infinity, one, zero
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import tanh
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_tanh_special_cases_one_arg_equal_1(arg1):
- """
- Special case test for `tanh(x, /)`:
-
- - If `x_i` is `NaN`, the result is `NaN`.
-
- """
- res = tanh(arg1)
- mask = exactly_equal(arg1, NaN(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (NaN(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_tanh_special_cases_one_arg_equal_2(arg1):
- """
- Special case test for `tanh(x, /)`:
-
- - If `x_i` is `+0`, the result is `+0`.
-
- """
- res = tanh(arg1)
- mask = exactly_equal(arg1, zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_tanh_special_cases_one_arg_equal_3(arg1):
- """
- Special case test for `tanh(x, /)`:
-
- - If `x_i` is `-0`, the result is `-0`.
-
- """
- res = tanh(arg1)
- mask = exactly_equal(arg1, -zero(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-zero(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_tanh_special_cases_one_arg_equal_4(arg1):
- """
- Special case test for `tanh(x, /)`:
-
- - If `x_i` is `+infinity`, the result is `+1`.
-
- """
- res = tanh(arg1)
- mask = exactly_equal(arg1, infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (one(arg1.shape, arg1.dtype))[mask])
-
-
-@given(numeric_arrays)
-def test_tanh_special_cases_one_arg_equal_5(arg1):
- """
- Special case test for `tanh(x, /)`:
-
- - If `x_i` is `-infinity`, the result is `-1`.
-
- """
- res = tanh(arg1)
- mask = exactly_equal(arg1, -infinity(arg1.shape, arg1.dtype))
- assert_exactly_equal(res[mask], (-one(arg1.shape, arg1.dtype))[mask])
diff --git a/array_api_tests/special_cases/test_trunc.py b/array_api_tests/special_cases/test_trunc.py
deleted file mode 100644
index c6a11c6e..00000000
--- a/array_api_tests/special_cases/test_trunc.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""
-Special cases tests for trunc.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import assert_exactly_equal, isintegral
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import trunc
-
-from hypothesis import given
-
-
-@given(numeric_arrays)
-def test_trunc_special_cases_one_arg_equal(arg1):
- """
- Special case test for `trunc(x, /)`:
-
- - If `x_i` is already integer-valued, the result is `x_i`.
-
- """
- res = trunc(arg1)
- mask = isintegral(arg1)
- assert_exactly_equal(res[mask], (arg1)[mask])
diff --git a/array_api_tests/stubs.py b/array_api_tests/stubs.py
new file mode 100644
index 00000000..9025c461
--- /dev/null
+++ b/array_api_tests/stubs.py
@@ -0,0 +1,98 @@
+import inspect
+import sys
+from importlib import import_module
+from importlib.util import find_spec
+from pathlib import Path
+from types import FunctionType, ModuleType
+from typing import Dict, List
+
+from . import api_version
+
+__all__ = [
+ "name_to_func",
+ "array_methods",
+ "array_attributes",
+ "category_to_funcs",
+ "EXTENSIONS",
+ "extension_to_funcs",
+]
+
+spec_module = "_" + api_version.replace('.', '_')
+
+spec_dir = Path(__file__).parent.parent / "array-api" / "spec" / api_version / "API_specification"
+assert spec_dir.exists(), f"{spec_dir} not found - try `git submodule update --init`"
+sigs_dir = Path(__file__).parent.parent / "array-api" / "src" / "array_api_stubs" / spec_module
+assert sigs_dir.exists()
+
+sigs_abs_path: str = str(sigs_dir.parent.parent.resolve())
+sys.path.append(sigs_abs_path)
+assert find_spec(f"array_api_stubs.{spec_module}") is not None
+
+name_to_mod: Dict[str, ModuleType] = {}
+for path in sigs_dir.glob("*.py"):
+ name = path.name.replace(".py", "")
+ name_to_mod[name] = import_module(f"array_api_stubs.{spec_module}.{name}")
+
+array = name_to_mod["array_object"].array
+array_methods = [
+ f for n, f in inspect.getmembers(array, predicate=inspect.isfunction)
+ if n != "__init__" # probably exists for Sphinx
+]
+array_attributes = [
+ n for n, f in inspect.getmembers(array, predicate=lambda x: isinstance(x, property))
+]
+
+category_to_funcs: Dict[str, List[FunctionType]] = {}
+for name, mod in name_to_mod.items():
+ if name.endswith("_functions"):
+ category = name.replace("_functions", "")
+ objects = [getattr(mod, name) for name in mod.__all__]
+ assert all(isinstance(o, FunctionType) for o in objects) # sanity check
+ category_to_funcs[category] = objects
+
+all_funcs = []
+for funcs in [array_methods, *category_to_funcs.values()]:
+ all_funcs.extend(funcs)
+name_to_func: Dict[str, FunctionType] = {f.__name__: f for f in all_funcs}
+
+info_funcs = []
+if api_version >= "2023.12":
+ # The info functions in the stubs are in info.py, but this is not a name
+ # in the standard.
+ info_mod = name_to_mod["info"]
+
+ # Note that __array_namespace_info__ is in info.__all__ but it is in the
+ # top-level namespace, not the info namespace.
+ info_funcs = [getattr(info_mod, name) for name in info_mod.__all__
+ if name != '__array_namespace_info__']
+ assert all(isinstance(f, FunctionType) for f in info_funcs)
+ name_to_func.update({f.__name__: f for f in info_funcs})
+
+ all_funcs.append(info_mod.__array_namespace_info__)
+ name_to_func['__array_namespace_info__'] = info_mod.__array_namespace_info__
+ category_to_funcs['info'] = [info_mod.__array_namespace_info__]
+
+EXTENSIONS: List[str] = ["linalg"]
+if api_version >= "2022.12":
+ EXTENSIONS.append("fft")
+extension_to_funcs: Dict[str, List[FunctionType]] = {}
+for ext in EXTENSIONS:
+ mod = name_to_mod[ext]
+ objects = [getattr(mod, name) for name in mod.__all__]
+ assert all(isinstance(o, FunctionType) for o in objects) # sanity check
+ funcs = []
+ for func in objects:
+ if "Alias" in func.__doc__:
+ funcs.append(name_to_func[func.__name__])
+ else:
+ funcs.append(func)
+ extension_to_funcs[ext] = funcs
+
+for funcs in extension_to_funcs.values():
+ for func in funcs:
+ if func.__name__ not in name_to_func.keys():
+ name_to_func[func.__name__] = func
+
+# sanity check public attributes are not empty
+for attr in __all__:
+ assert len(locals()[attr]) != 0, f"{attr} is empty"
diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py
new file mode 100644
index 00000000..4d4af350
--- /dev/null
+++ b/array_api_tests/test_array_object.py
@@ -0,0 +1,341 @@
+import cmath
+import math
+from itertools import product
+from typing import List, Sequence, Tuple, Union, get_args
+
+import pytest
+from hypothesis import assume, given, note
+from hypothesis import strategies as st
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import xps
+from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
+
+
+def scalar_objects(
+ dtype: DataType, shape: Shape
+) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
+ """Generates scalars or nested sequences which are valid for xp.asarray()"""
+ size = math.prod(shape)
+ return st.lists(hh.from_dtype(dtype), min_size=size, max_size=size).map(
+ lambda l: sh.reshape(l, shape)
+ )
+
+
+def normalize_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]:
+ """
+ Normalize an indexing key.
+
+ * If a non-tuple index, wrap as a tuple.
+ * Represent ellipsis as equivalent slices.
+ """
+ _key = tuple(key) if isinstance(key, tuple) else (key,)
+ if Ellipsis in _key:
+ nonexpanding_key = tuple(i for i in _key if i is not None)
+ start_a = nonexpanding_key.index(Ellipsis)
+ stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1))
+ slices = tuple(slice(None) for _ in range(start_a, stop_a))
+ start_pos = _key.index(Ellipsis)
+ _key = _key[:start_pos] + slices + _key[start_pos + 1 :]
+ return _key
+
+
+def get_indexed_axes_and_out_shape(
+ key: Tuple[Union[int, slice, None], ...], shape: Shape
+) -> Tuple[Tuple[Sequence[int], ...], Shape]:
+ """
+ From the (normalized) key and input shape, calculates:
+
+ * indexed_axes: For each dimension, the axes which the key indexes.
+ * out_shape: The resulting shape of indexing an array (of the input shape)
+ with the key.
+ """
+ axes_indices = []
+ out_shape = []
+ a = 0
+ for i in key:
+ if i is None:
+ out_shape.append(1)
+ else:
+ side = shape[a]
+ if isinstance(i, int):
+ if i < 0:
+ i += side
+ axes_indices.append((i,))
+ else:
+ indices = range(side)[i]
+ axes_indices.append(indices)
+ out_shape.append(len(indices))
+ a += 1
+ return tuple(axes_indices), tuple(out_shape)
+
+
+@given(shape=hh.shapes(), dtype=hh.all_dtypes, data=st.data())
+def test_getitem(shape, dtype, data):
+ zero_sided = any(side == 0 for side in shape)
+ if zero_sided:
+ x = xp.zeros(shape, dtype=dtype)
+ else:
+ obj = data.draw(scalar_objects(dtype, shape), label="obj")
+ x = xp.asarray(obj, dtype=dtype)
+ note(f"{x=}")
+ key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
+
+ out = x[key]
+
+ ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
+ _key = normalize_key(key, shape)
+ axes_indices, expected_shape = get_indexed_axes_and_out_shape(_key, shape)
+ ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
+ out_zero_sided = any(side == 0 for side in expected_shape)
+ if not zero_sided and not out_zero_sided:
+ out_obj = []
+ for idx in product(*axes_indices):
+ val = obj
+ for i in idx:
+ val = val[i]
+ out_obj.append(val)
+ out_obj = sh.reshape(out_obj, expected_shape)
+ expected = xp.asarray(out_obj, dtype=dtype)
+ ph.assert_array_elements("__getitem__", out=out, expected=expected)
+
+
+@pytest.mark.unvectorized
+@given(
+ shape=hh.shapes(),
+ dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
+ data=st.data(),
+)
+def test_setitem(shape, dtypes, data):
+ zero_sided = any(side == 0 for side in shape)
+ if zero_sided:
+ x = xp.zeros(shape, dtype=dtypes.result_dtype)
+ else:
+ obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj")
+ x = xp.asarray(obj, dtype=dtypes.result_dtype)
+ note(f"{x=}")
+ key = data.draw(xps.indices(shape=shape), label="key")
+ _key = normalize_key(key, shape)
+ axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
+ value_strat = hh.arrays(dtype=dtypes.result_dtype, shape=out_shape)
+ if out_shape == ():
+ # We can pass scalars if we're only indexing one element
+ value_strat |= hh.from_dtype(dtypes.result_dtype)
+ value = data.draw(value_strat, label="value")
+
+ res = xp.asarray(x, copy=True)
+ res[key] = value
+
+ ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
+ ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.shape")
+ f_res = sh.fmt_idx("x", key)
+ if isinstance(value, get_args(Scalar)):
+ msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
+ if cmath.isnan(value):
+ assert xp.isnan(res[key]), msg
+ else:
+ assert res[key] == value, msg
+ else:
+ ph.assert_array_elements("__setitem__", out=res[key], expected=value, out_repr=f_res)
+ unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
+ for idx in unaffected_indices:
+ ph.assert_0d_equals(
+ "__setitem__",
+ x_repr=f"old {f_res}",
+ x_val=x[idx],
+ out_repr=f"modified {f_res}",
+ out_val=res[idx],
+ )
+
+
+@pytest.mark.unvectorized
+@pytest.mark.data_dependent_shapes
+@given(hh.shapes(), st.data())
+def test_getitem_masking(shape, data):
+ x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
+ mask_shapes = st.one_of(
+ st.sampled_from([x.shape, ()]),
+ st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
+ lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l))
+ ),
+ hh.shapes(),
+ )
+ key = data.draw(hh.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
+
+ if key.ndim > x.ndim or not all(
+ ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
+ ):
+ with pytest.raises(IndexError):
+ x[key]
+ return
+
+ out = x[key]
+
+ ph.assert_dtype("__getitem__", in_dtype=x.dtype, out_dtype=out.dtype)
+ if key.ndim == 0:
+ expected_shape = (1,) if key else (0,)
+ expected_shape += x.shape
+ else:
+ size = int(xp.sum(xp.astype(key, xp.uint8)))
+ expected_shape = (size,) + x.shape[key.ndim :]
+ ph.assert_shape("__getitem__", out_shape=out.shape, expected=expected_shape)
+ if not any(s == 0 for s in key.shape):
+ assume(key.ndim == x.ndim) # TODO: test key.ndim < x.ndim scenarios
+ out_indices = sh.ndindex(out.shape)
+ for x_idx in sh.ndindex(x.shape):
+ if key[x_idx]:
+ out_idx = next(out_indices)
+ ph.assert_0d_equals(
+ "__getitem__",
+ x_repr=f"x[{x_idx}]",
+ x_val=x[x_idx],
+ out_repr=f"out[{out_idx}]",
+ out_val=out[out_idx],
+ )
+
+
+@pytest.mark.unvectorized
+@given(hh.shapes(), st.data())
+def test_setitem_masking(shape, data):
+ x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
+ key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
+ value = data.draw(
+ hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
+ )
+
+ res = xp.asarray(x, copy=True)
+ res[key] = value
+
+ ph.assert_dtype("__setitem__", in_dtype=x.dtype, out_dtype=res.dtype, repr_name="x.dtype")
+ ph.assert_shape("__setitem__", out_shape=res.shape, expected=x.shape, repr_name="x.dtype")
+ scalar_type = dh.get_scalar_type(x.dtype)
+ for idx in sh.ndindex(x.shape):
+ if key[idx]:
+ if isinstance(value, get_args(Scalar)):
+ ph.assert_scalar_equals(
+ "__setitem__",
+ type_=scalar_type,
+ idx=idx,
+ out=scalar_type(res[idx]),
+ expected=value,
+ repr_name="modified x",
+ )
+ else:
+ ph.assert_0d_equals(
+ "__setitem__",
+ x_repr="value",
+ x_val=value,
+ out_repr=f"modified x[{idx}]",
+ out_val=res[idx]
+ )
+ else:
+ ph.assert_0d_equals(
+ "__setitem__",
+ x_repr=f"old x[{idx}]",
+ x_val=x[idx],
+ out_repr=f"modified x[{idx}]",
+ out_val=res[idx]
+ )
+
+
+# ### Fancy indexing ###
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.unvectorized
+@pytest.mark.parametrize("idx_max_dims", [1, None])
+@given(shape=hh.shapes(min_dims=2), data=st.data())
+def test_getitem_arrays_and_ints_1(shape, data, idx_max_dims):
+ # min_dims=2 : test multidim `x` arrays
+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
+ _test_getitem_arrays_and_ints(shape, data, idx_max_dims)
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.unvectorized
+@pytest.mark.parametrize("idx_max_dims", [1, None])
+@given(shape=hh.shapes(min_dims=1), data=st.data())
+def test_getitem_arrays_and_ints_2(shape, data, idx_max_dims):
+ # min_dims=1 : favor 1D `x` arrays
+ # index arrays are 1D for idx_max_dims=1 and multidim for idx_max_dims=None
+ _test_getitem_arrays_and_ints(shape, data, idx_max_dims)
+
+
+def _test_getitem_arrays_and_ints(shape, data, idx_max_dims):
+ assume((len(shape) > 0) and all(sh > 0 for sh in shape))
+
+ dtype = xp.int32
+ obj = data.draw(scalar_objects(dtype, shape), label="obj")
+ x = xp.asarray(obj, dtype=dtype)
+
+ # draw a mix of ints and index arrays
+ arr_index = [data.draw(st.booleans()) for _ in range(len(shape))]
+ assume(sum(arr_index) > 0)
+
+ # draw shapes for index arrays: max_dims=1 ==> 1D indexing arrays ONLY
+ # max_dims=None ==> multidim indexing arrays
+ if sum(arr_index) > 0:
+ index_shapes = data.draw(
+ hh.mutually_broadcastable_shapes(
+ sum(arr_index), min_dims=1, max_dims=idx_max_dims, min_side=1
+ )
+ )
+ index_shapes = list(index_shapes)
+
+ # prepare the indexing tuple, a mix of integer indices and index arrays
+ key = []
+ for i,typ in enumerate(arr_index):
+ if typ:
+ # draw an array index
+ this_idx = data.draw(
+ xps.arrays(
+ dtype,
+ shape=index_shapes.pop(),
+ elements=st.integers(0, shape[i]-1)
+ )
+ )
+ key.append(this_idx)
+
+ else:
+ # draw an integer
+ key.append(data.draw(st.integers(-shape[i], shape[i]-1)))
+
+ key = tuple(key)
+ out = x[key]
+
+ arrays = [xp.asarray(k) for k in key]
+ bcast_shape = sh.broadcast_shapes(*[arr.shape for arr in arrays])
+ bcast_key = [xp.broadcast_to(arr, bcast_shape) for arr in arrays]
+
+ for idx in sh.ndindex(bcast_shape):
+ tpl = tuple(k[idx] for k in bcast_key)
+ assert out[idx] == x[tpl], f"failing at {idx = } w/ {key = }"
+
+
+def make_scalar_casting_param(
+ method_name: str, dtype: DataType, stype: ScalarType
+) -> Param:
+ dtype_name = dh.dtype_to_name[dtype]
+ return pytest.param(
+ method_name, dtype, stype, id=f"{method_name}({dtype_name})"
+ )
+
+
+@pytest.mark.parametrize(
+ "method_name, dtype, stype",
+ [make_scalar_casting_param("__bool__", xp.bool, bool)]
+ + [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_dtypes]
+ + [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_dtypes]
+ + [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_dtypes],
+)
+@given(data=st.data())
+def test_scalar_casting(method_name, dtype, stype, data):
+ x = data.draw(hh.arrays(dtype, shape=()), label="x")
+ method = getattr(x, method_name)
+ out = method()
+ assert isinstance(
+ out, stype
+ ), f"{method_name}({x})={out}, which is not a {stype.__name__} scalar"
diff --git a/array_api_tests/test_broadcasting.py b/array_api_tests/test_broadcasting.py
deleted file mode 100644
index b49434ff..00000000
--- a/array_api_tests/test_broadcasting.py
+++ /dev/null
@@ -1,129 +0,0 @@
-"""
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md
-"""
-
-from functools import reduce
-
-import pytest
-
-from hypothesis import given
-
-from .hypothesis_helpers import nonbroadcastable_ones_array_two_args
-from .pytest_helpers import raises, doesnt_raise, nargs
-
-from .function_stubs import elementwise_functions
-from . import _array_module
-
-# The spec does not specify what exception is raised on broadcast errors. We
-# use a custom exception to distinguish it from potential bugs in
-# broadcast_shapes().
-class BroadcastError(Exception):
- pass
-
-# The spec only specifies broadcasting for two shapes.
-def broadcast_shapes(shape1, shape2):
- """
- Broadcast shapes `shape1` and `shape2`.
-
- The code in this function should follow the pseudocode in the spec as
- closely as possible.
- """
- N1 = len(shape1)
- N2 = len(shape2)
- N = max(N1, N2)
- shape = [None]*N
- i = N - 1
- while i >= 0:
- n1 = N1 - N + i
- if N1 - N + i >= 0:
- d1 = shape1[n1]
- else:
- d1 = 1
- n2 = N2 - N + i
- if N2 - N + i >= 0:
- d2 = shape2[n2]
- else:
- d2 = 1
-
- if d1 == 1:
- shape[i] = d2
- elif d2 == 1:
- shape[i] = d1
- elif d1 == d2:
- shape[i] = d1
- else:
- raise BroadcastError
-
- i = i - 1
-
- return tuple(shape)
-
-def test_broadcast_shapes_explicit_spec():
- """
- Explicit broadcast shapes examples from the spec
- """
- shape1 = (8, 1, 6, 1)
- shape2 = (7, 1, 5)
- result = (8, 7, 6, 5)
- assert broadcast_shapes(shape1, shape2) == result
-
- shape1 = (5, 4)
- shape2 = (1,)
- result = (5, 4)
- assert broadcast_shapes(shape1, shape2) == result
-
- shape1 = (5, 4)
- shape2 = (4,)
- result = (5, 4)
- assert broadcast_shapes(shape1, shape2) == result
-
- shape1 = (15, 3, 5)
- shape2 = (15, 1, 5)
- result = (15, 3, 5)
- assert broadcast_shapes(shape1, shape2) == result
-
- shape1 = (15, 3, 5)
- shape2 = (3, 5)
- result = (15, 3, 5)
- assert broadcast_shapes(shape1, shape2) == result
-
- shape1 = (15, 3, 5)
- shape2 = (3, 1)
- result = (15, 3, 5)
- assert broadcast_shapes(shape1, shape2) == result
-
- shape1 = (3,)
- shape2 = (4,)
- raises(BroadcastError, lambda: broadcast_shapes(shape1, shape2)) # dimension does not match
-
- shape1 = (2, 1)
- shape2 = (8, 4, 3)
- raises(BroadcastError, lambda: broadcast_shapes(shape1, shape2)) # second dimension does not match
-
- shape1 = (15, 3, 5)
- shape2 = (15, 3)
- raises(BroadcastError, lambda: broadcast_shapes(shape1, shape2)) # singleton dimensions can only be prepended, not appended
-
-# TODO: Extend this to all functions (not just elementwise), and handle
-# functions that take more than 2 args
-@pytest.mark.parametrize('func_name', [i for i in
- elementwise_functions.__all__ if
- nargs(i) > 1])
-@given(args=nonbroadcastable_ones_array_two_args)
-def test_broadcasting_hypothesis(func_name, args):
- assert nargs(func_name) == 2
- func = getattr(_array_module, func_name)
-
- if isinstance(func, _array_module._UndefinedStub):
- func._raise()
-
- shapes = [i.shape for i in args]
- try:
- broadcast_shape = reduce(broadcast_shapes, shapes)
- except BroadcastError:
- raises(Exception, lambda: func(*args),
- f"{func_name} should raise an exception from not being able to broadcast inputs with shapes {shapes}")
- else:
- result = doesnt_raise(lambda: func(*args),
- f"{func_name} raised an unexpected exception from broadcastable inputs with shapes {shapes}")
- assert result.shape == broadcast_shape, "broadcast shapes incorrect"
diff --git a/array_api_tests/test_constants.py b/array_api_tests/test_constants.py
index d58580d2..145a2736 100644
--- a/array_api_tests/test_constants.py
+++ b/array_api_tests/test_constants.py
@@ -1,38 +1,56 @@
-from ._array_module import (e, inf, nan, pi, equal, isnan, abs, full, float32,
- float64, less, isinf, greater)
-from .array_helpers import one
+import math
+from typing import Any, SupportsFloat
-def test_e():
- # Check that e acts as a scalar
- E = full((1,), e, dtype=float64)
+import pytest
- # We don't require any accuracy. This is just a smoke test to check that
- # 'e' is actually the constant e.
- assert less(abs(E - 2.71), one((1,), dtype=float64)), "e is not the constant e"
+from . import dtype_helpers as dh
+from . import xp
+from .typing import Array
-def test_pi():
- # Check that pi acts as a scalar
- PI = full((1,), pi, dtype=float64)
- # We don't require any accuracy. This is just a smoke test to check that
- # 'pi' is actually the constant π.
- assert less(abs(PI - 3.14), one((1,), dtype=float64)), "pi is not the constant π"
+def assert_scalar_float(name: str, c: Any):
+ assert isinstance(c, SupportsFloat), f"{name}={c!r} does not look like a float"
-def test_inf():
- # Check that inf acts as a scalar
- INF = full((1,), inf, dtype=float64)
- assert isinf(inf), "inf is not infinity"
- assert isinf(INF), "inf is not infinity"
- assert greater(inf, 0), "inf is not positive"
- assert greater(INF, 0), "inf is not positive"
+def assert_0d_float(name: str, x: Array):
+ assert dh.is_float_dtype(
+ x.dtype
+ ), f"xp.asarray(xp.{name})={x!r}, but should have float dtype"
-def test_nan():
- # Check that nan acts as a scalar
- NAN = full((1,), nan, dtype=float64)
- assert isnan(nan), "nan is not Not a Number"
- assert isnan(NAN), "nan is not Not a Number"
+@pytest.mark.parametrize("name, n", [("e", math.e), ("pi", math.pi)])
+def test_irrational_numbers(name, n):
+ assert hasattr(xp, name)
+ c = getattr(xp, name)
+ assert_scalar_float(name, c)
+ floor = math.floor(n)
+ assert c > floor, f"xp.{name}={c!r} <= {floor}"
+ ceil = math.ceil(n)
+ assert c < ceil, f"xp.{name}={c!r} >= {ceil}"
+ x = xp.asarray(c)
+ assert_0d_float("name", x)
- assert not equal(nan, nan), "nan should be unequal to itself"
- assert not equal(NAN, NAN), "nan should be unequal to itself"
+
+def test_inf():
+ assert hasattr(xp, "inf")
+ assert_scalar_float("inf", xp.inf)
+ assert math.isinf(xp.inf)
+ assert xp.inf > 0, "xp.inf not greater than 0"
+ x = xp.asarray(xp.inf)
+ assert_0d_float("inf", x)
+ assert xp.isinf(x), "xp.isinf(xp.asarray(xp.inf))=False"
+
+
+def test_nan():
+ assert hasattr(xp, "nan")
+ assert_scalar_float("nan", xp.nan)
+ assert math.isnan(xp.nan)
+ assert xp.nan != xp.nan, "xp.nan should not have equality with itself"
+ x = xp.asarray(xp.nan)
+ assert_0d_float("nan", x)
+ assert xp.isnan(x), "xp.isnan(xp.asarray(xp.nan))=False"
+
+
+def test_newaxis():
+ assert hasattr(xp, "newaxis")
+ assert xp.newaxis is None
diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py
index 7a5feb8d..1f144c72 100644
--- a/array_api_tests/test_creation_functions.py
+++ b/array_api_tests/test_creation_functions.py
@@ -1,216 +1,592 @@
-from ._array_module import (arange, ceil, empty, _floating_dtypes, eye, full,
-equal, all, linspace, ones, zeros, isnan)
-from .array_helpers import (is_integer_dtype, dtype_ranges,
- assert_exactly_equal, isintegral)
-from .hypothesis_helpers import (numeric_dtypes, dtypes, MAX_ARRAY_SIZE,
- shapes, sizes, sqrt_sizes, shared_dtypes,
- shared_scalars)
-
-from hypothesis import assume, given
-from hypothesis.strategies import integers, floats, one_of, none, booleans
-
-int_range = integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE)
-float_range = floats(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE, allow_nan=False)
-@given(one_of(int_range, float_range),
- one_of(none(), int_range, float_range),
- one_of(none(), int_range, float_range).filter(lambda x: x != 0),
- one_of(none(), numeric_dtypes))
-def test_arange(start, stop, step, dtype):
- if dtype in dtype_ranges:
- m, M = dtype_ranges[dtype]
- if (not (m <= start <= M)
- or isinstance(stop, int) and not (m <= stop <= M)
- or isinstance(step, int) and not (m <= step <= M)):
- assume(False)
-
- kwargs = {} if dtype is None else {'dtype': dtype}
-
- all_int = (is_integer_dtype(dtype)
- and isinstance(start, int)
- and (stop is None or isinstance(stop, int))
- and (step is None or isinstance(step, int)))
-
+import cmath
+import math
+from itertools import count
+from typing import Iterator, NamedTuple, Union
+
+from hypothesis import assume, given, note
+from hypothesis import strategies as st
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import xps
+from .typing import DataType, Scalar
+
+
+class frange(NamedTuple):
+ start: float
+ stop: float
+ step: float
+
+ def __iter__(self) -> Iterator[float]:
+ pos_range = self.stop > self.start
+ pos_step = self.step > 0
+ if pos_step != pos_range:
+ return
+ if pos_range:
+ for n in count(self.start, self.step):
+ if n >= self.stop:
+ break
+ yield n
+ else:
+ for n in count(self.start, self.step):
+ if n <= self.stop:
+ break
+ yield n
+
+ def __len__(self) -> int:
+ return max(math.ceil((self.stop - self.start) / self.step), 0)
+
+
+# Testing xp.arange() requires bounding the start/stop/step arguments to only
+# test argument combinations compliant with the Array API, as well as to not
+# produce arrays with sizes not supproted by an array module.
+#
+# We first make sure generated integers can be represented by an array module's
+# default integer type, as even if a float array should be produced a module
+# might represent integer arguments as 0d arrays.
+#
+# This means that float arguments also need to be bound, so that they do not
+# require any integer arguments to be outside the representable bounds.
+int_min, int_max = dh.dtype_ranges[dh.default_int]
+float_min = float(int_min * (hh.MAX_ARRAY_SIZE - 1))
+float_max = float(int_max * (hh.MAX_ARRAY_SIZE - 1))
+
+
+def reals(min_value=None, max_value=None) -> st.SearchStrategy[Union[int, float]]:
+ round_ = int
+ if min_value is not None and min_value > 0:
+ round_ = math.ceil
+ elif max_value is not None and max_value < 0:
+ round_ = math.floor
+ int_min_value = int_min if min_value is None else max(round_(min_value), int_min)
+ int_max_value = int_max if max_value is None else min(round_(max_value), int_max)
+ return st.one_of(
+ st.integers(int_min_value, int_max_value),
+ # We do not assign float bounds to the floats() strategy, instead opting
+ # to filter out-of-bound values. Passing such min/max values will modify
+ # test case reduction behaviour so that simple bugs will become harder
+ # for users to identify. Hypothesis plans to improve floats() behaviour
+ # in https://github.com/HypothesisWorks/hypothesis/issues/2907
+ st.floats(min_value, max_value, allow_nan=False, allow_infinity=False).filter(
+ lambda n: float_min <= n <= float_max
+ ),
+ )
+
+
+@given(dtype=st.none() | hh.real_dtypes, data=st.data())
+def test_arange(dtype, data):
+ if dtype is None or dh.is_float_dtype(dtype):
+ start = data.draw(reals(), label="start")
+ stop = data.draw(reals() | st.none(), label="stop")
+ else:
+ start = data.draw(hh.from_dtype(dtype), label="start")
+ stop = data.draw(hh.from_dtype(dtype), label="stop")
if stop is None:
- # NB: "start" is really the stop
- # step is ignored in this case
- a = arange(start, **kwargs)
- if all_int:
- r = range(start)
- elif step is None:
- a = arange(start, stop, **kwargs)
- if all_int:
- r = range(start, stop)
+ _start = 0
+ _stop = start
else:
- a = arange(start, stop, step, **kwargs)
- if all_int:
- r = range(start, stop, step)
- if dtype is None:
- # TODO: What is the correct dtype of a?
- pass
- else:
- assert a.dtype == dtype, "arange() produced an incorrect dtype"
- assert a.ndim == 1, "arange() should return a 1-dimensional array"
- if all_int:
- assert a.shape == (len(r),), "arange() produced incorrect shape"
- if len(r) <= MAX_ARRAY_SIZE:
- assert list(a) == list(r), "arange() produced incorrect values"
- else:
- # This is already implied by the len(r) test above
- if (stop is not None
- and step is not None
- and (step > 0 and stop >= start
- or step < 0 and stop <= start)):
- assert a.size == ceil((stop-start)/step), "arange() produced an array of the incorrect size"
-
-@given(one_of(shapes, sizes), one_of(none(), dtypes))
-def test_empty(shape, dtype):
- if dtype is None:
- a = empty(shape)
- assert a.dtype in _floating_dtypes, "empty() should produce an array with the default floating point dtype"
+ _start = start
+ _stop = stop
+
+ # tol is the minimum tolerance for step values, used to avoid scenarios
+ # where xp.arange() produces arrays that would be over MAX_ARRAY_SIZE.
+ tol = max(abs(_stop - _start) / (math.sqrt(hh.MAX_ARRAY_SIZE)), 0.01)
+ assert tol != 0, "tol must not equal 0" # sanity check
+ assume(-tol > int_min)
+ assume(tol < int_max)
+ if dtype is None or dh.is_float_dtype(dtype):
+ step = data.draw(reals(min_value=tol) | reals(max_value=-tol), label="step")
else:
- a = empty(shape, dtype=dtype)
- assert a.dtype == dtype
-
- if isinstance(shape, int):
- shape = (shape,)
- assert a.shape == shape, "empty() produced an array with an incorrect shape"
+ step_strats = []
+ if dtype in dh.int_dtypes:
+ step_min = min(math.floor(-tol), -1)
+ step_strats.append(hh.from_dtype(dtype, max_value=step_min))
+ step_max = max(math.ceil(tol), 1)
+ step_strats.append(hh.from_dtype(dtype, min_value=step_max))
+ step = data.draw(st.one_of(step_strats), label="step")
+ assert step != 0, "step must not equal 0" # sanity check
-# TODO: implement empty_like (requires hypothesis arrays support)
-def test_empty_like():
- pass
+ all_int = all(arg is None or isinstance(arg, int) for arg in [start, stop, step])
-@given(sqrt_sizes, one_of(none(), sqrt_sizes), one_of(none(), integers()), numeric_dtypes)
-def test_eye(N, M, k, dtype):
- kwargs = {k: v for k, v in {'M': M, 'k': k, 'dtype': dtype}.items() if v
- is not None}
- a = eye(N, **kwargs)
if dtype is None:
- assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
+ if all_int:
+ _dtype = dh.default_int
+ else:
+ _dtype = dh.default_float
else:
- assert a.dtype == dtype, "eye() did not produce the correct dtype"
-
- if M is None:
- M = N
- assert a.shape == (N, M), "eye() produced an array with incorrect shape"
-
- if k is None:
- k = 0
- for i in range(N):
- for j in range(M):
- if j - i == k:
- assert a[i, j] == 1, "eye() did not produce a 1 on the diagonal"
- else:
- assert a[i, j] == 0, "eye() did not produce a 0 off the diagonal"
-
-@given(shapes, shared_scalars(), one_of(none(), shared_dtypes))
-def test_full(shape, fill_value, dtype):
- kwargs = {} if dtype is None else {'dtype': dtype}
+ _dtype = dtype
+
+ # sanity checks
+ if dh.is_int_dtype(_dtype):
+ m, M = dh.dtype_ranges[_dtype]
+ assert m <= _start <= M
+ assert m <= _stop <= M
+ assert m <= step <= M
+ # Ignore ridiculous distances so we don't fail like
+ #
+ # >>> torch.arange(9132051521638391890, 0, -91320515216383920)
+ # RuntimeError: invalid size, possible overflow?
+ #
+ assume(abs(_start - _stop) < M // 2)
+
+ r = frange(_start, _stop, step)
+ size = len(r)
+ assert (
+ size <= hh.MAX_ARRAY_SIZE
+ ), f"{size=} should be no more than {hh.MAX_ARRAY_SIZE}" # sanity check
+
+ args_samples = [(start, stop), (start, stop, step)]
+ if stop is None:
+ args_samples.insert(0, (start,))
+ args = data.draw(st.sampled_from(args_samples), label="args")
+ kvds = [hh.KVD("dtype", dtype, None)]
+ if len(args) != 3:
+ kvds.insert(0, hh.KVD("step", step, 1))
+ kwargs = data.draw(hh.specified_kwargs(*kvds), label="kwargs")
- a = full(shape, fill_value, **kwargs)
+ out = xp.arange(*args, **kwargs)
if dtype is None:
- # TODO: Should it actually match the fill_value?
- # assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
- pass
- else:
- assert a.dtype == dtype
-
- assert a.shape == shape, "full() produced an array with incorrect shape"
- if isnan(fill_value):
- assert all(isnan(a)), "full() array did not equal the fill value"
- else:
- assert all(equal(a, fill_value)), "full() array did not equal the fill value"
-
-# TODO: implement full_like (requires hypothesis arrays support)
-def test_full_like():
- pass
-
-@given(one_of(integers(), floats(allow_nan=False, allow_infinity=False)),
- one_of(integers(), floats(allow_nan=False, allow_infinity=False)),
- sizes,
- one_of(none(), dtypes),
- one_of(none(), booleans()),)
-def test_linspace(start, stop, num, dtype, endpoint):
- if dtype in dtype_ranges:
- m, M = dtype_ranges[dtype]
- if (isinstance(start, int) and not (m <= start <= M)
- or isinstance(stop, int) and not (m <= stop <= M)):
- assume(False)
- # Skip on int start or stop that cannot be exactly represented as a float,
- # since we do not have good approx_equal helpers yet.
- if (dtype is None or dtype in _floating_dtypes
- and ((isinstance(start, int) and not isintegral(start))
- or (isinstance(stop, int) and not isintegral(stop)))):
- assume(False)
-
- kwargs = {k: v for k, v in {'dtype': dtype, 'endpoint': endpoint}.items()
- if v is not None}
- a = linspace(start, stop, num, **kwargs)
-
+ if all_int:
+ ph.assert_default_int("arange", out.dtype)
+ else:
+ ph.assert_default_float("arange", out.dtype)
+ else:
+ ph.assert_kw_dtype("arange", kw_dtype=dtype, out_dtype=out.dtype)
+ f_sig = ", ".join(str(n) for n in args)
+ if len(kwargs) > 0:
+ f_sig += f", {ph.fmt_kw(kwargs)}"
+ f_func = f"[arange({f_sig})]"
+ assert out.ndim == 1, f"{out.ndim=}, but should be 1 [{f_func}]"
+ # We check size is roughly as expected to avoid edge cases e.g.
+ #
+ # >>> xp.arange(2, step=0.333333333333333)
+ # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66, 2.0]
+ # >>> xp.arange(2, step=0.3333333333333333)
+ # [0.0, 0.33, 0.66, 1.0, 1.33, 1.66]
+ #
+ # >>> start, stop, step = 0, 108086391056891901, 1080863910568919
+ # >>> x = xp.arange(start, stop, step, dtype=xp.uint64)
+ # >>> x.size
+ # 100
+ # >>> r = range(start, stop, step)
+ # >>> len(r)
+ # 101
+ #
+ min_size = math.floor(size * 0.9)
+ max_size = max(math.ceil(size * 1.1), 1)
+ out_size = math.prod(out.shape)
+ assert (
+ min_size <= out_size <= max_size
+ ), f"prod(out.shape)={out_size}, but should be roughly {size} {f_func}"
+ if dh.is_int_dtype(_dtype):
+ elements = list(r)
+ assume(out_size == len(elements))
+ ph.assert_array_elements("arange", out=out, expected=xp.asarray(elements, dtype=_dtype))
+ else:
+ assume(out_size == size)
+ if out_size > 0:
+ assert xp.equal(
+ out[0], xp.asarray(_start, dtype=out.dtype)
+ ), f"out[0]={out[0]}, but should be {_start} {f_func}"
+
+
+@given(shape=hh.shapes(min_side=1), data=st.data())
+def test_asarray_scalars(shape, data):
+ kw = data.draw(
+ hh.kwargs(dtype=st.none() | hh.all_dtypes, copy=st.none()), label="kw"
+ )
+ dtype = kw.get("dtype", None)
if dtype is None:
- assert a.dtype in _floating_dtypes, "linspace() should produce an array with the default floating point dtype"
+ dtype_family = data.draw(
+ st.sampled_from(
+ [(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)]
+ ),
+ label="expected out dtypes",
+ )
+ _dtype = dtype_family[0]
else:
- assert a.dtype == dtype, "linspace() did not produce the correct dtype"
-
- assert a.shape == (num,), "linspace() did not produce an array with the correct shape"
-
- if endpoint in [None, True]:
- if num > 1:
- assert all(equal(a[-1], full((), stop, dtype=dtype))), "linspace() produced an array that does not include the endpoint"
+ _dtype = dtype
+ if dh.is_float_dtype(_dtype):
+ elements_strat = hh.from_dtype(_dtype) | hh.from_dtype(xp.int32)
+ elif dh.is_int_dtype(_dtype):
+ elements_strat = hh.from_dtype(_dtype) | st.booleans()
else:
- # linspace(..., num, endpoint=False) is the same as the first num
- # elements of linspace(..., num+1, endpoint=True)
- b = linspace(start, stop, num + 1, **{**kwargs, 'endpoint': True})
- assert_exactly_equal(b[:-1], a)
-
- if num > 0:
- # We need to cast start to dtype
- assert all(equal(a[0], full((), start, dtype=dtype))), "linspace() produced an array that does not start with the start"
-
- # TODO: This requires an assert_approx_equal function
-
- # n = num - 1 if endpoint in [None, True] else num
- # for i in range(1, num):
- # assert all(equal(a[i], full((), i*(stop - start)/n + start, dtype=dtype))), f"linspace() produced an array with an incorrect value at index {i}"
-
-@given(shapes, one_of(none(), dtypes))
-def test_ones(shape, dtype):
- kwargs = {} if dtype is None else {'dtype': dtype}
+ elements_strat = hh.from_dtype(_dtype)
+ size = math.prod(shape)
+ obj_strat = st.lists(elements_strat, min_size=size, max_size=size)
+ scalar_type = dh.get_scalar_type(_dtype)
+ if dtype is None:
+ # For asarray to infer the dtype we're testing, obj requires at least
+ # one element to be the scalar equivalent of the inferred dtype, and so
+ # we filter out invalid examples. Note we use type() as Python booleans
+ # instance check with ints e.g. isinstance(False, int) == True.
+ obj_strat = obj_strat.filter(lambda l: any(type(e) == scalar_type for e in l))
+ _obj = data.draw(obj_strat, label="_obj")
+ obj = sh.reshape(_obj, shape)
+ note(f"{obj=}")
- a = ones(shape, **kwargs)
+ out = xp.asarray(obj, **kw)
if dtype is None:
- # TODO: Should it actually match the fill_value?
- # assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
- pass
+ msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be "
+ if dtype_family == (xp.float32, xp.float64):
+ msg += "default floating-point dtype (float32 or float64)"
+ elif dtype_family == (xp.int32, xp.int64):
+ msg += "default integer dtype (int32 or int64)"
+ else:
+ msg += "boolean dtype"
+ msg += " [asarray()]"
+ assert out.dtype in dtype_family, msg
else:
- assert a.dtype == dtype
-
- assert a.shape == shape, "ones() produced an array with incorrect shape"
- assert all(equal(a, full((), 1, **kwargs))), "ones() array did not equal 1"
-
-# TODO: implement ones_like (requires hypothesis arrays support)
-def test_ones_like():
- pass
+ assert kw["dtype"] == _dtype # sanity check
+ ph.assert_kw_dtype("asarray", kw_dtype=_dtype, out_dtype=out.dtype)
+ ph.assert_shape("asarray", out_shape=out.shape, expected=shape)
+ for idx, v_expect in zip(sh.ndindex(out.shape), _obj):
+ v = scalar_type(out[idx])
+ ph.assert_scalar_equals("asarray", type_=scalar_type, idx=idx, out=v, expected=v_expect, kw=kw)
+def scalar_eq(s1: Scalar, s2: Scalar) -> bool:
+ if cmath.isnan(s1):
+ return cmath.isnan(s2)
+ else:
+ return s1 == s2
+
+
+@given(
+ shape=hh.shapes(),
+ dtypes=hh.oneway_promotable_dtypes(dh.all_dtypes),
+ data=st.data(),
+)
+def test_asarray_arrays(shape, dtypes, data):
+ # generate arrays only since we draw the copy= kwd below (and np.asarray(scalar, copy=False) error out)
+ x = data.draw(hh.arrays_no_scalars(dtype=dtypes.input_dtype, shape=shape), label="x")
+ dtypes_strat = st.just(dtypes.input_dtype)
+ if dtypes.input_dtype == dtypes.result_dtype:
+ dtypes_strat |= st.none()
+ kw = data.draw(
+ hh.kwargs(dtype=dtypes_strat, copy=st.none() | st.booleans()),
+ label="kw",
+ )
+
+ out = xp.asarray(x, **kw)
+
+ dtype = kw.get("dtype", None)
+ if dtype is None:
+ ph.assert_dtype("asarray", in_dtype=x.dtype, out_dtype=out.dtype)
+ else:
+ ph.assert_kw_dtype("asarray", kw_dtype=dtype, out_dtype=out.dtype)
+ ph.assert_shape("asarray", out_shape=out.shape, expected=x.shape)
+ ph.assert_array_elements("asarray", out=out, expected=x, kw=kw)
+ copy = kw.get("copy", None)
+ if copy is not None:
+ stype = dh.get_scalar_type(x.dtype)
+ idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
+ old_value = stype(x[idx])
+ scalar_strat = hh.from_dtype(dtypes.input_dtype).filter(
+ lambda n: not scalar_eq(n, old_value)
+ )
+ value = data.draw(
+ scalar_strat | scalar_strat.map(lambda n: xp.asarray(n, dtype=x.dtype)),
+ label="mutating value",
+ )
+ x[idx] = value
+ note(f"mutated {x=}")
+ # sanity check
+ ph.assert_scalar_equals(
+ "__setitem__", type_=stype, idx=idx, out=stype(x[idx]), expected=value, repr_name="x"
+ )
+ new_out_value = stype(out[idx])
+ f_out = f"{sh.fmt_idx('out', idx)}={new_out_value}"
+ if copy:
+ assert scalar_eq(
+ new_out_value, old_value
+ ), f"{f_out}, but should be {old_value} even after x was mutated"
+ else:
+ assert scalar_eq(
+ new_out_value, value
+ ), f"{f_out}, but should be {value} after x was mutated"
+
+
+@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes))
+def test_empty(shape, kw):
+ out = xp.empty(shape, **kw)
+ if kw.get("dtype", None) is None:
+ ph.assert_default_float("empty", out.dtype)
+ else:
+ ph.assert_kw_dtype("empty", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("empty", out_shape=out.shape, expected=shape, kw=dict(shape=shape))
+
+
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()),
+ kw=hh.kwargs(dtype=st.none() | hh.all_dtypes),
+)
+def test_empty_like(x, kw):
+ out = xp.empty_like(x, **kw)
+ if kw.get("dtype", None) is None:
+ ph.assert_dtype("empty_like", in_dtype=x.dtype, out_dtype=out.dtype)
+ else:
+ ph.assert_kw_dtype("empty_like", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("empty_like", out_shape=out.shape, expected=x.shape)
+
+
+@given(
+ n_rows=hh.sqrt_sizes,
+ n_cols=st.none() | hh.sqrt_sizes,
+ kw=hh.kwargs(
+ k=st.integers(),
+ dtype=hh.numeric_dtypes,
+ ),
+)
+def test_eye(n_rows, n_cols, kw):
+ out = xp.eye(n_rows, n_cols, **kw)
+ if kw.get("dtype", None) is None:
+ ph.assert_default_float("eye", out.dtype)
+ else:
+ ph.assert_kw_dtype("eye", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ _n_cols = n_rows if n_cols is None else n_cols
+ ph.assert_shape("eye", out_shape=out.shape, expected=(n_rows, _n_cols), kw=dict(n_rows=n_rows, n_cols=n_cols))
+ k = kw.get("k", 0)
+ expected = xp.asarray(
+ [[1 if j - i == k else 0 for j in range(_n_cols)] for i in range(n_rows)],
+ dtype=out.dtype # Note: dtype already checked above.
+ )
+ if 0 in expected.shape:
+ expected = xp.reshape(expected, (n_rows, _n_cols))
+ ph.assert_array_elements("eye", out=out, expected=expected, kw=kw)
+
+
+default_unsafe_dtypes = [xp.uint64]
+if dh.default_int == xp.int32:
+ default_unsafe_dtypes.extend([xp.uint32, xp.int64])
+if dh.default_float == xp.float32:
+ default_unsafe_dtypes.append(xp.float64)
+if dh.default_complex == xp.complex64:
+ default_unsafe_dtypes.append(xp.complex64)
+default_safe_dtypes: st.SearchStrategy = hh.all_dtypes.filter(
+ lambda d: d not in default_unsafe_dtypes
+)
+
+
+@st.composite
+def full_fill_values(draw) -> Union[bool, int, float, complex]:
+ kw = draw(
+ st.shared(hh.kwargs(dtype=st.none() | hh.all_dtypes), key="full_kw")
+ )
+ dtype = kw.get("dtype", None) or draw(default_safe_dtypes)
+ return draw(hh.from_dtype(dtype))
+
+
+@given(
+ shape=hh.shapes(),
+ fill_value=full_fill_values(),
+ kw=st.shared(hh.kwargs(dtype=st.none() | hh.all_dtypes), key="full_kw"),
+)
+def test_full(shape, fill_value, kw):
+ with hh.reject_overflow():
+ out = xp.full(shape, fill_value, **kw)
+ if kw.get("dtype", None):
+ dtype = kw["dtype"]
+ elif isinstance(fill_value, bool):
+ dtype = xp.bool
+ elif isinstance(fill_value, int):
+ dtype = dh.default_int
+ elif isinstance(fill_value, float):
+ dtype = dh.default_float
+ else:
+ assert isinstance(fill_value, complex) # sanity check
+ dtype = dh.default_complex
+ # Ignore large components so we don't fail like
+ #
+ # >>> torch.fill(complex(0.0, 3.402823466385289e+38))
+ # RuntimeError: value cannot be converted to complex without overflow
+ #
+ M = dh.dtype_ranges[dh.dtype_components[dtype]].max
+ assume(all(abs(c) < math.sqrt(M) for c in [fill_value.real, fill_value.imag]))
+ if kw.get("dtype", None) is None:
+ if isinstance(fill_value, bool):
+ assert out.dtype == xp.bool, f"{out.dtype=}, but should be bool [full()]"
+ elif isinstance(fill_value, int):
+ ph.assert_default_int("full", out.dtype)
+ elif isinstance(fill_value, float):
+ ph.assert_default_float("full", out.dtype)
+ else:
+ assert isinstance(fill_value, complex) # sanity check
+ ph.assert_default_complex("full", out.dtype)
+ else:
+ ph.assert_kw_dtype("full", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("full", out_shape=out.shape, expected=shape, kw=dict(shape=shape))
+ ph.assert_fill("full", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value))
+
+
+@given(kw=hh.kwargs(dtype=st.none() | hh.all_dtypes), data=st.data())
+def test_full_like(kw, data):
+ dtype = kw.get("dtype", None) or data.draw(hh.all_dtypes, label="dtype")
+ x = data.draw(hh.arrays(dtype=dtype, shape=hh.shapes()), label="x")
+ fill_value = data.draw(hh.from_dtype(dtype), label="fill_value")
+ out = xp.full_like(x, fill_value, **kw)
+ dtype = kw.get("dtype", None) or x.dtype
+ if kw.get("dtype", None) is None:
+ ph.assert_dtype("full_like", in_dtype=x.dtype, out_dtype=out.dtype)
+ else:
+ ph.assert_kw_dtype("full_like", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("full_like", out_shape=out.shape, expected=x.shape)
+ ph.assert_fill("full_like", fill_value=fill_value, dtype=dtype, out=out, kw=dict(fill_value=fill_value))
+
+
+finite_kw = {"allow_nan": False, "allow_infinity": False}
+
+
+@given(
+ num=hh.sizes,
+ dtype=st.none() | hh.real_floating_dtypes,
+ endpoint=st.booleans(),
+ data=st.data(),
+)
+def test_linspace(num, dtype, endpoint, data):
+ _dtype = dh.default_float if dtype is None else dtype
+
+ start = data.draw(hh.from_dtype(_dtype, **finite_kw), label="start")
+ stop = data.draw(hh.from_dtype(_dtype, **finite_kw), label="stop")
+ # avoid overflow errors
+ assume(not xp.isnan(xp.asarray(stop - start, dtype=_dtype)))
+ assume(not xp.isnan(xp.asarray(start - stop, dtype=_dtype)))
+ # avoid generating very large distances
+ # https://github.com/data-apis/array-api-tests/issues/125
+ assume(abs(stop - start) < math.sqrt(dh.dtype_ranges[_dtype].max))
+
+ kw = data.draw(
+ hh.specified_kwargs(
+ hh.KVD("dtype", dtype, None),
+ hh.KVD("endpoint", endpoint, True),
+ ),
+ label="kw",
+ )
+ out = xp.linspace(start, stop, num, **kw)
-@given(shapes, one_of(none(), dtypes))
-def test_zeros(shape, dtype):
- kwargs = {} if dtype is None else {'dtype': dtype}
+ if dtype is None:
+ ph.assert_default_float("linspace", out.dtype)
+ else:
+ ph.assert_kw_dtype("linspace", kw_dtype=dtype, out_dtype=out.dtype)
+ ph.assert_shape("linspace", out_shape=out.shape, expected=num, kw=dict(start=start, stop=stop, num=num))
+ f_func = f"[linspace({start}, {stop}, {num})]"
+ if num > 0:
+ assert xp.equal(
+ out[0], xp.asarray(start, dtype=out.dtype)
+ ), f"out[0]={out[0]}, but should be {start} {f_func}"
+ if endpoint:
+ if num > 1:
+ assert xp.equal(
+ out[-1], xp.asarray(stop, dtype=out.dtype)
+ ), f"out[-1]={out[-1]}, but should be {stop} {f_func}"
+ else:
+ # linspace(..., num, endpoint=True) should return an array equivalent to
+ # the first num elements when endpoint=False
+ expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
+ expected = expected[:-1]
+ ph.assert_array_elements("linspace", out=out, expected=expected)
+
+
+@given(dtype=hh.numeric_dtypes, data=st.data())
+def test_meshgrid(dtype, data):
+ # The number and size of generated arrays is arbitrarily limited to prevent
+ # meshgrid() running out of memory.
+ shapes = data.draw(
+ st.integers(1, 5).flatmap(
+ lambda n: hh.mutually_broadcastable_shapes(
+ n, min_dims=1, max_dims=1, max_side=4
+ )
+ ),
+ label="shapes",
+ )
+ arrays = []
+ for i, shape in enumerate(shapes, 1):
+ x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}")
+ arrays.append(x)
+ # sanity check
+ # assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
+ out = xp.meshgrid(*arrays)
+ for i, x in enumerate(out):
+ ph.assert_dtype("meshgrid", in_dtype=dtype, out_dtype=x.dtype, repr_name=f"out[{i}].dtype")
+
+
+def make_one(dtype: DataType) -> Scalar:
+ if dtype is None or dh.is_float_dtype(dtype):
+ return 1.0
+ elif dh.is_int_dtype(dtype):
+ return 1
+ else:
+ return True
- a = zeros(shape, **kwargs)
- if dtype is None:
- # TODO: Should it actually match the fill_value?
- # assert a.dtype in _floating_dtypes, "eye() should produce an array with the default floating point dtype"
- pass
+@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes))
+def test_ones(shape, kw):
+ out = xp.ones(shape, **kw)
+ if kw.get("dtype", None) is None:
+ ph.assert_default_float("ones", out.dtype)
+ else:
+ ph.assert_kw_dtype("ones", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("ones", out_shape=out.shape, expected=shape,
+ kw={'shape': shape, **kw})
+ dtype = kw.get("dtype", None) or dh.default_float
+ ph.assert_fill("ones", fill_value=make_one(dtype), dtype=dtype, out=out, kw=kw)
+
+
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()),
+ kw=hh.kwargs(dtype=st.none() | hh.all_dtypes),
+)
+def test_ones_like(x, kw):
+ out = xp.ones_like(x, **kw)
+ if kw.get("dtype", None) is None:
+ ph.assert_dtype("ones_like", in_dtype=x.dtype, out_dtype=out.dtype)
else:
- assert a.dtype == dtype
+ ph.assert_kw_dtype("ones_like", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("ones_like", out_shape=out.shape, expected=x.shape, kw=kw)
+ dtype = kw.get("dtype", None) or x.dtype
+ ph.assert_fill("ones_like", fill_value=make_one(dtype), dtype=dtype,
+ out=out, kw=kw)
+
+
+def make_zero(dtype: DataType) -> Scalar:
+ if dtype is None or dh.is_float_dtype(dtype):
+ return 0.0
+ elif dh.is_int_dtype(dtype):
+ return 0
+ else:
+ return False
- assert a.shape == shape, "zeros() produced an array with incorrect shape"
- assert all(equal(a, full((), 0, **kwargs))), "zeros() array did not equal 0"
-# TODO: implement zeros_like (requires hypothesis arrays support)
-def test_zeros_like():
- pass
+@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.all_dtypes))
+def test_zeros(shape, kw):
+ out = xp.zeros(shape, **kw)
+ if kw.get("dtype", None) is None:
+ ph.assert_default_float("zeros", out_dtype=out.dtype)
+ else:
+ ph.assert_kw_dtype("zeros", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("zeros", out_shape=out.shape, expected=shape, kw={'shape': shape, **kw})
+ dtype = kw.get("dtype", None) or dh.default_float
+ ph.assert_fill("zeros", fill_value=make_zero(dtype), dtype=dtype, out=out,
+ kw=kw)
+
+
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()),
+ kw=hh.kwargs(dtype=st.none() | hh.all_dtypes),
+)
+def test_zeros_like(x, kw):
+ out = xp.zeros_like(x, **kw)
+ if kw.get("dtype", None) is None:
+ ph.assert_dtype("zeros_like", in_dtype=x.dtype, out_dtype=out.dtype)
+ else:
+ ph.assert_kw_dtype("zeros_like", kw_dtype=kw["dtype"], out_dtype=out.dtype)
+ ph.assert_shape("zeros_like", out_shape=out.shape, expected=x.shape,
+ kw=kw)
+ dtype = kw.get("dtype", None) or x.dtype
+ ph.assert_fill("zeros_like", fill_value=make_zero(dtype), dtype=dtype,
+ out=out, kw=kw)
diff --git a/array_api_tests/test_data_type_functions.py b/array_api_tests/test_data_type_functions.py
new file mode 100644
index 00000000..84e6f34c
--- /dev/null
+++ b/array_api_tests/test_data_type_functions.py
@@ -0,0 +1,289 @@
+import struct
+from typing import Union
+
+import pytest
+from hypothesis import given, assume
+from hypothesis import strategies as st
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import xps
+from .typing import DataType
+
+
+# TODO: test with complex dtypes
+def non_complex_dtypes():
+ return xps.boolean_dtypes() | hh.real_dtypes
+
+
+def float32(n: Union[int, float]) -> float:
+ return struct.unpack("!f", struct.pack("!f", float(n)))[0]
+
+
+def _float_match_complex(complex_dtype):
+ if complex_dtype == xp.complex64:
+ return xp.float32
+ elif complex_dtype == xp.complex128:
+ return xp.float64
+ else:
+ return dh.default_float
+
+
+@given(
+ x_dtype=hh.all_dtypes,
+ dtype=hh.all_dtypes,
+ kw=hh.kwargs(copy=st.booleans()),
+ data=st.data(),
+)
+def test_astype(x_dtype, dtype, kw, data):
+ _complex_dtypes = (xp.complex64, xp.complex128)
+
+ if xp.bool in (x_dtype, dtype):
+ elements_strat = hh.from_dtype(x_dtype)
+ else:
+
+ if dh.is_int_dtype(x_dtype):
+ cast = int
+ elif x_dtype in (xp.float32, xp.complex64):
+ cast = float32
+ else:
+ cast = float
+
+ real_dtype = x_dtype
+ if x_dtype in _complex_dtypes:
+ real_dtype = _float_match_complex(x_dtype)
+ m1, M1 = dh.dtype_ranges[real_dtype]
+
+ real_dtype = dtype
+ if dtype in _complex_dtypes:
+ real_dtype = _float_match_complex(x_dtype)
+ m2, M2 = dh.dtype_ranges[real_dtype]
+
+ min_value = cast(max(m1, m2))
+ max_value = cast(min(M1, M2))
+
+ elements_strat = hh.from_dtype(
+ x_dtype,
+ min_value=min_value,
+ max_value=max_value,
+ allow_nan=False,
+ allow_infinity=False,
+ )
+ x = data.draw(
+ hh.arrays(dtype=x_dtype, shape=hh.shapes(), elements=elements_strat), label="x"
+ )
+
+ # according to the spec, "Casting a complex floating-point array to a real-valued
+ # data type should not be permitted."
+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
+ assume(not ((x_dtype in _complex_dtypes) and (dtype not in _complex_dtypes)))
+
+ out = xp.astype(x, dtype, **kw)
+
+ ph.assert_kw_dtype("astype", kw_dtype=dtype, out_dtype=out.dtype)
+ ph.assert_shape("astype", out_shape=out.shape, expected=x.shape, kw=kw)
+ # TODO: test values
+ # TODO: test copy
+
+
+@given(
+ shapes=st.integers(1, 5).flatmap(hh.mutually_broadcastable_shapes), data=st.data()
+)
+def test_broadcast_arrays(shapes, data):
+ arrays = []
+ for c, shape in enumerate(shapes, 1):
+ x = data.draw(hh.arrays(dtype=hh.all_dtypes, shape=shape), label=f"x{c}")
+ arrays.append(x)
+
+ out = xp.broadcast_arrays(*arrays)
+
+ expected_shape = sh.broadcast_shapes(*shapes)
+ for i, x in enumerate(arrays):
+ ph.assert_dtype(
+ "broadcast_arrays",
+ in_dtype=x.dtype,
+ out_dtype=out[i].dtype,
+ repr_name=f"out[{i}].dtype"
+ )
+ ph.assert_result_shape(
+ "broadcast_arrays",
+ in_shapes=shapes,
+ out_shape=out[i].shape,
+ expected=expected_shape,
+ repr_name=f"out[{i}].shape",
+ )
+ # TODO: test values
+
+
+@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data())
+def test_broadcast_to(x, data):
+ shape = data.draw(
+ hh.mutually_broadcastable_shapes(1, base_shape=x.shape)
+ .map(lambda S: S[0])
+ .filter(lambda s: sh.broadcast_shapes(x.shape, s) == s),
+ label="shape",
+ )
+
+ out = xp.broadcast_to(x, shape)
+
+ ph.assert_dtype("broadcast_to", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("broadcast_to", out_shape=out.shape, expected=shape)
+ # TODO: test values
+
+
+@given(_from=hh.all_dtypes, to=hh.all_dtypes)
+def test_can_cast(_from, to):
+ out = xp.can_cast(_from, to)
+
+ expected = False
+ for other in dh.all_dtypes:
+ if dh.promotion_table.get((_from, other)) == to:
+ expected = True
+ break
+
+ f_func = f"[can_cast({dh.dtype_to_name[_from]}, {dh.dtype_to_name[to]})]"
+ if expected:
+ # cross-kind casting is not explicitly disallowed. We can only test
+ # the cases where it should return True. TODO: if expected=False,
+ # check that the array library actually allows such casts.
+ assert out == expected, f"{out=}, but should be {expected} {f_func}"
+
+
+@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
+def test_finfo(dtype):
+ for arg in (
+ dtype,
+ xp.asarray(1, dtype=dtype),
+ # np.float64 and np.asarray(1, dtype=np.float64).dtype are different
+ xp.asarray(1, dtype=dtype).dtype,
+ ):
+ out = xp.finfo(arg)
+ assert isinstance(out.bits, int)
+ assert isinstance(out.eps, float)
+ assert isinstance(out.max, float)
+ assert isinstance(out.min, float)
+ assert isinstance(out.smallest_normal, float)
+
+
+@pytest.mark.min_version("2022.12")
+@pytest.mark.parametrize("dtype", dh.real_float_dtypes + dh.complex_dtypes)
+def test_finfo_dtype(dtype):
+ out = xp.finfo(dtype)
+
+ if dtype == xp.complex64:
+ assert out.dtype == xp.float32
+ elif dtype == xp.complex128:
+ assert out.dtype == xp.float64
+ else:
+ assert out.dtype == dtype
+
+ # Guard vs. numpy.dtype.__eq__ lax comparison
+ assert not isinstance(out.dtype, str)
+ assert out.dtype is not float
+ assert out.dtype is not complex
+
+
+@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
+def test_iinfo(dtype):
+ for arg in (
+ dtype,
+ xp.asarray(1, dtype=dtype),
+ # np.int64 and np.asarray(1, dtype=np.int64).dtype are different
+ xp.asarray(1, dtype=dtype).dtype,
+ ):
+ out = xp.iinfo(arg)
+ assert isinstance(out.bits, int)
+ assert isinstance(out.max, int)
+ assert isinstance(out.min, int)
+
+
+@pytest.mark.min_version("2022.12")
+@pytest.mark.parametrize("dtype", dh.int_dtypes + dh.uint_dtypes)
+def test_iinfo_dtype(dtype):
+ out = xp.iinfo(dtype)
+ assert out.dtype == dtype
+ # Guard vs. numpy.dtype.__eq__ lax comparison
+ assert not isinstance(out.dtype, str)
+ assert out.dtype is not int
+
+
+def atomic_kinds() -> st.SearchStrategy[Union[DataType, str]]:
+ return hh.all_dtypes | st.sampled_from(list(dh.kind_to_dtypes.keys()))
+
+
+@pytest.mark.min_version("2022.12")
+@given(
+ dtype=hh.all_dtypes,
+ kind=atomic_kinds() | st.lists(atomic_kinds(), min_size=1).map(tuple),
+)
+def test_isdtype(dtype, kind):
+ out = xp.isdtype(dtype, kind)
+
+ assert isinstance(out, bool), f"{type(out)=}, but should be bool [isdtype()]"
+ _kinds = kind if isinstance(kind, tuple) else (kind,)
+ expected = False
+ for _kind in _kinds:
+ if isinstance(_kind, str):
+ if dtype in dh.kind_to_dtypes[_kind]:
+ expected = True
+ break
+ else:
+ if dtype == _kind:
+ expected = True
+ break
+ assert out == expected, f"{out=}, but should be {expected} [isdtype()]"
+
+
+@pytest.mark.min_version("2024.12")
+class TestResultType:
+ @given(dtypes=hh.mutually_promotable_dtypes(None))
+ def test_result_type(self, dtypes):
+ out = xp.result_type(*dtypes)
+ ph.assert_dtype("result_type", in_dtype=dtypes, out_dtype=out, repr_name="out")
+
+ @given(pair=hh.pair_of_mutually_promotable_dtypes(None))
+ def test_shuffled(self, pair):
+ """Test that result_type is insensitive to the order of arguments."""
+ s1, s2 = pair
+ out1 = xp.result_type(*s1)
+ out2 = xp.result_type(*s2)
+ assert out1 == out2
+
+ @given(pair=hh.pair_of_mutually_promotable_dtypes(2), data=st.data())
+ def test_arrays_and_dtypes(self, pair, data):
+ s1, s2 = pair
+ a2 = tuple(xp.empty(1, dtype=dt) for dt in s2)
+ a_and_dt = data.draw(st.permutations(s1 + a2))
+ out = xp.result_type(*a_and_dt)
+ ph.assert_dtype("result_type", in_dtype=s1+s2, out_dtype=out, repr_name="out")
+
+ @given(dtypes=hh.mutually_promotable_dtypes(2), data=st.data())
+ def test_with_scalars(self, dtypes, data):
+ out = xp.result_type(*dtypes)
+
+ if out == xp.bool:
+ scalars = [True]
+ elif out in dh.all_int_dtypes:
+ scalars = [1]
+ elif out in dh.real_dtypes:
+ scalars = [1, 1.0]
+ elif out in dh.numeric_dtypes:
+ scalars = [1, 1.0, 1j] # numeric_types - real_types == complex_types
+ else:
+ raise ValueError(f"unknown dtype {out = }.")
+
+ scalar = data.draw(st.sampled_from(scalars))
+ inputs = data.draw(st.permutations(dtypes + (scalar,)))
+
+ out_scalar = xp.result_type(*inputs)
+ assert out_scalar == out
+
+ # retry with arrays
+ arrays = tuple(xp.empty(1, dtype=dt) for dt in dtypes)
+ inputs = data.draw(st.permutations(arrays + (scalar,)))
+ out_scalar = xp.result_type(*inputs)
+ assert out_scalar == out
+
diff --git a/array_api_tests/test_fft.py b/array_api_tests/test_fft.py
new file mode 100644
index 00000000..358a8eef
--- /dev/null
+++ b/array_api_tests/test_fft.py
@@ -0,0 +1,307 @@
+import math
+from typing import List, Optional
+
+import pytest
+from hypothesis import assume, given
+from hypothesis import strategies as st
+
+from array_api_tests.typing import Array
+
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import xp
+
+pytestmark = [
+ pytest.mark.xp_extension("fft"),
+ pytest.mark.min_version("2022.12"),
+]
+
+fft_shapes_strat = hh.shapes(min_dims=1).filter(lambda s: math.prod(s) > 1)
+
+
+def draw_n_axis_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
+ size = math.prod(x.shape)
+ n = data.draw(
+ st.none() | st.integers((size // 2), math.ceil(size * 1.5)), label="n"
+ )
+ axis = data.draw(st.integers(-1, x.ndim - 1), label="axis")
+ if size_gt_1:
+ _axis = x.ndim - 1 if axis == -1 else axis
+ assume(x.shape[_axis] > 1)
+ norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
+ kwargs = data.draw(
+ hh.specified_kwargs(
+ ("n", n, None),
+ ("axis", axis, -1),
+ ("norm", norm, "backward"),
+ ),
+ label="kwargs",
+ )
+ return n, axis, norm, kwargs
+
+
+def draw_s_axes_norm_kwargs(x: Array, data: st.DataObject, *, size_gt_1=False) -> tuple:
+ all_axes = list(range(x.ndim))
+ axes = data.draw(
+ st.none() | st.lists(st.sampled_from(all_axes), min_size=1, unique=True),
+ label="axes",
+ )
+ _axes = all_axes if axes is None else axes
+ axes_sides = [x.shape[axis] for axis in _axes]
+ s_strat = st.tuples(
+ *[st.integers(max(side // 2, 1), math.ceil(side * 1.5)) for side in axes_sides]
+ )
+ if axes is None:
+ s_strat = st.none() | s_strat
+ s = data.draw(s_strat, label="s")
+
+ # Using `axes is None and s is not None` is disallowed by the spec
+ assume(axes is not None or s is None)
+
+ norm = data.draw(st.sampled_from(["backward", "ortho", "forward"]), label="norm")
+ kwargs = data.draw(
+ hh.specified_kwargs(
+ ("s", s, None),
+ ("axes", axes, None),
+ ("norm", norm, "backward"),
+ ),
+ label="kwargs",
+ )
+ return s, axes, norm, kwargs
+
+
+def assert_n_axis_shape(
+ func_name: str,
+ *,
+ x: Array,
+ n: Optional[int],
+ axis: int,
+ out: Array,
+):
+ _axis = len(x.shape) - 1 if axis == -1 else axis
+ if n is None:
+ axis_side = x.shape[_axis]
+ else:
+ axis_side = n
+ expected = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
+ ph.assert_shape(func_name, out_shape=out.shape, expected=expected)
+
+
+def assert_s_axes_shape(
+ func_name: str,
+ *,
+ x: Array,
+ s: Optional[List[int]],
+ axes: Optional[List[int]],
+ out: Array,
+):
+ _axes = sh.normalize_axis(axes, x.ndim)
+ _s = x.shape if s is None else s
+ expected = []
+ for i in range(x.ndim):
+ if i in _axes:
+ side = _s[_axes.index(i)]
+ else:
+ side = x.shape[i]
+ expected.append(side)
+ ph.assert_shape(func_name, out_shape=out.shape, expected=tuple(expected))
+
+
+@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_fft(x, data):
+ n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
+
+ out = xp.fft.fft(x, **kwargs)
+
+ ph.assert_dtype("fft", in_dtype=x.dtype, out_dtype=out.dtype)
+ assert_n_axis_shape("fft", x=x, n=n, axis=axis, out=out)
+
+
+@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_ifft(x, data):
+ n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
+
+ out = xp.fft.ifft(x, **kwargs)
+
+ ph.assert_dtype("ifft", in_dtype=x.dtype, out_dtype=out.dtype)
+ assert_n_axis_shape("ifft", x=x, n=n, axis=axis, out=out)
+
+
+@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_fftn(x, data):
+ s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
+
+ out = xp.fft.fftn(x, **kwargs)
+
+ ph.assert_dtype("fftn", in_dtype=x.dtype, out_dtype=out.dtype)
+ assert_s_axes_shape("fftn", x=x, s=s, axes=axes, out=out)
+
+
+@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_ifftn(x, data):
+ s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
+
+ out = xp.fft.ifftn(x, **kwargs)
+
+ ph.assert_dtype("ifftn", in_dtype=x.dtype, out_dtype=out.dtype)
+ assert_s_axes_shape("ifftn", x=x, s=s, axes=axes, out=out)
+
+
+@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_rfft(x, data):
+ n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
+
+ out = xp.fft.rfft(x, **kwargs)
+
+ ph.assert_float_to_complex_dtype("rfft", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ _axis = x.ndim - 1 if axis == -1 else axis
+ if n is None:
+ axis_side = x.shape[_axis] // 2 + 1
+ else:
+ axis_side = n // 2 + 1
+ expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
+ ph.assert_shape("rfft", out_shape=out.shape, expected=expected_shape)
+
+
+@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_irfft(x, data):
+ n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
+
+ out = xp.fft.irfft(x, **kwargs)
+
+ ph.assert_dtype(
+ "irfft",
+ in_dtype=x.dtype,
+ out_dtype=out.dtype,
+ expected=dh.dtype_components[x.dtype],
+ )
+
+ _axis = x.ndim - 1 if axis == -1 else axis
+ if n is None:
+ axis_side = 2 * (x.shape[_axis] - 1)
+ else:
+ axis_side = n
+ expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
+ ph.assert_shape("irfft", out_shape=out.shape, expected=expected_shape)
+
+
+@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_rfftn(x, data):
+ s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
+
+ out = xp.fft.rfftn(x, **kwargs)
+
+ ph.assert_float_to_complex_dtype("rfftn", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ _axes = sh.normalize_axis(axes, x.ndim)
+ _s = x.shape if s is None else s
+ expected = []
+ for i in range(x.ndim):
+ if i in _axes:
+ side = _s[_axes.index(i)]
+ else:
+ side = x.shape[i]
+ expected.append(side)
+ expected[_axes[-1]] = _s[-1] // 2 + 1
+ ph.assert_shape("rfftn", out_shape=out.shape, expected=tuple(expected))
+
+
+@given(
+ x=hh.arrays(
+ dtype=hh.complex_dtypes, shape=fft_shapes_strat.filter(lambda s: s[-1] > 1)
+ ),
+ data=st.data(),
+)
+def test_irfftn(x, data):
+ s, axes, norm, kwargs = draw_s_axes_norm_kwargs(x, data)
+
+ out = xp.fft.irfftn(x, **kwargs)
+
+ ph.assert_dtype(
+ "irfftn",
+ in_dtype=x.dtype,
+ out_dtype=out.dtype,
+ expected=dh.dtype_components[x.dtype],
+ )
+
+ _axes = sh.normalize_axis(axes, x.ndim)
+ _s = x.shape if s is None else s
+ expected = []
+ for i in range(x.ndim):
+ if i in _axes:
+ side = _s[_axes.index(i)]
+ else:
+ side = x.shape[i]
+ expected.append(side)
+ expected[_axes[-1]] = 2*(_s[-1] - 1) if s is None else _s[-1]
+ ph.assert_shape("irfftn", out_shape=out.shape, expected=tuple(expected))
+
+
+@given(x=hh.arrays(dtype=hh.complex_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_hfft(x, data):
+ n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data, size_gt_1=True)
+
+ out = xp.fft.hfft(x, **kwargs)
+
+ ph.assert_dtype(
+ "hfft",
+ in_dtype=x.dtype,
+ out_dtype=out.dtype,
+ expected=dh.dtype_components[x.dtype],
+ )
+
+ _axis = x.ndim - 1 if axis == -1 else axis
+ if n is None:
+ axis_side = 2 * (x.shape[_axis] - 1)
+ else:
+ axis_side = n
+ expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
+ ph.assert_shape("hfft", out_shape=out.shape, expected=expected_shape)
+
+
+@given(x=hh.arrays(dtype=hh.real_floating_dtypes, shape=fft_shapes_strat), data=st.data())
+def test_ihfft(x, data):
+ n, axis, norm, kwargs = draw_n_axis_norm_kwargs(x, data)
+
+ out = xp.fft.ihfft(x, **kwargs)
+
+ ph.assert_float_to_complex_dtype("ihfft", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ _axis = x.ndim - 1 if axis == -1 else axis
+ if n is None:
+ axis_side = x.shape[_axis] // 2 + 1
+ else:
+ axis_side = n // 2 + 1
+ expected_shape = x.shape[:_axis] + (axis_side,) + x.shape[_axis + 1 :]
+ ph.assert_shape("ihfft", out_shape=out.shape, expected=expected_shape)
+
+
+@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
+def test_fftfreq(n, kw):
+ out = xp.fft.fftfreq(n, **kw)
+ ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n})
+
+
+@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
+def test_rfftfreq(n, kw):
+ out = xp.fft.rfftfreq(n, **kw)
+ ph.assert_shape(
+ "rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n}
+ )
+
+
+@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])
+@given(x=hh.arrays(hh.floating_dtypes, fft_shapes_strat), data=st.data())
+def test_shift_func(func_name, x, data):
+ func = getattr(xp.fft, func_name)
+ axes = data.draw(
+ st.none()
+ | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
+ label="axes",
+ )
+ out = func(x, axes=axes)
+ ph.assert_dtype(func_name, in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape(func_name, out_shape=out.shape, expected=x.shape)
diff --git a/array_api_tests/test_has_names.py b/array_api_tests/test_has_names.py
new file mode 100644
index 00000000..8e934781
--- /dev/null
+++ b/array_api_tests/test_has_names.py
@@ -0,0 +1,37 @@
+"""
+This is a very basic test to see what names are defined in a library. It
+does not even require functioning hypothesis array_api support.
+"""
+
+import pytest
+
+from . import xp
+from .stubs import (array_attributes, array_methods, category_to_funcs,
+ extension_to_funcs, EXTENSIONS)
+
+has_name_params = []
+for ext, stubs in extension_to_funcs.items():
+ for stub in stubs:
+ has_name_params.append(pytest.param(ext, stub.__name__))
+for cat, stubs in category_to_funcs.items():
+ for stub in stubs:
+ has_name_params.append(pytest.param(cat, stub.__name__))
+for meth in array_methods:
+ has_name_params.append(pytest.param('array_method', meth.__name__))
+for attr in array_attributes:
+ has_name_params.append(pytest.param('array_attribute', attr))
+
+@pytest.mark.parametrize("category, name", has_name_params)
+def test_has_names(category, name):
+ if category in EXTENSIONS:
+ ext_mod = getattr(xp, category)
+ assert hasattr(ext_mod, name), f"{xp.__name__} is missing the {category} extension function {name}()"
+ elif category.startswith('array_'):
+ # TODO: This would fail if ones() is missing.
+ arr = xp.ones((1, 1))
+ if category == 'array_attribute':
+ assert hasattr(arr, name), f"The {xp.__name__} array object is missing the attribute {name}"
+ else:
+ assert hasattr(arr, name), f"The {xp.__name__} array object is missing the method {name}()"
+ else:
+ assert hasattr(xp, name), f"{xp.__name__} is missing the {category} function {name}()"
diff --git a/array_api_tests/test_indexing.py b/array_api_tests/test_indexing.py
deleted file mode 100644
index c74410f7..00000000
--- a/array_api_tests/test_indexing.py
+++ /dev/null
@@ -1,113 +0,0 @@
-"""
-https://data-apis.github.io/array-api/latest/API_specification/indexing.html
-
-For these tests, we only need arrays where each element is distinct, so we use
-arange().
-"""
-
-from hypothesis import given
-from hypothesis.strategies import shared
-
-from .array_helpers import assert_exactly_equal
-from .hypothesis_helpers import (slices, sizes, integer_indices, shapes, prod,
- multiaxis_indices)
-from .pytest_helpers import raises
-from ._array_module import arange, reshape
-
-@given(shared(sizes, key='array_sizes'), integer_indices(shared(sizes, key='array_sizes')))
-def test_integer_indexing(size, idx):
- # Test that indices on single dimensional arrays give the same result as
- # Python lists.
-
- # Sanity check that the strategies are working properly
- assert -size <= idx <= max(0, size - 1), "Sanity check failed. This indicates a bug in the test suite"
-
- a = arange(size)
- l = list(range(size))
- sliced_list = l[idx]
- sliced_array = a[idx]
-
- assert sliced_array.shape == (), "Integer indices should reduce the dimension by 1"
- assert sliced_array.dtype == a.dtype, "Integer indices should not change the dtype"
- assert sliced_array == sliced_list, "Integer index did not give the correct entry"
-
-@given(shared(sizes, key='array_sizes'), slices(shared(sizes, key='array_sizes')))
-def test_slicing(size, s):
- # Test that slices on arrays give the same result as Python lists.
-
- # Sanity check that the strategies are working properly
- if s.start is not None:
- assert -size <= s.start <= max(0, size - 1), "Sanity check failed. This indicates a bug in the test suite"
- if s.stop is not None:
- assert -size <= s.stop <= size, "Sanity check failed. This indicates a bug in the test suite"
-
- a = arange(size)
- l = list(range(size))
- sliced_list = l[s]
- sliced_array = a[s]
-
- assert len(sliced_list) == sliced_array.size, "Slice index did not give the same number of elements as slicing an equivalent Python list"
- assert sliced_array.shape == (sliced_array.size,), "Slice index did not give the correct shape"
- assert sliced_array.dtype == a.dtype, "Slice indices should not change the dtype"
- for i in range(len(sliced_list)):
- assert sliced_array[i] == sliced_list[i], "Slice index did not give the same elements as slicing an equivalent Python list"
-
-@given(shared(shapes, key='array_shapes'),
- multiaxis_indices(shapes=shared(shapes, key='array_shapes')))
-def test_multiaxis_indexing(shape, idx):
- # NOTE: Out of bounds indices (both integer and slices) are out of scope
- # for the spec. If you get a (valid) out of bounds error, it indicates a
- # bug in the multiaxis_indices strategy, which should only generate
- # indices that are not out of bounds.
- size = prod(shape)
- a = reshape(arange(size), shape)
-
- n_ellipses = idx.count(...)
- if n_ellipses > 1:
- raises(IndexError, lambda: a[idx],
- "Indices with more than one ellipsis should raise IndexError")
- return
- elif len(idx) - n_ellipses > len(shape):
- raises(IndexError, lambda: a[idx],
- "Tuple indices with more single axis expressions than the shape should raise IndexError")
- return
-
- sliced_array = a[idx]
- equiv_idx = idx
- if n_ellipses or len(idx) < len(shape):
- ellipsis_i = idx.index(...) if n_ellipses else len(idx)
- equiv_idx = (idx[:ellipsis_i]
- + (slice(None, None, None),)*(len(shape) - len(idx) + n_ellipses)
- + idx[ellipsis_i + 1:])
- # Sanity check
- assert len(equiv_idx) == len(shape), "Sanity check failed. This indicates a bug in the test suite"
- sliced_array2 = a[equiv_idx]
- assert_exactly_equal(sliced_array, sliced_array2)
-
- # TODO: We don't check that the exact entries are correct. Instead we
- # check the shape and other properties, and assume the single dimension
- # tests above are sufficient for testing the exact behavior of integer
- # indices and slices.
-
- # Check that the new shape is what it should be
- newshape = []
- for i, s in enumerate(equiv_idx):
- # Slices should retain the dimension. Integers remove a dimension.
- if isinstance(s, slice):
- newshape.append(len(range(shape[i])[s]))
- assert sliced_array.shape == tuple(newshape), "Index did not give the correct resulting array shape"
-
- # Check that integer indices i chose the same elements as the slice i:i+1
- equiv_idx2 = []
- for i, size in zip(equiv_idx, shape):
- if isinstance(i, int):
- if i >= 0:
- i = slice(i, i + 1)
- else:
- i = slice(size + i, size + i + 1)
- equiv_idx2.append(i)
- equiv_idx2 = tuple(equiv_idx2)
-
- sliced_array2 = a[equiv_idx2]
- assert sliced_array2.size == sliced_array.size, "Integer index not choosing the same elements as an equivalent slice"
- assert_exactly_equal(reshape(sliced_array2, sliced_array.shape), sliced_array)
diff --git a/array_api_tests/test_indexing_functions.py b/array_api_tests/test_indexing_functions.py
new file mode 100644
index 00000000..a599d218
--- /dev/null
+++ b/array_api_tests/test_indexing_functions.py
@@ -0,0 +1,122 @@
+import pytest
+from hypothesis import given, note
+from hypothesis import strategies as st
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+
+
+@pytest.mark.unvectorized
+@pytest.mark.min_version("2022.12")
+@given(
+ x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)),
+ data=st.data(),
+)
+def test_take(x, data):
+ # TODO:
+ # * negative axis
+ # * negative indices
+ # * different dtypes for indices
+ axis = data.draw(st.integers(0, max(x.ndim - 1, 0)), label="axis")
+ _indices = data.draw(
+ st.lists(st.integers(0, x.shape[axis] - 1), min_size=1, unique=True),
+ label="_indices",
+ )
+ indices = xp.asarray(_indices, dtype=dh.default_int)
+ note(f"{indices=}")
+
+ out = xp.take(x, indices, axis=axis)
+
+ ph.assert_dtype("take", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape(
+ "take",
+ out_shape=out.shape,
+ expected=x.shape[:axis] + (len(_indices),) + x.shape[axis + 1 :],
+ kw=dict(
+ x=x,
+ indices=indices,
+ axis=axis,
+ ),
+ )
+ out_indices = sh.ndindex(out.shape)
+ axis_indices = list(sh.axis_ndindex(x.shape, axis))
+ for axis_idx in axis_indices:
+ f_axis_idx = sh.fmt_idx("x", axis_idx)
+ for i in _indices:
+ f_take_idx = sh.fmt_idx(f_axis_idx, i)
+ indexed_x = x[axis_idx][i, ...]
+ for at_idx in sh.ndindex(indexed_x.shape):
+ out_idx = next(out_indices)
+ ph.assert_0d_equals(
+ "take",
+ x_repr=sh.fmt_idx(f_take_idx, at_idx),
+ x_val=indexed_x[at_idx],
+ out_repr=sh.fmt_idx("out", out_idx),
+ out_val=out[out_idx],
+ )
+ # sanity check
+ with pytest.raises(StopIteration):
+ next(out_indices)
+
+
+
+@pytest.mark.unvectorized
+@pytest.mark.min_version("2024.12")
+@given(
+ x=hh.arrays(hh.all_dtypes, hh.shapes(min_dims=1, min_side=1)),
+ data=st.data(),
+)
+def test_take_along_axis(x, data):
+ # TODO
+ # 2. negative indices
+ # 3. different dtypes for indices
+ # 4. "broadcast-compatible" indices
+ axis = data.draw(
+ st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
+ label="axis"
+ )
+ if axis is None:
+ axis_kw = {}
+ n_axis = x.ndim - 1
+ else:
+ axis_kw = {"axis": axis}
+ n_axis = axis + x.ndim if axis < 0 else axis
+
+ new_len = data.draw(st.integers(0, 2*x.shape[n_axis]), label="new_len")
+ idx_shape = x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:]
+ indices = data.draw(
+ hh.arrays(
+ shape=idx_shape,
+ dtype=dh.default_int,
+ elements={"min_value": 0, "max_value": x.shape[n_axis]-1}
+ ),
+ label="indices"
+ )
+ note(f"{indices=} {idx_shape=}")
+
+ out = xp.take_along_axis(x, indices, **axis_kw)
+
+ ph.assert_dtype("take_along_axis", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape(
+ "take_along_axis",
+ out_shape=out.shape,
+ expected=x.shape[:n_axis] + (new_len,) + x.shape[n_axis+1:],
+ kw=dict(
+ x=x,
+ indices=indices,
+ axis=axis,
+ ),
+ )
+
+ # value test: notation is from `np.take_along_axis` docstring
+ Ni, Nk = x.shape[:n_axis], x.shape[n_axis+1:]
+ for ii in sh.ndindex(Ni):
+ for kk in sh.ndindex(Nk):
+ a_1d = x[ii + (slice(None),) + kk]
+ i_1d = indices[ii + (slice(None),) + kk]
+ o_1d = out[ii + (slice(None),) + kk]
+ for j in range(new_len):
+ assert o_1d[j] == a_1d[i_1d[j]], f'{ii=}, {kk=}, {j=}'
diff --git a/array_api_tests/test_inspection_functions.py b/array_api_tests/test_inspection_functions.py
new file mode 100644
index 00000000..d210535e
--- /dev/null
+++ b/array_api_tests/test_inspection_functions.py
@@ -0,0 +1,81 @@
+import pytest
+from hypothesis import given, strategies as st
+from array_api_tests.dtype_helpers import available_kinds, dtype_names
+
+from . import xp
+
+pytestmark = pytest.mark.min_version("2023.12")
+
+
+class TestInspection:
+ def test_capabilities(self):
+ out = xp.__array_namespace_info__()
+
+ capabilities = out.capabilities()
+ assert isinstance(capabilities, dict)
+
+ expected_attr = {"boolean indexing": bool, "data-dependent shapes": bool}
+ if xp.__array_api_version__ >= "2024.12":
+ expected_attr.update(**{"max dimensions": type(None) | int})
+
+ for attr, typ in expected_attr.items():
+ assert attr in capabilities, f'capabilites is missing "{attr}".'
+ assert isinstance(capabilities[attr], typ)
+
+ max_dims = capabilities.get("max dimensions", 100500)
+ assert (max_dims is None) or (max_dims > 0)
+
+ def test_devices(self):
+ out = xp.__array_namespace_info__()
+
+ assert hasattr(out, "devices")
+ assert hasattr(out, "default_device")
+
+ assert isinstance(out.devices(), list)
+ if out.default_device() is not None:
+ # Per https://github.com/data-apis/array-api/issues/923
+ # default_device() can return None. Otherwise, it must be a valid device.
+ assert out.default_device() in out.devices()
+
+ def test_default_dtypes(self):
+ out = xp.__array_namespace_info__()
+
+ for device in xp.__array_namespace_info__().devices():
+ default_dtypes = out.default_dtypes(device=device)
+ assert isinstance(default_dtypes, dict)
+ expected_subset = (
+ {"real floating", "complex floating", "integral"}
+ & available_kinds()
+ | {"indexing"}
+ )
+ assert expected_subset.issubset(set(default_dtypes.keys()))
+
+
+atomic_kinds = [
+ "bool",
+ "signed integer",
+ "unsigned integer",
+ "real floating",
+ "complex floating",
+]
+
+
+@given(
+ kind=st.one_of(
+ st.none(),
+ st.sampled_from(atomic_kinds + ["integral", "numeric"]),
+ st.lists(st.sampled_from(atomic_kinds), unique=True, min_size=1).map(tuple),
+ ),
+ device=st.one_of(
+ st.none(),
+ st.sampled_from(xp.__array_namespace_info__().devices())
+ )
+)
+def test_array_namespace_info_dtypes(kind, device):
+ out = xp.__array_namespace_info__().dtypes(kind=kind, device=device)
+ assert isinstance(out, dict)
+
+ for name, dtyp in out.items():
+ assert name in dtype_names
+ xp.empty(1, dtype=dtyp, device=device) # check `dtyp` is a valid dtype
+
diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py
new file mode 100644
index 00000000..6f4608da
--- /dev/null
+++ b/array_api_tests/test_linalg.py
@@ -0,0 +1,1022 @@
+"""
+Tests for linalg functions
+
+https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html
+
+and
+
+https://data-apis.org/array-api/latest/extensions/linear_algebra_functions.html
+
+Note: this file currently mixes both the required linear algebra functions and
+functions from the linalg extension. The functions in the latter are not
+required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25).
+
+"""
+import pytest
+from hypothesis import assume, given
+from hypothesis.strategies import (booleans, composite, tuples, floats,
+ integers, shared, sampled_from, one_of,
+ data)
+from ndindex import iter_indices
+
+import math
+import itertools
+from typing import Tuple
+
+from .array_helpers import assert_exactly_equal, asarray
+from .hypothesis_helpers import (arrays, all_floating_dtypes, all_dtypes,
+ numeric_dtypes, xps, shapes, kwargs,
+ matrix_shapes, square_matrix_shapes,
+ symmetric_matrices, SearchStrategy,
+ positive_definite_matrices, MAX_ARRAY_SIZE,
+ invertible_matrices, two_mutual_arrays,
+ mutually_promotable_dtypes, one_d_shapes,
+ two_mutually_broadcastable_shapes,
+ mutually_broadcastable_shapes,
+ SQRT_MAX_ARRAY_SIZE, finite_matrices,
+ rtol_shared_matrix_shapes, rtols, axes)
+from . import dtype_helpers as dh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import api_version
+from .typing import Array
+
+from . import _array_module
+from . import _array_module as xp
+from ._array_module import linalg
+
+
+def assert_equal(x, y, msg_extra=None):
+ extra = '' if not msg_extra else f' ({msg_extra})'
+ if x.dtype in dh.all_float_dtypes:
+ # It's too difficult to do an approximately equal test here because
+ # different routines can give completely different answers, and even
+ # when it does work, the elementwise comparisons are too slow. So for
+ # floating-point dtypes only test the shape and dtypes.
+
+ # assert_allclose(x, y)
+
+ assert x.shape == y.shape, f"The input arrays do not have the same shapes ({x.shape} != {y.shape}){extra}"
+ assert x.dtype == y.dtype, f"The input arrays do not have the same dtype ({x.dtype} != {y.dtype}){extra}"
+ else:
+ assert_exactly_equal(x, y, msg_extra=msg_extra)
+
+
+def _test_stacks(f, *args, res=None, dims=2, true_val=None,
+ matrix_axes=(-2, -1),
+ res_axes=None,
+ assert_equal=assert_equal, **kw):
+ """
+ Test that f(*args, **kw) maps across stacks of matrices
+
+ dims is the number of dimensions f(*args, *kw) should have for a single n
+ x m matrix stack.
+
+ matrix_axes are the axes along which matrices (or vectors) are stacked in
+ the input.
+
+ true_val may be a function such that true_val(*x_stacks, **kw) gives the
+ true value for f on a stack.
+
+ res should be the result of f(*args, **kw). It is computed if not passed
+ in.
+
+ """
+ if res is None:
+ res = f(*args, **kw)
+
+ shapes = [x.shape for x in args]
+
+ # Assume the result is stacked along the last 'dims' axes of matrix_axes.
+ # This holds for all the functions tested in this file
+ if res_axes is None:
+ if not isinstance(matrix_axes, tuple) and all(isinstance(x, int) for x in matrix_axes):
+ raise ValueError("res_axes must be specified if matrix_axes is not a tuple of integers")
+ res_axes = matrix_axes[::-1][:dims]
+
+ for (x_idxes, (res_idx,)) in zip(
+ iter_indices(*shapes, skip_axes=matrix_axes),
+ iter_indices(res.shape, skip_axes=res_axes)):
+ x_idxes = [x_idx.raw for x_idx in x_idxes]
+ res_idx = res_idx.raw
+
+ res_stack = res[res_idx]
+ x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)]
+ decomp_res_stack = f(*x_stacks, **kw)
+ msg_extra = f'{x_idxes = }, {res_idx = }'
+ assert_equal(res_stack, decomp_res_stack, msg_extra)
+ if true_val:
+ assert_equal(decomp_res_stack, true_val(*x_stacks, **kw), msg_extra)
+
+
+def _test_namedtuple(res, fields, func_name):
+ """
+ Test that res is a namedtuple with the correct fields.
+ """
+ # isinstance(namedtuple) doesn't work, and it could be either
+ # collections.namedtuple or typing.NamedTuple. So we just check that it is
+ # a tuple subclass with the right fields in the right order.
+
+ assert isinstance(res, tuple), f"{func_name}() did not return a tuple"
+ assert type(res) != tuple, f"{func_name}() did not return a namedtuple"
+ assert len(res) == len(fields), f"{func_name}() result tuple not the correct length (should have {len(fields)} elements)"
+ for i, field in enumerate(fields):
+ assert hasattr(res, field), f"{func_name}() result namedtuple doesn't have the '{field}' field"
+ assert res[i] is getattr(res, field), f"{func_name}() result namedtuple '{field}' field is not in position {i}"
+
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=positive_definite_matrices(),
+ kw=kwargs(upper=booleans())
+)
+def test_cholesky(x, kw):
+ res = linalg.cholesky(x, **kw)
+
+ ph.assert_dtype("cholesky", in_dtype=x.dtype, out_dtype=res.dtype)
+ ph.assert_result_shape("cholesky", in_shapes=[x.shape],
+ out_shape=res.shape, expected=x.shape)
+
+ _test_stacks(linalg.cholesky, x, **kw, res=res)
+
+ # Test that the result is upper or lower triangular
+ if kw.get('upper', False):
+ assert_exactly_equal(res, _array_module.triu(res))
+ else:
+ assert_exactly_equal(res, _array_module.tril(res))
+
+
+@composite
+def cross_args(draw, dtype_objects=dh.real_dtypes):
+ """
+ cross() requires two arrays with a size 3 in the 'axis' dimension
+
+ To do this, we generate a shape and an axis but change the shape to be 3
+ in the drawn axis.
+
+ """
+ shape1, shape2 = draw(two_mutually_broadcastable_shapes)
+ min_ndim = min(len(shape1), len(shape2))
+ assume(min_ndim > 0)
+
+ kw = draw(kwargs(axis=integers(-min_ndim, -1)))
+ axis = kw.get('axis', -1)
+ if draw(booleans()):
+ # Sometimes allow invalid inputs to test it errors
+ shape1 = list(shape1)
+ shape1[axis] = 3
+ shape1 = tuple(shape1)
+ shape2 = list(shape2)
+ shape2[axis] = 3
+ shape2 = tuple(shape2)
+
+ mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects))
+ arrays1 = arrays(
+ dtype=mutual_dtypes.map(lambda pair: pair[0]),
+ shape=shape1,
+ )
+ arrays2 = arrays(
+ dtype=mutual_dtypes.map(lambda pair: pair[1]),
+ shape=shape2,
+ )
+ return draw(arrays1), draw(arrays2), kw
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ cross_args()
+)
+def test_cross(x1_x2_kw):
+ x1, x2, kw = x1_x2_kw
+
+ axis = kw.get('axis', -1)
+ if not (x1.shape[axis] == x2.shape[axis] == 3):
+ ph.raises(Exception, lambda: xp.cross(x1, x2, **kw),
+ "cross did not raise an exception for invalid shapes")
+ return
+
+ res = linalg.cross(x1, x2, **kw)
+
+ broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
+
+ ph.assert_dtype("cross", in_dtype=[x1.dtype, x2.dtype],
+ out_dtype=res.dtype)
+ ph.assert_result_shape("cross", in_shapes=[x1.shape, x2.shape], out_shape=res.shape, expected=broadcasted_shape)
+
+ def exact_cross(a, b):
+ assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
+ return asarray(xp.stack([
+ a[1]*b[2] - a[2]*b[1],
+ a[2]*b[0] - a[0]*b[2],
+ a[0]*b[1] - a[1]*b[0],
+ ]), dtype=res.dtype)
+
+ # We don't want to pass in **kw here because that would pass axis to
+ # cross() on a single stack, but the axis is not meaningful on unstacked
+ # vectors.
+ _test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes),
+)
+def test_det(x):
+ res = linalg.det(x)
+
+ ph.assert_dtype("det", in_dtype=x.dtype, out_dtype=res.dtype)
+ ph.assert_result_shape("det", in_shapes=[x.shape], out_shape=res.shape,
+ expected=x.shape[:-2])
+
+ _test_stacks(linalg.det, x, res=res, dims=0)
+
+ # TODO: Test that res actually corresponds to the determinant of x
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=arrays(dtype=all_dtypes, shape=matrix_shapes()),
+ # offset may produce an overflow if it is too large. Supporting offsets
+ # that are way larger than the array shape isn't very important.
+ kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE))
+)
+def test_diagonal(x, kw):
+ res = linalg.diagonal(x, **kw)
+
+ ph.assert_dtype("diagonal", in_dtype=x.dtype, out_dtype=res.dtype)
+
+ n, m = x.shape[-2:]
+ offset = kw.get('offset', 0)
+ # Note: the spec does not specify that offset must be within the bounds of
+ # the matrix. A large offset should just produce a size 0 in the last
+ # dimension.
+ if offset < 0:
+ diag_size = min(n, m, max(n + offset, 0))
+ elif offset == 0:
+ diag_size = min(n, m)
+ else:
+ diag_size = min(n, m, max(m - offset, 0))
+
+ expected_shape = (*x.shape[:-2], diag_size)
+ ph.assert_result_shape("diagonal", in_shapes=[x.shape],
+ out_shape=res.shape, expected=expected_shape)
+
+ def true_diag(x_stack, offset=0):
+ if offset >= 0:
+ x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)]
+ else:
+ x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
+ return asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype)
+
+ _test_stacks(linalg.diagonal, x, **kw, res=res, dims=1, true_val=true_diag)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(x=symmetric_matrices(finite=True))
+def test_eigh(x):
+ res = linalg.eigh(x)
+
+ _test_namedtuple(res, ['eigenvalues', 'eigenvectors'], 'eigh')
+
+ eigenvalues = res.eigenvalues
+ eigenvectors = res.eigenvectors
+
+ ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvalues.dtype,
+ expected=x.dtype, repr_name="eigenvalues.dtype")
+ ph.assert_result_shape("eigh", in_shapes=[x.shape],
+ out_shape=eigenvalues.shape,
+ expected=x.shape[:-1],
+ repr_name="eigenvalues.shape")
+
+ ph.assert_dtype("eigh", in_dtype=x.dtype, out_dtype=eigenvectors.dtype,
+ expected=x.dtype, repr_name="eigenvectors.dtype")
+ ph.assert_result_shape("eigh", in_shapes=[x.shape],
+ out_shape=eigenvectors.shape, expected=x.shape,
+ repr_name="eigenvectors.shape")
+
+ # Note: _test_stacks here is only testing the shape and dtype. The actual
+ # eigenvalues and eigenvectors may not be equal at all, since there is not
+ # requirements about how eigh computes an eigenbasis, or about the order
+ # of the eigenvalues
+ _test_stacks(lambda x: linalg.eigh(x).eigenvalues, x,
+ res=eigenvalues, dims=1)
+
+ # TODO: Test that eigenvectors are orthonormal.
+
+ _test_stacks(lambda x: linalg.eigh(x).eigenvectors, x,
+ res=eigenvectors, dims=2)
+
+ # TODO: Test that res actually corresponds to the eigenvalues and
+ # eigenvectors of x
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(x=symmetric_matrices(finite=True))
+def test_eigvalsh(x):
+ res = linalg.eigvalsh(x)
+
+ ph.assert_dtype("eigvalsh", in_dtype=x.dtype, out_dtype=res.dtype)
+ ph.assert_result_shape("eigvalsh", in_shapes=[x.shape],
+ out_shape=res.shape, expected=x.shape[:-1])
+
+ # Note: _test_stacks here is only testing the shape and dtype. The actual
+ # eigenvalues may not be equal at all, since there is not requirements or
+ # about the order of the eigenvalues, and the stacking code may use a
+ # different code path.
+ _test_stacks(linalg.eigvalsh, x, res=res, dims=1)
+
+ # TODO: Should we test that the result is the same as eigh(x).eigenvalues?
+ # (probably no because the spec doesn't actually require that)
+
+ # TODO: Test that res actually corresponds to the eigenvalues of x
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(x=invertible_matrices())
+def test_inv(x):
+ res = linalg.inv(x)
+
+ ph.assert_dtype("inv", in_dtype=x.dtype, out_dtype=res.dtype)
+ ph.assert_result_shape("inv", in_shapes=[x.shape], out_shape=res.shape,
+ expected=x.shape)
+
+ _test_stacks(linalg.inv, x, res=res)
+
+ # TODO: Test that the result is actually the inverse
+
+def _test_matmul(namespace, x1, x2):
+ matmul = namespace.matmul
+
+ # TODO: Make this also test the @ operator
+ if (x1.shape == () or x2.shape == ()
+ or len(x1.shape) == len(x2.shape) == 1 and x1.shape != x2.shape
+ or len(x1.shape) == 1 and len(x2.shape) >= 2 and x1.shape[0] != x2.shape[-2]
+ or len(x2.shape) == 1 and len(x1.shape) >= 2 and x2.shape[0] != x1.shape[-1]
+ or len(x1.shape) >= 2 and len(x2.shape) >= 2 and x1.shape[-1] != x2.shape[-2]):
+ # The spec doesn't specify what kind of exception is used here. Most
+ # libraries will use a custom exception class.
+ ph.raises(Exception, lambda: _array_module.matmul(x1, x2),
+ "matmul did not raise an exception for invalid shapes")
+ return
+ else:
+ res = matmul(x1, x2)
+
+ ph.assert_dtype("matmul", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
+
+ if len(x1.shape) == len(x2.shape) == 1:
+ ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
+ out_shape=res.shape, expected=())
+ elif len(x1.shape) == 1:
+ ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
+ out_shape=res.shape,
+ expected=x2.shape[:-2] + x2.shape[-1:])
+ _test_stacks(matmul, x1, x2, res=res, dims=1,
+ matrix_axes=[(0,), (-2, -1)], res_axes=[-1])
+ elif len(x2.shape) == 1:
+ ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
+ out_shape=res.shape, expected=x1.shape[:-1])
+ _test_stacks(matmul, x1, x2, res=res, dims=1,
+ matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
+ else:
+ stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
+ ph.assert_result_shape("matmul", in_shapes=[x1.shape, x2.shape],
+ out_shape=res.shape,
+ expected=stack_shape + (x1.shape[-2], x2.shape[-1]))
+ _test_stacks(matmul, x1, x2, res=res)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ *two_mutual_arrays(dh.real_dtypes)
+)
+def test_linalg_matmul(x1, x2):
+ return _test_matmul(linalg, x1, x2)
+
+@pytest.mark.unvectorized
+@given(
+ *two_mutual_arrays(dh.real_dtypes)
+)
+def test_matmul(x1, x2):
+ return _test_matmul(_array_module, x1, x2)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=finite_matrices(),
+ kw=kwargs(keepdims=booleans(),
+ ord=sampled_from([-float('inf'), -2, -1, 1, 2, float('inf'), 'fro', 'nuc']))
+)
+def test_matrix_norm(x, kw):
+ res = linalg.matrix_norm(x, **kw)
+
+ keepdims = kw.get('keepdims', False)
+ # TODO: Check that the ord values give the correct norms.
+ # ord = kw.get('ord', 'fro')
+
+ if keepdims:
+ expected_shape = x.shape[:-2] + (1, 1)
+ else:
+ expected_shape = x.shape[:-2]
+ ph.assert_complex_to_float_dtype("matrix_norm", in_dtype=x.dtype,
+ out_dtype=res.dtype)
+ ph.assert_result_shape("matrix_norm", in_shapes=[x.shape],
+ out_shape=res.shape, expected=expected_shape)
+
+ _test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0,
+ res=res)
+
+matrix_power_n = shared(integers(-100, 100), key='matrix_power n')
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ # Generate any square matrix if n >= 0 but only invertible matrices if n < 0
+ x=matrix_power_n.flatmap(lambda n: invertible_matrices() if n < 0 else
+ arrays(dtype=all_floating_dtypes(),
+ shape=square_matrix_shapes)),
+ n=matrix_power_n,
+)
+def test_matrix_power(x, n):
+ res = linalg.matrix_power(x, n)
+
+ ph.assert_dtype("matrix_power", in_dtype=x.dtype, out_dtype=res.dtype)
+ ph.assert_result_shape("matrix_power", in_shapes=[x.shape],
+ out_shape=res.shape, expected=x.shape)
+
+ if n == 0:
+ true_val = lambda x: _array_module.eye(x.shape[0], dtype=x.dtype)
+ else:
+ true_val = None
+ # _test_stacks only works with array arguments
+ func = lambda x: linalg.matrix_power(x, n)
+ _test_stacks(func, x, res=res, true_val=true_val)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=finite_matrices(shape=rtol_shared_matrix_shapes),
+ kw=kwargs(rtol=rtols)
+)
+def test_matrix_rank(x, kw):
+ linalg.matrix_rank(x, **kw)
+
+def _test_matrix_transpose(namespace, x):
+ matrix_transpose = namespace.matrix_transpose
+ res = matrix_transpose(x)
+ true_val = lambda a: _array_module.asarray([[a[i, j] for i in
+ range(a.shape[0])] for j in
+ range(a.shape[1])],
+ dtype=a.dtype)
+ shape = list(x.shape)
+ shape[-1], shape[-2] = shape[-2], shape[-1]
+ shape = tuple(shape)
+ ph.assert_dtype("matrix_transpose", in_dtype=x.dtype, out_dtype=res.dtype)
+ ph.assert_result_shape("matrix_transpose", in_shapes=[x.shape],
+ out_shape=res.shape, expected=shape)
+
+ _test_stacks(matrix_transpose, x, res=res, true_val=true_val)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=arrays(dtype=all_dtypes, shape=matrix_shapes()),
+)
+def test_linalg_matrix_transpose(x):
+ return _test_matrix_transpose(linalg, x)
+
+@pytest.mark.unvectorized
+@given(
+ x=arrays(dtype=all_dtypes, shape=matrix_shapes()),
+)
+def test_matrix_transpose(x):
+ return _test_matrix_transpose(_array_module, x)
+
+@pytest.mark.xp_extension('linalg')
+@given(
+ *two_mutual_arrays(dtypes=dh.real_dtypes,
+ two_shapes=tuples(one_d_shapes, one_d_shapes))
+)
+def test_outer(x1, x2):
+ # outer does not work on stacks. See
+ # https://github.com/data-apis/array-api/issues/242.
+ res = linalg.outer(x1, x2)
+
+ shape = (x1.shape[0], x2.shape[0])
+ ph.assert_dtype("outer", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
+ ph.assert_result_shape("outer", in_shapes=[x1.shape, x2.shape],
+ out_shape=res.shape, expected=shape)
+
+ if 0 in shape:
+ true_res = _array_module.empty(shape, dtype=res.dtype)
+ else:
+ true_res = _array_module.asarray([[x1[i]*x2[j]
+ for j in range(x2.shape[0])]
+ for i in range(x1.shape[0])],
+ dtype=res.dtype)
+
+ assert_exactly_equal(res, true_res)
+
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=finite_matrices(shape=rtol_shared_matrix_shapes),
+ kw=kwargs(rtol=rtols)
+)
+def test_pinv(x, kw):
+ linalg.pinv(x, **kw)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=arrays(dtype=all_floating_dtypes(), shape=matrix_shapes()),
+ kw=kwargs(mode=sampled_from(['reduced', 'complete']))
+)
+def test_qr(x, kw):
+ res = linalg.qr(x, **kw)
+ mode = kw.get('mode', 'reduced')
+
+ M, N = x.shape[-2:]
+ K = min(M, N)
+
+ _test_namedtuple(res, ['Q', 'R'], 'qr')
+ Q = res.Q
+ R = res.R
+
+ ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=Q.dtype,
+ expected=x.dtype, repr_name="Q.dtype")
+ if mode == 'complete':
+ expected_Q_shape = x.shape[:-2] + (M, M)
+ else:
+ expected_Q_shape = x.shape[:-2] + (M, K)
+ ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=Q.shape,
+ expected=expected_Q_shape, repr_name="Q.shape")
+
+ ph.assert_dtype("qr", in_dtype=x.dtype, out_dtype=R.dtype,
+ expected=x.dtype, repr_name="R.dtype")
+ if mode == 'complete':
+ expected_R_shape = x.shape[:-2] + (M, N)
+ else:
+ expected_R_shape = x.shape[:-2] + (K, N)
+ ph.assert_result_shape("qr", in_shapes=[x.shape], out_shape=R.shape,
+ expected=expected_R_shape, repr_name="R.shape")
+
+ _test_stacks(lambda x: linalg.qr(x, **kw).Q, x, res=Q)
+ _test_stacks(lambda x: linalg.qr(x, **kw).R, x, res=R)
+
+ # TODO: Test that Q is orthonormal
+
+ # Check that R is upper-triangular.
+ assert_exactly_equal(R, _array_module.triu(R))
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes),
+)
+def test_slogdet(x):
+ res = linalg.slogdet(x)
+
+ _test_namedtuple(res, ['sign', 'logabsdet'], 'slotdet')
+
+ sign, logabsdet = res
+
+ ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=sign.dtype,
+ expected=x.dtype, repr_name="sign.dtype")
+ ph.assert_result_shape("slogdet", in_shapes=[x.shape],
+ out_shape=sign.shape,
+ expected=x.shape[:-2],
+ repr_name="sign.shape")
+ expected_dtype = dh.as_real_dtype(x.dtype)
+ ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=logabsdet.dtype,
+ expected=expected_dtype, repr_name="logabsdet.dtype")
+ ph.assert_result_shape("slogdet", in_shapes=[x.shape],
+ out_shape=logabsdet.shape,
+ expected=x.shape[:-2],
+ repr_name="logabsdet.shape")
+
+ _test_stacks(lambda x: linalg.slogdet(x).sign, x,
+ res=sign, dims=0)
+ _test_stacks(lambda x: linalg.slogdet(x).logabsdet, x,
+ res=logabsdet, dims=0)
+
+ # Check that when the determinant is 0, the sign and logabsdet are (0,
+ # -inf).
+ # TODO: This test does not necessarily hold exactly. Update it to test it
+ # approximately.
+ # d = linalg.det(x)
+ # zero_det = equal(d, zero(d.shape, d.dtype))
+ # assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype))
+ # assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype))
+
+ # More generally, det(x) should equal sign*exp(logabsdet), but this does
+ # not hold exactly due to floating-point loss of precision.
+
+ # TODO: Test this when we have tests for floating-point values.
+ # assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps)
+
+def solve_args() -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]:
+ """
+ Strategy for the x1 and x2 arguments to test_solve()
+
+ solve() takes x1, x2, where x1 is any stack of square invertible matrices
+ of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
+ where the ... parts of x1 and x2 are broadcast compatible.
+ """
+ mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dh.all_float_dtypes))
+
+ stack_shapes = shared(two_mutually_broadcastable_shapes)
+ # Don't worry about dtypes since all floating dtypes are type promotable
+ # with each other.
+ x1 = shared(invertible_matrices(
+ stack_shapes=stack_shapes.map(lambda pair: pair[0]),
+ dtypes=mutual_dtypes.map(lambda pair: pair[0])))
+
+ @composite
+ def _x2_shapes(draw):
+ base_shape = draw(stack_shapes)[1] + draw(x1).shape[-1:]
+ end = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(base_shape), 1)))
+ return base_shape + (end,)
+
+ x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
+ x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1]))
+ return x1, x2
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(*solve_args())
+def test_solve(x1, x2):
+ res = linalg.solve(x1, x2)
+
+ ph.assert_dtype("solve", in_dtype=[x1.dtype, x2.dtype], out_dtype=res.dtype)
+ if x2.ndim == 1:
+ expected_shape = x1.shape[:-2] + x2.shape[-1:]
+ _test_stacks(linalg.solve, x1, x2, res=res, dims=1,
+ matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
+ else:
+ stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
+ expected_shape = stack_shape + x2.shape[-2:]
+ _test_stacks(linalg.solve, x1, x2, res=res, dims=2)
+
+ ph.assert_result_shape("solve", in_shapes=[x1.shape, x2.shape],
+ out_shape=res.shape, expected=expected_shape)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=finite_matrices(),
+ kw=kwargs(full_matrices=booleans())
+)
+def test_svd(x, kw):
+ res = linalg.svd(x, **kw)
+ full_matrices = kw.get('full_matrices', True)
+
+ *stack, M, N = x.shape
+ K = min(M, N)
+
+ _test_namedtuple(res, ['U', 'S', 'Vh'], 'svd')
+
+ U, S, Vh = res
+
+ ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=U.dtype,
+ expected=x.dtype, repr_name="U.dtype")
+ ph.assert_complex_to_float_dtype("svd", in_dtype=x.dtype,
+ out_dtype=S.dtype, repr_name="S.dtype")
+ ph.assert_dtype("svd", in_dtype=x.dtype, out_dtype=Vh.dtype,
+ expected=x.dtype, repr_name="Vh.dtype")
+
+ if full_matrices:
+ expected_U_shape = (*stack, M, M)
+ expected_Vh_shape = (*stack, N, N)
+ else:
+ expected_U_shape = (*stack, M, K)
+ expected_Vh_shape = (*stack, K, N)
+ ph.assert_result_shape("svd", in_shapes=[x.shape],
+ out_shape=U.shape,
+ expected=expected_U_shape,
+ repr_name="U.shape")
+ ph.assert_result_shape("svd", in_shapes=[x.shape],
+ out_shape=Vh.shape,
+ expected=expected_Vh_shape,
+ repr_name="Vh.shape")
+ ph.assert_result_shape("svd", in_shapes=[x.shape],
+ out_shape=S.shape,
+ expected=(*stack, K),
+ repr_name="S.shape")
+
+ # The values of s must be sorted from largest to smallest
+ if K >= 1:
+ assert _array_module.all(S[..., :-1] >= S[..., 1:]), "svd().S values are not sorted from largest to smallest"
+
+ _test_stacks(lambda x: linalg.svd(x, **kw).U, x, res=U)
+ _test_stacks(lambda x: linalg.svd(x, **kw).S, x, dims=1, res=S)
+ _test_stacks(lambda x: linalg.svd(x, **kw).Vh, x, res=Vh)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=finite_matrices(),
+)
+def test_svdvals(x):
+ res = linalg.svdvals(x)
+
+ *stack, M, N = x.shape
+ K = min(M, N)
+
+ ph.assert_complex_to_float_dtype("svdvals", in_dtype=x.dtype,
+ out_dtype=res.dtype)
+ ph.assert_result_shape("svdvals", in_shapes=[x.shape],
+ out_shape=res.shape,
+ expected=(*stack, K))
+
+ # SVD values must be sorted from largest to smallest
+ assert _array_module.all(res[..., :-1] >= res[..., 1:]), "svdvals() values are not sorted from largest to smallest"
+
+ _test_stacks(linalg.svdvals, x, dims=1, res=res)
+
+ # TODO: Check that svdvals() is the same as svd().s.
+
+_tensordot_pre_shapes = shared(two_mutually_broadcastable_shapes)
+
+@composite
+def _tensordot_axes(draw):
+ shape1, shape2 = draw(_tensordot_pre_shapes)
+ ndim1, ndim2 = len(shape1), len(shape2)
+ isint = draw(booleans())
+
+ if isint:
+ N = min(ndim1, ndim2)
+ return draw(integers(0, N))
+ else:
+ if ndim1 < ndim2:
+ first = draw(xps.valid_tuple_axes(ndim1))
+ second = draw(xps.valid_tuple_axes(ndim2, min_size=len(first),
+ max_size=len(first)))
+ else:
+ second = draw(xps.valid_tuple_axes(ndim2))
+ first = draw(xps.valid_tuple_axes(ndim1, min_size=len(second),
+ max_size=len(second)))
+ return (tuple(first), tuple(second))
+
+tensordot_kw = shared(kwargs(axes=_tensordot_axes()))
+
+@composite
+def tensordot_shapes(draw):
+ _shape1, _shape2 = map(list, draw(_tensordot_pre_shapes))
+ ndim1, ndim2 = len(_shape1), len(_shape2)
+ kw = draw(tensordot_kw)
+ if 'axes' not in kw:
+ assume(ndim1 >= 2 and ndim2 >= 2)
+ axes = kw.get('axes', 2)
+
+ if isinstance(axes, int):
+ axes = [list(range(-axes, 0)), list(range(0, axes))]
+
+ first, second = axes
+ for i, j in zip(first, second):
+ try:
+ if -ndim2 <= j < ndim2 and _shape2[j] != 1:
+ _shape1[i] = _shape2[j]
+ if -ndim1 <= i < ndim1 and _shape1[i] != 1:
+ _shape2[j] = _shape1[i]
+ except:
+ raise
+
+ shape1, shape2 = map(tuple, [_shape1, _shape2])
+ return (shape1, shape2)
+
+def _test_tensordot_stacks(x1, x2, kw, res):
+ """
+ Variant of _test_stacks for tensordot
+
+ tensordot doesn't stack directly along the non-contracted dimensions like
+ the other linalg functions. Rather, it is stacked along the product of
+ each non-contracted dimension. These dimensions are independent of one
+ another and do not broadcast.
+ """
+ shape1, shape2 = x1.shape, x2.shape
+
+ axes = kw.get('axes', 2)
+
+ if isinstance(axes, int):
+ res_axes = axes
+ axes = [list(range(-axes, 0)), list(range(0, axes))]
+ else:
+ # Convert something like (0, 4, 2) into (0, 2, 1)
+ res_axes = []
+ for a, s in zip(axes, [shape1, shape2]):
+ indices = [range(len(s))[i] for i in a]
+ repl = dict(zip(sorted(indices), range(len(indices))))
+ res_axes.append(tuple(repl[i] for i in indices))
+ res_axes = tuple(res_axes)
+
+ for ((i,), (j,)), (res_idx,) in zip(
+ itertools.product(
+ iter_indices(shape1, skip_axes=axes[0]),
+ iter_indices(shape2, skip_axes=axes[1])),
+ iter_indices(res.shape)):
+ i, j, res_idx = i.raw, j.raw, res_idx.raw
+
+ res_stack = res[res_idx]
+ x1_stack = x1[i]
+ x2_stack = x2[j]
+ decomp_res_stack = xp.tensordot(x1_stack, x2_stack, axes=res_axes)
+ assert_equal(res_stack, decomp_res_stack)
+
+def _test_tensordot(namespace, x1, x2, kw):
+ tensordot = namespace.tensordot
+ res = tensordot(x1, x2, **kw)
+
+ ph.assert_dtype("tensordot", in_dtype=[x1.dtype, x2.dtype],
+ out_dtype=res.dtype)
+
+ axes = _axes = kw.get('axes', 2)
+
+ if isinstance(axes, int):
+ _axes = [list(range(-axes, 0)), list(range(0, axes))]
+
+ _shape1 = list(x1.shape)
+ _shape2 = list(x2.shape)
+ for i, j in zip(*_axes):
+ _shape1[i] = _shape2[j] = None
+ _shape1 = tuple([i for i in _shape1 if i is not None])
+ _shape2 = tuple([i for i in _shape2 if i is not None])
+ result_shape = _shape1 + _shape2
+ ph.assert_result_shape('tensordot', [x1.shape, x2.shape], res.shape,
+ expected=result_shape)
+ _test_tensordot_stacks(x1, x2, kw, res)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ *two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
+ tensordot_kw,
+)
+def test_linalg_tensordot(x1, x2, kw):
+ _test_tensordot(linalg, x1, x2, kw)
+
+@pytest.mark.unvectorized
+@given(
+ *two_mutual_arrays(dh.numeric_dtypes, two_shapes=tensordot_shapes()),
+ tensordot_kw,
+)
+def test_tensordot(x1, x2, kw):
+ _test_tensordot(_array_module, x1, x2, kw)
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=arrays(dtype=numeric_dtypes, shape=matrix_shapes()),
+ # offset may produce an overflow if it is too large. Supporting offsets
+ # that are way larger than the array shape isn't very important.
+ kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE))
+)
+def test_trace(x, kw):
+ res = linalg.trace(x, **kw)
+
+ dtype = kw.get("dtype", None)
+ expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
+ if expected_dtype is None:
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
+ # uint32 or uint64), we skip testing the output dtype.
+ # See https://github.com/data-apis/array-api-tests/issues/160
+ if x.dtype in dh.uint_dtypes:
+ assert dh.is_int_dtype(res.dtype) # sanity check
+ elif api_version < "2023.12": # TODO: update dtype assertion for >2023.12 - see #234
+ ph.assert_dtype("trace", in_dtype=x.dtype, out_dtype=res.dtype, expected=expected_dtype)
+
+ n, m = x.shape[-2:]
+ ph.assert_result_shape('trace', x.shape, res.shape, expected=x.shape[:-2])
+
+ def true_trace(x_stack, offset=0):
+ # Note: the spec does not specify that offset must be within the
+ # bounds of the matrix. A large offset should just produce a size 0
+ # diagonal in the last dimension (trace 0). See test_diagonal().
+ if offset < 0:
+ diag_size = min(n, m, max(n + offset, 0))
+ elif offset == 0:
+ diag_size = min(n, m)
+ else:
+ diag_size = min(n, m, max(m - offset, 0))
+
+ if offset >= 0:
+ x_stack_diag = [x_stack[i, i + offset] for i in range(diag_size)]
+ else:
+ x_stack_diag = [x_stack[i - offset, i] for i in range(diag_size)]
+ result = xp.asarray(xp.stack(x_stack_diag) if x_stack_diag else [], dtype=x.dtype)
+ return _array_module.sum(result)
+
+
+ _test_stacks(linalg.trace, x, **kw, res=res, dims=0, true_val=true_trace)
+
+
+def _conj(x):
+ # XXX: replace with xp.dtype when all array libraries implement it
+ if x.dtype in (xp.complex64, xp.complex128):
+ return xp.conj(x)
+ else:
+ return x
+
+
+def _test_vecdot(namespace, x1, x2, data):
+ vecdot = namespace.vecdot
+ broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
+ min_ndim = min(x1.ndim, x2.ndim)
+ ndim = len(broadcasted_shape)
+ kw = data.draw(kwargs(axis=integers(-min_ndim, -1)))
+ axis = kw.get('axis', -1)
+ x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
+ x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
+ if x1_shape[axis] != x2_shape[axis]:
+ ph.raises(Exception, lambda: vecdot(x1, x2, **kw),
+ "vecdot did not raise an exception for invalid shapes")
+ return
+ expected_shape = list(broadcasted_shape)
+ expected_shape.pop(axis)
+ expected_shape = tuple(expected_shape)
+
+ res = vecdot(x1, x2, **kw)
+
+ ph.assert_dtype("vecdot", in_dtype=[x1.dtype, x2.dtype],
+ out_dtype=res.dtype)
+ ph.assert_result_shape("vecdot", in_shapes=[x1.shape, x2.shape],
+ out_shape=res.shape, expected=expected_shape)
+
+ def true_val(x, y, axis=-1):
+ return xp.sum(xp.multiply(_conj(x), y), dtype=res.dtype)
+
+ _test_stacks(vecdot, x1, x2, res=res, dims=0,
+ matrix_axes=(axis,), true_val=true_val)
+
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
+ data(),
+)
+def test_linalg_vecdot(x1, x2, data):
+ _test_vecdot(linalg, x1, x2, data)
+
+
+@pytest.mark.unvectorized
+@given(
+ *two_mutual_arrays(dh.numeric_dtypes, mutually_broadcastable_shapes(2, min_dims=1)),
+ data(),
+)
+def test_vecdot(x1, x2, data):
+ _test_vecdot(_array_module, x1, x2, data)
+
+
+@pytest.mark.xp_extension('linalg')
+def test_vecdot_conj():
+ # no-hypothesis test to check that the 1st argument is in fact conjugated
+ x1 = xp.asarray([1j, 2j, 3j])
+ x2 = xp.asarray([1, 2j, 3])
+
+ import cmath
+ assert cmath.isclose(complex(xp.linalg.vecdot(x1, x2)), 4 - 10j)
+
+
+# Insanely large orders might not work. There isn't a limit specified in the
+# spec, so we just limit to reasonable values here.
+max_ord = 100
+
+
+@pytest.mark.unvectorized
+@pytest.mark.xp_extension('linalg')
+@given(
+ x=arrays(dtype=all_floating_dtypes(), shape=shapes(min_side=1)),
+ data=data(),
+)
+def test_vector_norm(x, data):
+ kw = data.draw(
+ # We use data because axes is parameterized on x.ndim
+ kwargs(axis=axes(x.ndim),
+ keepdims=booleans(),
+ ord=one_of(
+ sampled_from([2, 1, 0, -1, -2, float("inf"), float("-inf")]),
+ integers(-max_ord, max_ord),
+ floats(-max_ord, max_ord),
+ )), label="kw")
+
+
+ res = linalg.vector_norm(x, **kw)
+ axis = kw.get('axis', None)
+ keepdims = kw.get('keepdims', False)
+ # TODO: Check that the ord values give the correct norms.
+ # ord = kw.get('ord', 2)
+
+ _axes = sh.normalize_axis(axis, x.ndim)
+
+ ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape,
+ in_shape=x.shape, axes=_axes,
+ keepdims=keepdims, kw=kw)
+ expected_dtype = dh.as_real_dtype(x.dtype)
+ ph.assert_dtype('linalg.vector_norm', in_dtype=x.dtype,
+ out_dtype=res.dtype, expected=expected_dtype)
+
+ _kw = kw.copy()
+ _kw.pop('axis', None)
+ _test_stacks(linalg.vector_norm, x, res=res,
+ dims=x.ndim if keepdims else 0,
+ matrix_axes=_axes, **_kw
+ )
diff --git a/array_api_tests/test_manipulation_functions.py b/array_api_tests/test_manipulation_functions.py
new file mode 100644
index 00000000..754b507d
--- /dev/null
+++ b/array_api_tests/test_manipulation_functions.py
@@ -0,0 +1,513 @@
+import math
+from collections import deque
+from typing import Iterable, Iterator, Tuple, Union
+
+import pytest
+from hypothesis import assume, given
+from hypothesis import strategies as st
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import xps
+from .typing import Array, Shape
+
+
+def shared_shapes(*args, **kwargs) -> st.SearchStrategy[Shape]:
+ key = "shape"
+ if args:
+ key += " " + " ".join(args)
+ if kwargs:
+ key += " " + ph.fmt_kw(kwargs)
+ return st.shared(hh.shapes(*args, **kwargs), key="shape")
+
+
+def assert_array_ndindex(
+ func_name: str,
+ x: Array,
+ *,
+ x_indices: Iterable[Union[int, Shape]],
+ out: Array,
+ out_indices: Iterable[Union[int, Shape]],
+ kw: dict = {},
+):
+ msg_suffix = f" [{func_name}({ph.fmt_kw(kw)})]\n {x=}\n{out=}"
+ for x_idx, out_idx in zip(x_indices, out_indices):
+ msg = f"out[{out_idx}]={out[out_idx]}, should be x[{x_idx}]={x[x_idx]}"
+ msg += msg_suffix
+ if dh.is_float_dtype(x.dtype) and xp.isnan(x[x_idx]):
+ assert xp.isnan(out[out_idx]), msg
+ else:
+ assert out[out_idx] == x[x_idx], msg
+
+
+@pytest.mark.unvectorized
+@given(
+ dtypes=hh.mutually_promotable_dtypes(None, dtypes=dh.numeric_dtypes),
+ base_shape=hh.shapes(),
+ data=st.data(),
+)
+def test_concat(dtypes, base_shape, data):
+ axis_strat = st.none()
+ ndim = len(base_shape)
+ if ndim > 0:
+ axis_strat |= st.integers(-ndim, ndim - 1)
+ kw = data.draw(
+ axis_strat.flatmap(lambda a: hh.specified_kwargs(("axis", a, 0))), label="kw"
+ )
+ axis = kw.get("axis", 0)
+ if axis is None:
+ _axis = None
+ shape_strat = hh.shapes()
+ else:
+ _axis = axis if axis >= 0 else len(base_shape) + axis
+ shape_strat = st.integers(0, hh.MAX_SIDE).map(
+ lambda i: base_shape[:_axis] + (i,) + base_shape[_axis + 1 :]
+ )
+ arrays = []
+ for i, dtype in enumerate(dtypes, 1):
+ x = data.draw(hh.arrays(dtype=dtype, shape=shape_strat), label=f"x{i}")
+ arrays.append(x)
+
+ out = xp.concat(arrays, **kw)
+
+ ph.assert_dtype("concat", in_dtype=dtypes, out_dtype=out.dtype)
+
+ shapes = tuple(x.shape for x in arrays)
+ if _axis is None:
+ size = sum(math.prod(s) for s in shapes)
+ shape = (size,)
+ else:
+ shape = list(shapes[0])
+ for other_shape in shapes[1:]:
+ shape[_axis] += other_shape[_axis]
+ shape = tuple(shape)
+ ph.assert_result_shape("concat", in_shapes=shapes, out_shape=out.shape, expected=shape, kw=kw)
+
+ if _axis is None:
+ out_indices = (i for i in range(math.prod(out.shape)))
+ for x_num, x in enumerate(arrays, 1):
+ for x_idx in sh.ndindex(x.shape):
+ out_i = next(out_indices)
+ ph.assert_0d_equals(
+ "concat",
+ x_repr=f"x{x_num}[{x_idx}]",
+ x_val=x[x_idx],
+ out_repr=f"out[{out_i}]",
+ out_val=out[out_i],
+ kw=kw,
+ )
+ else:
+ out_indices = sh.ndindex(out.shape)
+ for idx in sh.axis_ndindex(shapes[0], _axis):
+ f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
+ for x_num, x in enumerate(arrays, 1):
+ indexed_x = x[idx]
+ for x_idx in sh.ndindex(indexed_x.shape):
+ out_idx = next(out_indices)
+ ph.assert_0d_equals(
+ "concat",
+ x_repr=f"x{x_num}[{f_idx}][{x_idx}]",
+ x_val=indexed_x[x_idx],
+ out_repr=f"out[{out_idx}]",
+ out_val=out[out_idx],
+ kw=kw,
+ )
+
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()),
+ axis=shared_shapes().flatmap(
+ # Generate both valid and invalid axis
+ lambda s: st.integers(2 * (-len(s) - 1), 2 * len(s))
+ ),
+)
+def test_expand_dims(x, axis):
+ if axis < -x.ndim - 1 or axis > x.ndim:
+ with pytest.raises(IndexError):
+ xp.expand_dims(x, axis=axis)
+ return
+
+ out = xp.expand_dims(x, axis=axis)
+
+ ph.assert_dtype("expand_dims", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ shape = [side for side in x.shape]
+ index = axis if axis >= 0 else x.ndim + axis + 1
+ shape.insert(index, 1)
+ shape = tuple(shape)
+ ph.assert_result_shape("expand_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape)
+
+ assert_array_ndindex(
+ "expand_dims", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape)
+ )
+
+
+@pytest.mark.min_version("2023.12")
+@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data())
+def test_moveaxis(x, data):
+ source = data.draw(
+ st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim), label="source"
+ )
+ if isinstance(source, int):
+ destination = data.draw(st.integers(-x.ndim, x.ndim - 1), label="destination")
+ else:
+ assert isinstance(source, tuple) # sanity check
+ destination = data.draw(
+ st.lists(
+ st.integers(-x.ndim, x.ndim - 1),
+ min_size=len(source),
+ max_size=len(source),
+ unique_by=lambda n: n if n >= 0 else x.ndim + n,
+ ).map(tuple),
+ label="destination"
+ )
+
+ out = xp.moveaxis(x, source, destination)
+
+ ph.assert_dtype("moveaxis", in_dtype=x.dtype, out_dtype=out.dtype)
+
+
+ _source = sh.normalize_axis(source, x.ndim)
+ _destination = sh.normalize_axis(destination, x.ndim)
+
+ new_axes = [n for n in range(x.ndim) if n not in _source]
+
+ for dest, src in sorted(zip(_destination, _source)):
+ new_axes.insert(dest, src)
+
+ expected_shape = tuple(x.shape[i] for i in new_axes)
+
+ ph.assert_result_shape("moveaxis", in_shapes=[x.shape],
+ out_shape=out.shape, expected=expected_shape,
+ kw={"source": source, "destination": destination})
+
+ indices = list(sh.ndindex(x.shape))
+ permuted_indices = [tuple(idx[axis] for axis in new_axes) for idx in indices]
+ assert_array_ndindex(
+ "moveaxis", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=permuted_indices
+ )
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.all_dtypes, shape=hh.shapes(min_side=1).filter(lambda s: 1 in s)
+ ),
+ data=st.data(),
+)
+def test_squeeze(x, data):
+ axes = st.integers(-x.ndim, x.ndim - 1)
+ axis = data.draw(
+ axes
+ | st.lists(axes, unique_by=lambda i: i if i >= 0 else i + x.ndim).map(tuple),
+ label="axis",
+ )
+
+ axes = (axis,) if isinstance(axis, int) else axis
+ axes = sh.normalize_axis(axes, x.ndim)
+
+ squeezable_axes = [i for i, side in enumerate(x.shape) if side == 1]
+ if any(i not in squeezable_axes for i in axes):
+ with pytest.raises(ValueError):
+ xp.squeeze(x, axis)
+ return
+
+ out = xp.squeeze(x, axis)
+
+ ph.assert_dtype("squeeze", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ shape = []
+ for i, side in enumerate(x.shape):
+ if i not in axes:
+ shape.append(side)
+ shape = tuple(shape)
+ ph.assert_result_shape("squeeze", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axis=axis))
+
+ assert_array_ndindex("squeeze", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape))
+
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()),
+ data=st.data(),
+)
+def test_flip(x, data):
+ if x.ndim == 0:
+ axis_strat = st.none()
+ else:
+ axis_strat = (
+ st.none() | st.integers(-x.ndim, x.ndim - 1) | xps.valid_tuple_axes(x.ndim)
+ )
+ kw = data.draw(hh.kwargs(axis=axis_strat), label="kw")
+
+ out = xp.flip(x, **kw)
+
+ ph.assert_dtype("flip", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ for indices in sh.axes_ndindex(x.shape, _axes):
+ reverse_indices = indices[::-1]
+ assert_array_ndindex("flip", x, x_indices=indices, out=out,
+ out_indices=reverse_indices, kw=kw)
+
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)),
+ axes=shared_shapes(min_dims=1).flatmap(
+ lambda s: st.lists(
+ st.integers(0, len(s) - 1),
+ min_size=len(s),
+ max_size=len(s),
+ unique=True,
+ ).map(tuple)
+ ),
+)
+def test_permute_dims(x, axes):
+ out = xp.permute_dims(x, axes)
+
+ ph.assert_dtype("permute_dims", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ shape = [None for _ in range(len(axes))]
+ for i, dim in enumerate(axes):
+ side = x.shape[dim]
+ shape[i] = side
+ shape = tuple(shape)
+ ph.assert_result_shape("permute_dims", in_shapes=[x.shape], out_shape=out.shape, expected=shape, kw=dict(axes=axes))
+
+ indices = list(sh.ndindex(x.shape))
+ permuted_indices = [tuple(idx[axis] for axis in axes) for idx in indices]
+ assert_array_ndindex("permute_dims", x, x_indices=indices, out=out,
+ out_indices=permuted_indices)
+
+
+@pytest.mark.min_version("2023.12")
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)),
+ kw=hh.kwargs(
+ axis=st.none() | shared_shapes(min_dims=1).flatmap(
+ lambda s: st.integers(-len(s), len(s) - 1)
+ )
+ ),
+ data=st.data(),
+)
+def test_repeat(x, kw, data):
+ shape = x.shape
+ axis = kw.get("axis", None)
+ size = math.prod(shape) if axis is None else shape[axis]
+ repeat_strat = st.integers(1, 10)
+ repeats = data.draw(repeat_strat
+ | hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat,
+ shape=st.sampled_from([(1,), (size,)])),
+ label="repeats")
+ if isinstance(repeats, int):
+ n_repititions = size*repeats
+ else:
+ if repeats.shape == (1,):
+ n_repititions = size*int(repeats[0])
+ else:
+ n_repititions = int(xp.sum(repeats))
+
+ assume(n_repititions <= hh.SQRT_MAX_ARRAY_SIZE)
+
+ out = xp.repeat(x, repeats, **kw)
+ ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
+ if axis is None:
+ expected_shape = (n_repititions,)
+ else:
+ expected_shape = list(shape)
+ expected_shape[axis] = n_repititions
+ expected_shape = tuple(expected_shape)
+ ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
+
+ # Test values
+
+ if isinstance(repeats, int):
+ repeats_array = xp.full(size, repeats, dtype=xp.int32)
+ else:
+ repeats_array = repeats
+
+ if kw.get("axis") is None:
+ x = xp.reshape(x, (-1,))
+ axis = 0
+
+ for idx, in sh.iter_indices(x.shape, skip_axes=axis):
+ x_slice = x[idx]
+ out_slice = out[idx]
+ start = 0
+ for i, count in enumerate(repeats_array):
+ end = start + count
+ ph.assert_array_elements("repeat", out=out_slice[start:end],
+ expected=xp.full((count,), x_slice[i], dtype=x.dtype),
+ kw=kw)
+ start = end
+
+reshape_shape = st.shared(hh.shapes(), key="reshape_shape")
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=reshape_shape),
+ shape=hh.reshape_shapes(reshape_shape),
+)
+def test_reshape(x, shape):
+ out = xp.reshape(x, shape)
+
+ ph.assert_dtype("reshape", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ _shape = list(shape)
+ if any(side == -1 for side in shape):
+ size = math.prod(x.shape)
+ rsize = math.prod(shape) * -1
+ _shape[shape.index(-1)] = size / rsize
+ _shape = tuple(_shape)
+ ph.assert_result_shape("reshape", in_shapes=[x.shape], out_shape=out.shape, expected=_shape, kw=dict(shape=shape))
+
+ assert_array_ndindex("reshape", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=sh.ndindex(out.shape))
+
+
+def roll_ndindex(shape: Shape, shifts: Tuple[int], axes: Tuple[int]) -> Iterator[Shape]:
+ assert len(shifts) == len(axes) # sanity check
+ all_shifts = [0 for _ in shape]
+ for s, a in zip(shifts, axes):
+ all_shifts[a] = s
+ for idx in sh.ndindex(shape):
+ yield tuple((i + sh) % si for i, sh, si in zip(idx, all_shifts, shape))
+
+
+@pytest.mark.unvectorized
+@given(hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes()), st.data())
+def test_roll(x, data):
+ shift_strat = st.integers(-hh.MAX_ARRAY_SIZE, hh.MAX_ARRAY_SIZE)
+ if x.ndim > 0:
+ shift_strat = shift_strat | st.lists(
+ shift_strat, min_size=1, max_size=x.ndim
+ ).map(tuple)
+ shift = data.draw(shift_strat, label="shift")
+ if isinstance(shift, tuple):
+ axis_strat = xps.valid_tuple_axes(x.ndim).filter(lambda t: len(t) == len(shift))
+ kw_strat = axis_strat.map(lambda t: {"axis": t})
+ else:
+ axis_strat = st.none()
+ if x.ndim != 0:
+ axis_strat |= st.integers(-x.ndim, x.ndim - 1)
+ kw_strat = hh.kwargs(axis=axis_strat)
+ kw = data.draw(kw_strat, label="kw")
+
+ out = xp.roll(x, shift, **kw)
+
+ kw = {"shift": shift, **kw} # for error messages
+
+ ph.assert_dtype("roll", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ ph.assert_result_shape("roll", in_shapes=[x.shape], out_shape=out.shape, kw=kw)
+
+ if kw.get("axis", None) is None:
+ assert isinstance(shift, int) # sanity check
+ indices = list(sh.ndindex(x.shape))
+ shifted_indices = deque(indices)
+ shifted_indices.rotate(-shift)
+ assert_array_ndindex("roll", x, x_indices=indices, out=out, out_indices=shifted_indices, kw=kw)
+ else:
+ shifts = (shift,) if isinstance(shift, int) else shift
+ axes = sh.normalize_axis(kw["axis"], x.ndim)
+ shifted_indices = roll_ndindex(x.shape, shifts, axes)
+ assert_array_ndindex("roll", x, x_indices=sh.ndindex(x.shape), out=out, out_indices=shifted_indices, kw=kw)
+
+
+@pytest.mark.unvectorized
+@given(
+ shape=shared_shapes(min_dims=1),
+ dtypes=hh.mutually_promotable_dtypes(None),
+ kw=hh.kwargs(
+ axis=shared_shapes(min_dims=1).flatmap(
+ lambda s: st.integers(-len(s), len(s) - 1)
+ )
+ ),
+ data=st.data(),
+)
+def test_stack(shape, dtypes, kw, data):
+ arrays = []
+ for i, dtype in enumerate(dtypes, 1):
+ x = data.draw(hh.arrays(dtype=dtype, shape=shape), label=f"x{i}")
+ arrays.append(x)
+
+ out = xp.stack(arrays, **kw)
+
+ ph.assert_dtype("stack", in_dtype=dtypes, out_dtype=out.dtype)
+
+ axis = kw.get("axis", 0)
+ _axis = axis if axis >= 0 else len(shape) + axis + 1
+ _shape = list(shape)
+ _shape.insert(_axis, len(arrays))
+ _shape = tuple(_shape)
+ ph.assert_result_shape(
+ "stack", in_shapes=tuple(x.shape for x in arrays), out_shape=out.shape, expected=_shape, kw=kw
+ )
+
+ out_indices = sh.ndindex(out.shape)
+ for idx in sh.axis_ndindex(arrays[0].shape, axis=_axis):
+ f_idx = ", ".join(str(i) if isinstance(i, int) else ":" for i in idx)
+ for x_num, x in enumerate(arrays, 1):
+ indexed_x = x[idx]
+ for x_idx in sh.ndindex(indexed_x.shape):
+ out_idx = next(out_indices)
+ ph.assert_0d_equals(
+ "stack",
+ x_repr=f"x{x_num}[{f_idx}][{x_idx}]",
+ x_val=indexed_x[x_idx],
+ out_repr=f"out[{out_idx}]",
+ out_val=out[out_idx],
+ kw=kw,
+ )
+
+
+@pytest.mark.min_version("2023.12")
+@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()), data=st.data())
+def test_tile(x, data):
+ repetitions = data.draw(
+ st.lists(st.integers(1, 4), min_size=1, max_size=x.ndim + 1).map(tuple),
+ label="repetitions"
+ )
+ out = xp.tile(x, repetitions)
+ ph.assert_dtype("tile", in_dtype=x.dtype, out_dtype=out.dtype)
+ # TODO: values testing
+
+ # shape check; the notation is from the Array API docs
+ N, M = len(x.shape), len(repetitions)
+ if N > M:
+ S = x.shape
+ R = (1,)*(N - M) + repetitions
+ else:
+ S = (1,)*(M - N) + x.shape
+ R = repetitions
+
+ assert out.shape == tuple(r*s for r, s in zip(R, S))
+
+
+@pytest.mark.min_version("2023.12")
+@given(x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)), data=st.data())
+def test_unstack(x, data):
+ axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
+ kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
+ out = xp.unstack(x, **kw)
+
+ assert isinstance(out, tuple)
+ assert len(out) == x.shape[axis]
+ expected_shape = list(x.shape)
+ expected_shape.pop(axis)
+ expected_shape = tuple(expected_shape)
+ for i in range(x.shape[axis]):
+ arr = out[i]
+ ph.assert_result_shape("unstack", in_shapes=[x.shape],
+ out_shape=arr.shape, expected=expected_shape,
+ kw=kw, repr_name=f"out[{i}].shape")
+
+ ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype,
+ repr_name=f"out[{i}].dtype")
+
+ idx = [slice(None)] * x.ndim
+ idx[axis] = i
+ ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]")
diff --git a/array_api_tests/test_operators_and_elementwise_functions.py b/array_api_tests/test_operators_and_elementwise_functions.py
new file mode 100644
index 00000000..7fefa151
--- /dev/null
+++ b/array_api_tests/test_operators_and_elementwise_functions.py
@@ -0,0 +1,1955 @@
+"""
+Test element-wise functions/operators against reference implementations.
+"""
+import cmath
+import math
+import operator
+import builtins
+from copy import copy
+from enum import Enum, auto
+from typing import Callable, List, NamedTuple, Optional, Sequence, TypeVar, Union
+
+import pytest
+from hypothesis import assume, given
+from hypothesis import strategies as st
+
+from . import _array_module as xp, api_version
+from . import array_helpers as ah
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import xps
+from .typing import Array, DataType, Param, Scalar, ScalarType, Shape
+
+
+pytestmark = pytest.mark.unvectorized
+
+
+EPS32 = xp.finfo(xp.float32).eps
+
+
+def mock_int_dtype(n: int, dtype: DataType) -> int:
+ """Returns equivalent of `n` that mocks `dtype` behaviour."""
+ nbits = dh.dtype_nbits[dtype]
+ mask = (1 << nbits) - 1
+ n &= mask
+ if dh.dtype_signed[dtype]:
+ highest_bit = 1 << (nbits - 1)
+ if n & highest_bit:
+ n = -((~n & mask) + 1)
+ return n
+
+
+def isclose(
+ a: float,
+ b: float,
+ maximum: float,
+ *,
+ rel_tol: float = 0.25,
+ abs_tol: float = 1,
+) -> bool:
+ """Wraps math.isclose with very generous defaults.
+
+ This is useful for many floating-point operations where the spec does not
+ make accuracy requirements.
+ """
+ if math.isnan(a) or math.isnan(b):
+ raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
+ if math.isinf(a):
+ return math.isinf(b) or abs(b) > math.log(maximum)
+ elif math.isinf(b):
+ return math.isinf(a) or abs(a) > math.log(maximum)
+ return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
+
+
+def isclose_complex(
+ a: complex,
+ b: complex,
+ maximum: float,
+ *,
+ rel_tol: float = 0.25,
+ abs_tol: float = 1,
+) -> bool:
+ """Like isclose() but specifically for complex values."""
+ if cmath.isnan(a) or cmath.isnan(b):
+ raise ValueError(f"{a=} and {b=}, but input must be non-NaN")
+ if cmath.isinf(a):
+ return cmath.isinf(b) or abs(b) > 2**(math.log2(maximum)//2)
+ elif cmath.isinf(b):
+ return cmath.isinf(a) or abs(a) > 2**(math.log2(maximum)//2)
+ return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol)
+
+
+def default_filter(s: Scalar) -> bool:
+ """Returns False when s is a non-finite or a signed zero.
+
+ Used by default as these values are typically special-cased.
+ """
+ if isinstance(s, int): # note bools are ints
+ return True
+ else:
+ return math.isfinite(s) and s != 0
+
+
+T = TypeVar("T")
+
+
+def unary_assert_against_refimpl(
+ func_name: str,
+ in_: Array,
+ res: Array,
+ refimpl: Callable[[T], T],
+ *,
+ res_stype: Optional[ScalarType] = None,
+ filter_: Callable[[Scalar], bool] = default_filter,
+ strict_check: Optional[bool] = None,
+ expr_template: Optional[str] = None,
+):
+ """
+ Assert unary element-wise results are as expected.
+
+ We iterate through every element in the input and resulting arrays, casting
+ the respective elements (0-D arrays) to Python scalars, and assert against
+ the expected output specified by the passed reference implementation, e.g.
+
+ >>> x = xp.asarray([[0, 1], [2, 4]])
+ >>> out = xp.square(x)
+ >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)
+
+ is equivalent to
+
+ >>> for idx in np.ndindex(x.shape):
+ ... expected = int(x[idx]) ** 2
+ ... assert int(out[idx]) == expected
+
+ Casting
+ -------
+
+ The input scalar type is inferred from the input array's dtype like so
+
+ Array dtypes | Python builtin type
+ ----------------- | ---------------------
+ xp.bool | bool
+ xp.int*, xp.uint* | int
+ xp.float* | float
+ xp.complex* | complex
+
+ If res_stype=None (the default), the result scalar type is the same as the
+ input scalar type. We can also specify the result scalar type ourselves, e.g.
+
+ >>> x = xp.asarray([42., xp.inf])
+ >>> out = xp.isinf(x) # should be [False, True]
+ >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)
+
+ Filtering special-cased values
+ ------------------------------
+
+ Values which are special-cased can be present in the input array, but get
+ filtered before they can be asserted against refimpl.
+
+ If filter_=default_filter (the default), all non-finite and floating zero
+ values are filtered, e.g.
+
+ >>> unary_assert_against_refimpl('sin', x, out, math.sin)
+
+ is equivalent to
+
+ >>> for idx in np.ndindex(x.shape):
+ ... at_x = float(x[idx])
+ ... if math.isfinite(at_x) or at_x != 0:
+ ... expected = math.sin(at_x)
+ ... assert math.isclose(float(out[idx]), expected)
+
+ We can also specify the filter function ourselves, e.g.
+
+ >>> def sqrt_filter(s: float) -> bool:
+ ... return math.isfinite(s) and s >= 0
+ >>> unary_assert_against_refimpl('sqrt', x, out, math.sqrt, filter_=sqrt_filter)
+
+ is equivalent to
+
+ >>> for idx in np.ndindex(x.shape):
+ ... at_x = float(x[idx])
+ ... if math.isfinite(s) and s >=0:
+ ... expected = math.sin(at_x)
+ ... assert math.isclose(float(out[idx]), expected)
+
+ Note we leave special-cased values in the input arrays, so as to ensure
+ their presence doesn't affect the outputs respective to non-special-cased
+ elements. We specifically test special case bevaiour in test_special_cases.py.
+
+ Assertion strictness
+ --------------------
+
+ If strict_check=None (the default), integer elements are strictly asserted
+ against, and floating elements are loosely asserted against, e.g.
+
+ >>> unary_assert_against_refimpl('square', x, out, lambda s: s ** 2)
+
+ is equivalent to
+
+ >>> for idx in np.ndindex(x.shape):
+ ... expected = in_stype(x[idx]) ** 2
+ ... if in_stype == int:
+ ... assert int(out[idx]) == expected
+ ... else: # in_stype == float
+ ... assert math.isclose(float(out[idx]), expected)
+
+ Specifying strict_check as True or False will assert strictly/loosely
+ respectively, regardless of dtype. This is useful for testing functions that
+ have definitive outputs for floating inputs, i.e. rounding functions.
+
+ Expressions in errors
+ ---------------------
+
+ Assertion error messages include an expression, by default using func_name
+ like so
+
+ >>> x = xp.asarray([42., xp.inf])
+ >>> out = xp.isinf(x)
+ >>> out
+ [False, False]
+ >>> unary_assert_against_refimpl('isinf', x, out, math.isinf, res_stype=bool)
+ AssertionError: out[1]=False, but should be isinf(x[1])=True ...
+
+ We can specify the expression template ourselves, e.g.
+
+ >>> x = xp.asarray(True)
+ >>> out = xp.logical_not(x)
+ >>> out
+ True
+ >>> unary_assert_against_refimpl(
+ ... 'logical_not', x, out, expr_template='(not {})={}'
+ ... )
+ AssertionError: out=True, but should be (not True)=False ...
+
+ """
+ if in_.shape != res.shape:
+ raise ValueError(f"{res.shape=}, but should be {in_.shape=}")
+ if expr_template is None:
+ expr_template = func_name + "({})={}"
+ in_stype = dh.get_scalar_type(in_.dtype)
+ if res_stype is None:
+ res_stype = dh.get_scalar_type(res.dtype)
+ if res.dtype == xp.bool:
+ m, M = (None, None)
+ elif res.dtype in dh.complex_dtypes:
+ m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]]
+ else:
+ m, M = dh.dtype_ranges[res.dtype]
+ if in_.dtype in dh.complex_dtypes:
+ component_filter = copy(filter_)
+ filter_ = lambda s: component_filter(s.real) and component_filter(s.imag)
+ for idx in sh.ndindex(in_.shape):
+ scalar_i = in_stype(in_[idx])
+ if not filter_(scalar_i):
+ continue
+ try:
+ expected = refimpl(scalar_i)
+ except OverflowError:
+ continue
+ if res.dtype != xp.bool:
+ if res.dtype in dh.complex_dtypes:
+ if expected.real <= m or expected.real >= M:
+ continue
+ if expected.imag <= m or expected.imag >= M:
+ continue
+ else:
+ if expected <= m or expected >= M:
+ continue
+ scalar_o = res_stype(res[idx])
+ f_i = sh.fmt_idx("x", idx)
+ f_o = sh.fmt_idx("out", idx)
+ expr = expr_template.format(f_i, expected)
+ # TODO: strict check floating results too
+ if strict_check == False or res.dtype in dh.all_float_dtypes:
+ msg = (
+ f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
+ f"{f_i}={scalar_i}"
+ )
+ if res.dtype in dh.complex_dtypes:
+ assert isclose_complex(scalar_o, expected, M), msg
+ else:
+ assert isclose(scalar_o, expected, M), msg
+ else:
+ assert scalar_o == expected, (
+ f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
+ f"{f_i}={scalar_i}"
+ )
+
+
+def binary_assert_against_refimpl(
+ func_name: str,
+ left: Array,
+ right: Array,
+ res: Array,
+ refimpl: Callable[[T, T], T],
+ *,
+ res_stype: Optional[ScalarType] = None,
+ filter_: Callable[[Scalar], bool] = default_filter,
+ strict_check: Optional[bool] = None,
+ left_sym: str = "x1",
+ right_sym: str = "x2",
+ res_name: str = "out",
+ expr_template: Optional[str] = None,
+):
+ """
+ Assert binary element-wise results are as expected.
+
+ See unary_assert_against_refimpl for more information.
+ """
+ if expr_template is None:
+ expr_template = func_name + "({}, {})={}"
+ in_stype = dh.get_scalar_type(left.dtype)
+ if res_stype is None:
+ res_stype = dh.get_scalar_type(left.dtype)
+ if res_stype is None:
+ res_stype = in_stype
+ if res.dtype == xp.bool:
+ m, M = (None, None)
+ elif res.dtype in dh.complex_dtypes:
+ m, M = dh.dtype_ranges[dh.dtype_components[res.dtype]]
+ else:
+ m, M = dh.dtype_ranges[res.dtype]
+ if left.dtype in dh.complex_dtypes:
+ component_filter = copy(filter_)
+ filter_ = lambda s: component_filter(s.real) and component_filter(s.imag)
+ for l_idx, r_idx, o_idx in sh.iter_indices(left.shape, right.shape, res.shape):
+ scalar_l = in_stype(left[l_idx])
+ scalar_r = in_stype(right[r_idx])
+ if not (filter_(scalar_l) and filter_(scalar_r)):
+ continue
+ try:
+ expected = refimpl(scalar_l, scalar_r)
+ except OverflowError:
+ continue
+ if res.dtype != xp.bool:
+ if res.dtype in dh.complex_dtypes:
+ if expected.real <= m or expected.real >= M:
+ continue
+ if expected.imag <= m or expected.imag >= M:
+ continue
+ else:
+ if expected <= m or expected >= M:
+ continue
+ scalar_o = res_stype(res[o_idx])
+ f_l = sh.fmt_idx(left_sym, l_idx)
+ f_r = sh.fmt_idx(right_sym, r_idx)
+ f_o = sh.fmt_idx(res_name, o_idx)
+ expr = expr_template.format(f_l, f_r, expected)
+ if strict_check == False or res.dtype in dh.all_float_dtypes:
+ msg = (
+ f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
+ f"{f_l}={scalar_l}, {f_r}={scalar_r}"
+ )
+ if res.dtype in dh.complex_dtypes:
+ assert isclose_complex(scalar_o, expected, M), msg
+ else:
+ assert isclose(scalar_o, expected, M), msg
+ else:
+ assert scalar_o == expected, (
+ f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
+ f"{f_l}={scalar_l}, {f_r}={scalar_r}"
+ )
+
+
+def right_scalar_assert_against_refimpl(
+ func_name: str,
+ left: Array,
+ right: Scalar,
+ res: Array,
+ refimpl: Callable[[T, T], T],
+ *,
+ res_stype: Optional[ScalarType] = None,
+ filter_: Callable[[Scalar], bool] = default_filter,
+ strict_check: Optional[bool] = None,
+ left_sym: str = "x1",
+ res_name: str = "out",
+ expr_template: str = None,
+):
+ """
+ Assert binary element-wise results from scalar operands are as expected.
+
+ See unary_assert_against_refimpl for more information.
+ """
+ if expr_template is None:
+ expr_template = func_name + "({}, {})={}"
+ if left.dtype in dh.complex_dtypes:
+ component_filter = copy(filter_)
+ filter_ = lambda s: component_filter(s.real) and component_filter(s.imag)
+ if filter_(right):
+ return # short-circuit here as there will be nothing to test
+ in_stype = dh.get_scalar_type(left.dtype)
+ if res_stype is None:
+ res_stype = dh.get_scalar_type(left.dtype)
+ if res_stype is None:
+ res_stype = in_stype
+ if res.dtype == xp.bool:
+ m, M = (None, None)
+ elif left.dtype in dh.complex_dtypes:
+ m, M = dh.dtype_ranges[dh.dtype_components[left.dtype]]
+ else:
+ m, M = dh.dtype_ranges[left.dtype]
+ for idx in sh.ndindex(res.shape):
+ scalar_l = in_stype(left[idx])
+ if not (filter_(scalar_l) and filter_(right)):
+ continue
+ try:
+ expected = refimpl(scalar_l, right)
+ except OverflowError:
+ continue
+ if left.dtype != xp.bool:
+ if res.dtype in dh.complex_dtypes:
+ if expected.real <= m or expected.real >= M:
+ continue
+ if expected.imag <= m or expected.imag >= M:
+ continue
+ else:
+ if expected <= m or expected >= M:
+ continue
+ scalar_o = res_stype(res[idx])
+ f_l = sh.fmt_idx(left_sym, idx)
+ f_o = sh.fmt_idx(res_name, idx)
+ expr = expr_template.format(f_l, right, expected)
+ if strict_check == False or res.dtype in dh.all_float_dtypes:
+ msg = (
+ f"{f_o}={scalar_o}, but should be roughly {expr} [{func_name}()]\n"
+ f"{f_l}={scalar_l}"
+ )
+ if res.dtype in dh.complex_dtypes:
+ assert isclose_complex(scalar_o, expected, M), msg
+ else:
+ assert isclose(scalar_o, expected, M), msg
+ else:
+ assert scalar_o == expected, (
+ f"{f_o}={scalar_o}, but should be {expr} [{func_name}()]\n"
+ f"{f_l}={scalar_l}"
+ )
+
+
+# When appropriate, this module tests operators alongside their respective
+# elementwise methods. We do this by parametrizing a generalised test method
+# with every relevant method and operator.
+#
+# Notable arguments in the parameter's context object:
+# - The function object, which for operator test cases is a wrapper that allows
+# test logic to be generalised.
+# - The argument strategies, which can be used to draw arguments for the test
+# case. They may require additional filtering for certain test cases.
+# - right_is_scalar (binary parameters only), which denotes if the right
+# argument is a scalar in a test case. This can be used to appropriately
+# adjust draw filtering and test logic.
+
+
+func_to_op = {v: k for k, v in dh.op_to_func.items()}
+all_op_to_symbol = {**dh.binary_op_to_symbol, **dh.inplace_op_to_symbol}
+finite_kw = {"allow_nan": False, "allow_infinity": False}
+
+
+class UnaryParamContext(NamedTuple):
+ func_name: str
+ func: Callable[[Array], Array]
+ strat: st.SearchStrategy[Array]
+
+ @property
+ def id(self) -> str:
+ return self.func_name
+
+ def __repr__(self):
+ return f"UnaryParamContext(<{self.id}>)"
+
+
+def make_unary_params(
+ elwise_func_name: str,
+ dtypes: Sequence[DataType],
+ *,
+ min_version: str = "2021.12",
+) -> List[Param[UnaryParamContext]]:
+ dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
+ assert len(dtypes) > 0 # sanity check
+ if api_version < "2022.12":
+ dtypes = [d for d in dtypes if d not in dh.complex_dtypes]
+ dtypes_strat = st.sampled_from(dtypes)
+ strat = hh.arrays(dtype=dtypes_strat, shape=hh.shapes())
+ func_ctx = UnaryParamContext(
+ func_name=elwise_func_name, func=getattr(xp, elwise_func_name), strat=strat
+ )
+ op_name = func_to_op[elwise_func_name]
+ op_ctx = UnaryParamContext(
+ func_name=op_name, func=lambda x: getattr(x, op_name)(), strat=strat
+ )
+ if api_version < min_version:
+ marks = pytest.mark.skip(
+ reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}"
+ )
+ else:
+ marks = ()
+ return [
+ pytest.param(func_ctx, id=func_ctx.id, marks=marks),
+ pytest.param(op_ctx, id=op_ctx.id, marks=marks),
+ ]
+
+
+class FuncType(Enum):
+ FUNC = auto()
+ OP = auto()
+ IOP = auto()
+
+
+shapes_kw = {"min_side": 1}
+
+
+class BinaryParamContext(NamedTuple):
+ func_name: str
+ func: Callable[[Array, Union[Scalar, Array]], Array]
+ left_sym: str
+ left_strat: st.SearchStrategy[Array]
+ right_sym: str
+ right_strat: st.SearchStrategy[Union[Scalar, Array]]
+ right_is_scalar: bool
+ res_name: str
+
+ @property
+ def id(self) -> str:
+ return f"{self.func_name}({self.left_sym}, {self.right_sym})"
+
+ def __repr__(self):
+ return f"BinaryParamContext(<{self.id}>)"
+
+
+def make_binary_params(
+ elwise_func_name: str, dtypes: Sequence[DataType]
+) -> List[Param[BinaryParamContext]]:
+ dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
+ assert len(dtypes) > 0 # sanity check
+ shared_oneway_dtypes = st.shared(hh.oneway_promotable_dtypes(dtypes))
+ left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype)
+ right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype)
+
+ def make_param(
+ func_name: str, func_type: FuncType, right_is_scalar: bool
+ ) -> Param[BinaryParamContext]:
+ if right_is_scalar:
+ left_sym = "x"
+ right_sym = "s"
+ else:
+ left_sym = "x1"
+ right_sym = "x2"
+
+ if right_is_scalar:
+ left_strat = hh.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw))
+ right_strat = right_dtypes.flatmap(lambda d: hh.from_dtype(d, **finite_kw))
+ else:
+ if func_type is FuncType.IOP:
+ shared_oneway_shapes = st.shared(hh.oneway_broadcastable_shapes())
+ left_strat = hh.arrays(
+ dtype=left_dtypes,
+ shape=shared_oneway_shapes.map(lambda S: S.result_shape),
+ )
+ right_strat = hh.arrays(
+ dtype=right_dtypes,
+ shape=shared_oneway_shapes.map(lambda S: S.input_shape),
+ )
+ else:
+ mutual_shapes = st.shared(
+ hh.mutually_broadcastable_shapes(2, **shapes_kw)
+ )
+ left_strat = hh.arrays(
+ dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
+ )
+ right_strat = hh.arrays(
+ dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
+ )
+
+ if func_type is FuncType.FUNC:
+ func = getattr(xp, func_name)
+ else:
+ op_sym = all_op_to_symbol[func_name]
+ expr = f"{left_sym} {op_sym} {right_sym}"
+ if func_type is FuncType.OP:
+
+ def func(l: Array, r: Union[Scalar, Array]) -> Array:
+ locals_ = {}
+ locals_[left_sym] = l
+ locals_[right_sym] = r
+ return eval(expr, locals_)
+
+ else:
+
+ def func(l: Array, r: Union[Scalar, Array]) -> Array:
+ locals_ = {}
+ locals_[left_sym] = xp.asarray(l, copy=True) # prevents mutating l
+ locals_[right_sym] = r
+ exec(expr, locals_)
+ return locals_[left_sym]
+
+ func.__name__ = func_name # for repr
+
+ if func_type is FuncType.IOP:
+ res_name = left_sym
+ else:
+ res_name = "out"
+
+ ctx = BinaryParamContext(
+ func_name,
+ func,
+ left_sym,
+ left_strat,
+ right_sym,
+ right_strat,
+ right_is_scalar,
+ res_name,
+ )
+ return pytest.param(ctx, id=ctx.id)
+
+ op_name = func_to_op[elwise_func_name]
+ params = [
+ make_param(elwise_func_name, FuncType.FUNC, False),
+ make_param(op_name, FuncType.OP, False),
+ make_param(op_name, FuncType.OP, True),
+ ]
+ iop_name = f"__i{op_name[2:]}"
+ if iop_name in dh.inplace_op_to_symbol.keys():
+ params.append(make_param(iop_name, FuncType.IOP, False))
+ params.append(make_param(iop_name, FuncType.IOP, True))
+
+ return params
+
+
+def binary_param_assert_dtype(
+ ctx: BinaryParamContext,
+ left: Array,
+ right: Union[Array, Scalar],
+ res: Array,
+ expected: Optional[DataType] = None,
+):
+ if ctx.right_is_scalar:
+ in_dtypes = left.dtype
+ else:
+ in_dtypes = [left.dtype, right.dtype] # type: ignore
+ ph.assert_dtype(
+ ctx.func_name, in_dtype=in_dtypes, out_dtype=res.dtype, expected=expected, repr_name=f"{ctx.res_name}.dtype"
+ )
+
+
+def binary_param_assert_shape(
+ ctx: BinaryParamContext,
+ left: Array,
+ right: Union[Array, Scalar],
+ res: Array,
+ expected: Optional[Shape] = None,
+):
+ if ctx.right_is_scalar:
+ in_shapes = [left.shape]
+ else:
+ in_shapes = [left.shape, right.shape] # type: ignore
+ ph.assert_result_shape(
+ ctx.func_name, in_shapes=in_shapes, out_shape=res.shape, expected=expected, repr_name=f"{ctx.res_name}.shape"
+ )
+
+
+def binary_param_assert_against_refimpl(
+ ctx: BinaryParamContext,
+ left: Array,
+ right: Union[Array, Scalar],
+ res: Array,
+ op_sym: str,
+ refimpl: Callable[[T, T], T],
+ *,
+ res_stype: Optional[ScalarType] = None,
+ filter_: Callable[[Scalar], bool] = default_filter,
+ strict_check: Optional[bool] = None,
+):
+ expr_template = "({} " + op_sym + " {})={}"
+ if ctx.right_is_scalar:
+ right_scalar_assert_against_refimpl(
+ func_name=ctx.func_name,
+ left_sym=ctx.left_sym,
+ left=left,
+ right=right,
+ res_stype=res_stype,
+ res_name=ctx.res_name,
+ res=res,
+ refimpl=refimpl,
+ expr_template=expr_template,
+ filter_=filter_,
+ strict_check=strict_check,
+ )
+ else:
+ binary_assert_against_refimpl(
+ func_name=ctx.func_name,
+ left_sym=ctx.left_sym,
+ left=left,
+ right_sym=ctx.right_sym,
+ right=right,
+ res_stype=res_stype,
+ res_name=ctx.res_name,
+ res=res,
+ refimpl=refimpl,
+ expr_template=expr_template,
+ filter_=filter_,
+ strict_check=strict_check,
+ )
+
+
+def _convert_scalars_helper(x1, x2):
+ """Convert python scalar to arrays, record the shapes/dtypes of arrays.
+
+ For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
+ and all arguments converted to arrays.
+
+ dtypes are separate to help distinguishing between
+ `py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
+ """
+ if dh.is_scalar(x1):
+ in_dtypes = [x2.dtype]
+ in_shapes = [x2.shape]
+ x1a, x2a = xp.asarray(x1), x2
+ elif dh.is_scalar(x2):
+ in_dtypes = [x1.dtype]
+ in_shapes = [x1.shape]
+ x1a, x2a = x1, xp.asarray(x2)
+ else:
+ in_dtypes = [x1.dtype, x2.dtype]
+ in_shapes = [x1.shape, x2.shape]
+ x1a, x2a = x1, x2
+
+ return in_dtypes, in_shapes, (x1a, x2a)
+
+
+def _assert_correctness_binary(
+ name, func, in_dtypes, in_shapes, in_arrs, out, expected_dtype=None, **kwargs
+):
+ x1a, x2a = in_arrs
+ ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype)
+ ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape)
+ check_values = kwargs.pop('check_values', None)
+ if check_values:
+ binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs)
+
+
+@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes))
+@given(data=st.data())
+def test_abs(ctx, data):
+ x = data.draw(ctx.strat, label="x")
+ # abs of the smallest negative integer is out-of-scope
+ if x.dtype in dh.int_dtypes:
+ assume(xp.all(x > dh.dtype_ranges[x.dtype].min))
+
+ out = ctx.func(x)
+
+ if x.dtype in dh.complex_dtypes:
+ assert out.dtype == dh.dtype_components[x.dtype]
+ else:
+ ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl(
+ ctx.func_name,
+ x,
+ out,
+ abs, # type: ignore
+ res_stype=float if x.dtype in dh.complex_dtypes else None,
+ expr_template="abs({})={}",
+ # filter_=lambda s: (
+ # s == float("infinity") or (cmath.isfinite(s) and not ph.is_neg_zero(s))
+ # ),
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_acos(x):
+ out = xp.acos(x)
+ ph.assert_dtype("acos", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("acos", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.acos if x.dtype in dh.complex_dtypes else math.acos
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1
+ unary_assert_against_refimpl(
+ "acos", x, out, refimpl, filter_=filter_
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_acosh(x):
+ out = xp.acosh(x)
+ ph.assert_dtype("acosh", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("acosh", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.acosh if x.dtype in dh.complex_dtypes else math.acosh
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 1
+ unary_assert_against_refimpl(
+ "acosh", x, out, refimpl, filter_=filter_
+ )
+
+
+@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
+@given(data=st.data())
+def test_add(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ with hh.reject_overflow():
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "+", operator.add)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_asin(x):
+ out = xp.asin(x)
+ ph.assert_dtype("asin", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("asin", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.asin if x.dtype in dh.complex_dtypes else math.asin
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 <= s <= 1
+ unary_assert_against_refimpl(
+ "asin", x, out, refimpl, filter_=filter_
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_asinh(x):
+ out = xp.asinh(x)
+ ph.assert_dtype("asinh", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("asinh", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.asinh if x.dtype in dh.complex_dtypes else math.asinh
+ unary_assert_against_refimpl("asinh", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_atan(x):
+ out = xp.atan(x)
+ ph.assert_dtype("atan", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("atan", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.atan if x.dtype in dh.complex_dtypes else math.atan
+ unary_assert_against_refimpl("atan", x, out, refimpl)
+
+
+@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
+def test_atan2(x1, x2):
+ out = xp.atan2(x1, x2)
+ _assert_correctness_binary(
+ "atan",
+ cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2,
+ in_dtypes=[x1.dtype, x2.dtype],
+ in_shapes=[x1.shape, x2.shape],
+ in_arrs=[x1, x2],
+ out=out,
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_atanh(x):
+ out = xp.atanh(x)
+ ph.assert_dtype("atanh", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("atanh", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.atanh if x.dtype in dh.complex_dtypes else math.atanh
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and -1 < s < 1
+ unary_assert_against_refimpl(
+ "atanh",
+ x,
+ out,
+ refimpl,
+ filter_=filter_,
+ )
+
+
+@pytest.mark.parametrize(
+ "ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
+)
+@given(data=st.data())
+def test_bitwise_and(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ if left.dtype == xp.bool:
+ refimpl = operator.and_
+ else:
+ refimpl = lambda l, r: mock_int_dtype(l & r, res.dtype)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "&", refimpl)
+
+
+@pytest.mark.parametrize(
+ "ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
+)
+@given(data=st.data())
+def test_bitwise_left_shift(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+ if ctx.right_is_scalar:
+ assume(right >= 0)
+ else:
+ assume(not xp.any(ah.isnegative(right)))
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ nbits = dh.dtype_nbits[res.dtype]
+ binary_param_assert_against_refimpl(
+ ctx, left, right, res, "<<", lambda l, r: l << r if r < nbits else 0
+ )
+
+
+@pytest.mark.parametrize(
+ "ctx", make_unary_params("bitwise_invert", dh.bool_and_all_int_dtypes)
+)
+@given(data=st.data())
+def test_bitwise_invert(ctx, data):
+ x = data.draw(ctx.strat, label="x")
+
+ out = ctx.func(x)
+
+ ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape)
+ if x.dtype == xp.bool:
+ refimpl = operator.not_
+ else:
+ refimpl = lambda s: mock_int_dtype(~s, x.dtype)
+ unary_assert_against_refimpl(ctx.func_name, x, out, refimpl, expr_template="~{}={}")
+
+
+@pytest.mark.parametrize(
+ "ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
+)
+@given(data=st.data())
+def test_bitwise_or(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ if left.dtype == xp.bool:
+ refimpl = operator.or_
+ else:
+ refimpl = lambda l, r: mock_int_dtype(l | r, res.dtype)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "|", refimpl)
+
+
+@pytest.mark.parametrize(
+ "ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
+)
+@given(data=st.data())
+def test_bitwise_right_shift(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+ if ctx.right_is_scalar:
+ assume(right >= 0)
+ else:
+ assume(not xp.any(ah.isnegative(right)))
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ binary_param_assert_against_refimpl(
+ ctx, left, right, res, ">>", lambda l, r: mock_int_dtype(l >> r, res.dtype)
+ )
+
+
+@pytest.mark.parametrize(
+ "ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
+)
+@given(data=st.data())
+def test_bitwise_xor(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ if left.dtype == xp.bool:
+ refimpl = operator.xor
+ else:
+ refimpl = lambda l, r: mock_int_dtype(l ^ r, res.dtype)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "^", refimpl)
+
+
+@given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()))
+def test_ceil(x):
+ out = xp.ceil(x)
+ ph.assert_dtype("ceil", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("ceil", out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl("ceil", x, out, math.ceil, strict_check=True)
+
+
+@pytest.mark.min_version("2023.12")
+@given(x=hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()), data=st.data())
+def test_clip(x, data):
+ # Ensure that if both min and max are arrays that all three of x, min, max
+ # are broadcast compatible.
+ shape1, shape2 = data.draw(hh.mutually_broadcastable_shapes(2,
+ base_shape=x.shape),
+ label="min.shape, max.shape")
+
+ min = data.draw(st.one_of(
+ st.none(),
+ hh.scalars(dtypes=st.just(x.dtype)),
+ hh.arrays(dtype=st.just(x.dtype), shape=shape1),
+ ), label="min")
+ max = data.draw(st.one_of(
+ st.none(),
+ hh.scalars(dtypes=st.just(x.dtype)),
+ hh.arrays(dtype=st.just(x.dtype), shape=shape2),
+ ), label="max")
+
+ # min > max is undefined (but allow nans)
+ assume(min is None or max is None or not xp.any(ah.less(xp.asarray(max), xp.asarray(min))))
+
+ kw = data.draw(
+ hh.specified_kwargs(
+ ("min", min, None),
+ ("max", max, None)),
+ label="kwargs")
+
+ out = xp.clip(x, **kw)
+
+ # min and max do not participate in type promotion
+ ph.assert_dtype("clip", in_dtype=x.dtype, out_dtype=out.dtype)
+
+ shapes = [x.shape]
+ if min is not None and not dh.is_scalar(min):
+ shapes.append(min.shape)
+ if max is not None and not dh.is_scalar(max):
+ shapes.append(max.shape)
+ expected_shape = sh.broadcast_shapes(*shapes)
+ ph.assert_shape("clip", out_shape=out.shape, expected=expected_shape)
+
+ # This is based on right_scalar_assert_against_refimpl and
+ # binary_assert_against_refimpl. clip() is currently the only ternary
+ # elementwise function and the only function that supports arrays and
+ # scalars. However, where() (in test_searching_functions) is similar
+ # and if scalar support is added to it, we may want to factor out and
+ # reuse this logic.
+
+ def refimpl(_x, _min, _max):
+ # Skip cases where _min and _max are integers whose values do not
+ # fit in the dtype of _x, since this behavior is unspecified.
+ if dh.is_int_dtype(x.dtype):
+ if _min is not None and _min not in dh.dtype_ranges[x.dtype]:
+ return None
+ if _max is not None and _max not in dh.dtype_ranges[x.dtype]:
+ return None
+
+ # If min or max are float64 and x is float32, they will need to be
+ # downcast to float32. This could result in a round in the wrong
+ # direction meaning the resulting clipped value might not actually be
+ # between min and max. This behavior is unspecified, so skip any cases
+ # where x is within the rounding error of downcasting min or max.
+ if x.dtype == xp.float32:
+ if min is not None and not dh.is_scalar(min) and min.dtype == xp.float64 and math.isfinite(_min):
+ _min_float32 = float(xp.asarray(_min, dtype=xp.float32))
+ if math.isinf(_min_float32):
+ return None
+ tol = abs(_min - _min_float32)
+ if math.isclose(_min, _min_float32, abs_tol=tol):
+ return None
+ if max is not None and not dh.is_scalar(max) and max.dtype == xp.float64 and math.isfinite(_max):
+ _max_float32 = float(xp.asarray(_max, dtype=xp.float32))
+ if math.isinf(_max_float32):
+ return None
+ tol = abs(_max - _max_float32)
+ if math.isclose(_max, _max_float32, abs_tol=tol):
+ return None
+
+ if (math.isnan(_x)
+ or (_min is not None and math.isnan(_min))
+ or (_max is not None and math.isnan(_max))):
+ return math.nan
+ if _min is _max is None:
+ return _x
+ if _max is None:
+ return builtins.max(_x, _min)
+ if _min is None:
+ return builtins.min(_x, _max)
+ return builtins.min(builtins.max(_x, _min), _max)
+
+ stype = dh.get_scalar_type(x.dtype)
+ min_shape = () if min is None or dh.is_scalar(min) else min.shape
+ max_shape = () if max is None or dh.is_scalar(max) else max.shape
+
+ for x_idx, min_idx, max_idx, o_idx in sh.iter_indices(
+ x.shape, min_shape, max_shape, out.shape):
+ x_val = stype(x[x_idx])
+ if min is None or dh.is_scalar(min):
+ min_val = min
+ else:
+ min_val = stype(min[min_idx])
+ if max is None or dh.is_scalar(max):
+ max_val = max
+ else:
+ max_val = stype(max[max_idx])
+ expected = refimpl(x_val, min_val, max_val)
+ if expected is None:
+ continue
+ out_val = stype(out[o_idx])
+ if math.isnan(expected):
+ assert math.isnan(out_val), (
+ f"out[{o_idx}]={out[o_idx]} but should be nan [clip()]\n"
+ f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
+ )
+ else:
+ if out.dtype == xp.float32:
+ # conversion to builtin float is prone to roundoff errors
+ close_enough = math.isclose(out_val, expected, rel_tol=EPS32)
+ else:
+ close_enough = out_val == expected
+
+ assert close_enough, (
+ f"out[{o_idx}]={out[o_idx]} but should be {expected} [clip()]\n"
+ f"x[{x_idx}]={x_val}, min[{min_idx}]={min_val}, max[{max_idx}]={max_val}"
+ )
+
+
+@pytest.mark.min_version("2022.12")
+@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from")
+@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
+def test_conj(x):
+ out = xp.conj(x)
+ ph.assert_dtype("conj", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("conj", out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl("conj", x, out, operator.methodcaller("conjugate"))
+
+
+@pytest.mark.min_version("2023.12")
+@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
+def test_copysign(x1, x2):
+ out = xp.copysign(x1, x2)
+ ph.assert_dtype("copysign", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
+ ph.assert_result_shape("copysign", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
+ binary_assert_against_refimpl("copysign", x1, x2, out, math.copysign)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_cos(x):
+ out = xp.cos(x)
+ ph.assert_dtype("cos", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("cos", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.cos if x.dtype in dh.complex_dtypes else math.cos
+ unary_assert_against_refimpl("cos", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_cosh(x):
+ out = xp.cosh(x)
+ ph.assert_dtype("cosh", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("cosh", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.cosh if x.dtype in dh.complex_dtypes else math.cosh
+ unary_assert_against_refimpl("cosh", x, out, refimpl)
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.all_float_dtypes))
+@given(data=st.data())
+def test_divide(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+ if ctx.right_is_scalar:
+ assume # TODO: assume what?
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ binary_param_assert_against_refimpl(
+ ctx,
+ left,
+ right,
+ res,
+ "/",
+ operator.truediv,
+ filter_=lambda s: cmath.isfinite(s) and s != 0,
+ )
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
+@given(data=st.data())
+def test_equal(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ out = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, out, xp.bool)
+ binary_param_assert_shape(ctx, left, right, out)
+ if not ctx.right_is_scalar:
+ # We manually promote the dtypes as incorrect internal type promotion
+ # could lead to false positives. For example
+ #
+ # >>> xp.equal(
+ # ... xp.asarray(1.0, dtype=xp.float32),
+ # ... xp.asarray(1.00000001, dtype=xp.float64),
+ # ... )
+ #
+ # would erroneously be True if float64 downcasted to float32.
+ promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
+ left = xp.astype(left, promoted_dtype)
+ right = xp.astype(right, promoted_dtype)
+ binary_param_assert_against_refimpl(
+ ctx, left, right, out, "==", operator.eq, res_stype=bool
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_exp(x):
+ out = xp.exp(x)
+ ph.assert_dtype("exp", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("exp", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.exp if x.dtype in dh.complex_dtypes else math.exp
+ unary_assert_against_refimpl("exp", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_expm1(x):
+ out = xp.expm1(x)
+ ph.assert_dtype("expm1", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("expm1", out_shape=out.shape, expected=x.shape)
+ if x.dtype in dh.complex_dtypes:
+ def refimpl(z):
+ # There's no cmath.expm1. Use
+ #
+ # exp(x+yi) - 1
+ # = exp(x)exp(yi) - 1
+ # = exp(x)(cos(y) + sin(y)i) - 1
+ # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i
+ # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i
+ #
+ # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of
+ # significance near y = 0.
+ re, im = z.real, z.imag
+ return math.expm1(re)*math.cos(im) - 2*math.sin(im/2)**2 + 1j*math.exp(re)*math.sin(im)
+ else:
+ refimpl = math.expm1
+ unary_assert_against_refimpl("expm1", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.real_dtypes, shape=hh.shapes()))
+def test_floor(x):
+ out = xp.floor(x)
+ ph.assert_dtype("floor", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("floor", out_shape=out.shape, expected=x.shape)
+ if x.dtype in dh.complex_dtypes:
+ def refimpl(z):
+ return complex(math.floor(z.real), math.floor(z.imag))
+ else:
+ refimpl = math.floor
+ unary_assert_against_refimpl("floor", x, out, refimpl, strict_check=True)
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.real_dtypes))
+@given(data=st.data())
+def test_floor_divide(ctx, data):
+ left = data.draw(
+ ctx.left_strat.filter(lambda x: not xp.any(x == 0)), label=ctx.left_sym
+ )
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+ if ctx.right_is_scalar:
+ assume(right != 0)
+ else:
+ assume(not xp.any(right == 0))
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.real_dtypes))
+@given(data=st.data())
+def test_greater(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ out = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, out, xp.bool)
+ binary_param_assert_shape(ctx, left, right, out)
+ if not ctx.right_is_scalar:
+ # See test_equal note
+ promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
+ left = xp.astype(left, promoted_dtype)
+ right = xp.astype(right, promoted_dtype)
+ binary_param_assert_against_refimpl(
+ ctx, left, right, out, ">", operator.gt, res_stype=bool
+ )
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.real_dtypes))
+@given(data=st.data())
+def test_greater_equal(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ out = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, out, xp.bool)
+ binary_param_assert_shape(ctx, left, right, out)
+ if not ctx.right_is_scalar:
+ # See test_equal note
+ promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
+ left = xp.astype(left, promoted_dtype)
+ right = xp.astype(right, promoted_dtype)
+ binary_param_assert_against_refimpl(
+ ctx, left, right, out, ">=", operator.ge, res_stype=bool
+ )
+
+
+@pytest.mark.min_version("2023.12")
+@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
+def test_hypot(x1, x2):
+ out = xp.hypot(x1, x2)
+ _assert_correctness_binary(
+ "hypot",
+ math.hypot,
+ in_dtypes=[x1.dtype, x2.dtype],
+ in_shapes=[x1.shape, x2.shape],
+ in_arrs=[x1, x2],
+ out=out
+ )
+
+
+@pytest.mark.min_version("2022.12")
+@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from")
+@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
+def test_imag(x):
+ out = xp.imag(x)
+ ph.assert_dtype("imag", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
+ ph.assert_shape("imag", out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl("imag", x, out, operator.attrgetter("imag"))
+
+
+@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
+def test_isfinite(x):
+ out = xp.isfinite(x)
+ ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
+ ph.assert_shape("isfinite", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.isfinite if x.dtype in dh.complex_dtypes else math.isfinite
+ unary_assert_against_refimpl("isfinite", x, out, refimpl, res_stype=bool)
+
+
+@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
+def test_isinf(x):
+ out = xp.isinf(x)
+ ph.assert_dtype("isfinite", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
+ ph.assert_shape("isinf", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.isinf if x.dtype in dh.complex_dtypes else math.isinf
+ unary_assert_against_refimpl("isinf", x, out, refimpl, res_stype=bool)
+
+
+@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
+def test_isnan(x):
+ out = xp.isnan(x)
+ ph.assert_dtype("isnan", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
+ ph.assert_shape("isnan", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.isnan if x.dtype in dh.complex_dtypes else math.isnan
+ unary_assert_against_refimpl("isnan", x, out, refimpl, res_stype=bool)
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("less", dh.real_dtypes))
+@given(data=st.data())
+def test_less(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ out = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, out, xp.bool)
+ binary_param_assert_shape(ctx, left, right, out)
+ if not ctx.right_is_scalar:
+ # See test_equal note
+ promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
+ left = xp.astype(left, promoted_dtype)
+ right = xp.astype(right, promoted_dtype)
+ binary_param_assert_against_refimpl(
+ ctx, left, right, out, "<", operator.lt, res_stype=bool
+ )
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.real_dtypes))
+@given(data=st.data())
+def test_less_equal(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ out = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, out, xp.bool)
+ binary_param_assert_shape(ctx, left, right, out)
+ if not ctx.right_is_scalar:
+ # See test_equal note
+ promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
+ left = xp.astype(left, promoted_dtype)
+ right = xp.astype(right, promoted_dtype)
+ binary_param_assert_against_refimpl(
+ ctx, left, right, out, "<=", operator.le, res_stype=bool
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_log(x):
+ out = xp.log(x)
+ ph.assert_dtype("log", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("log", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.log if x.dtype in dh.complex_dtypes else math.log
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0
+ unary_assert_against_refimpl(
+ "log", x, out, refimpl, filter_=filter_
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_log1p(x):
+ out = xp.log1p(x)
+ ph.assert_dtype("log1p", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("log1p", out_shape=out.shape, expected=x.shape)
+ # There isn't a cmath.log1p, and implementing one isn't straightforward
+ # (see
+ # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy).
+ # For now, just use log(1+p) for complex inputs, which should hopefully be
+ # fine given the very loose numerical tolerances we use. If it isn't, we
+ # can try using something like a series expansion for small p.
+ if x.dtype in dh.complex_dtypes:
+ refimpl = lambda z: cmath.log(1+z)
+ else:
+ refimpl = math.log1p
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > -1
+ unary_assert_against_refimpl(
+ "log1p", x, out, refimpl, filter_=filter_
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_log2(x):
+ out = xp.log2(x)
+ ph.assert_dtype("log2", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("log2", out_shape=out.shape, expected=x.shape)
+ if x.dtype in dh.complex_dtypes:
+ refimpl = lambda z: cmath.log(z)/math.log(2)
+ else:
+ refimpl = math.log2
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0
+ unary_assert_against_refimpl(
+ "log2", x, out, refimpl, filter_=filter_
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_log10(x):
+ out = xp.log10(x)
+ ph.assert_dtype("log10", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("log10", out_shape=out.shape, expected=x.shape)
+ if x.dtype in dh.complex_dtypes:
+ refimpl = lambda z: cmath.log(z)/math.log(10)
+ else:
+ refimpl = math.log10
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s > 0
+ unary_assert_against_refimpl(
+ "log10", x, out, refimpl, filter_=filter_
+ )
+
+
+def logaddexp_refimpl(l: float, r: float) -> float:
+ try:
+ return math.log(math.exp(l) + math.exp(r))
+ except ValueError: # raised for log(0.)
+ raise OverflowError
+
+
+@pytest.mark.min_version("2023.12")
+@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
+def test_logaddexp(x1, x2):
+ out = xp.logaddexp(x1, x2)
+ _assert_correctness_binary(
+ "logaddexp",
+ logaddexp_refimpl,
+ in_dtypes=[x1.dtype, x2.dtype],
+ in_shapes=[x1.shape, x2.shape],
+ in_arrs=[x1, x2],
+ out=out
+ )
+
+
+@given(hh.arrays(dtype=xp.bool, shape=hh.shapes()))
+def test_logical_not(x):
+ out = xp.logical_not(x)
+ ph.assert_dtype("logical_not", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("logical_not", out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl(
+ "logical_not", x, out, operator.not_, expr_template="(not {})={}"
+ )
+
+
+@given(*hh.two_mutual_arrays([xp.bool]))
+def test_logical_and(x1, x2):
+ out = xp.logical_and(x1, x2)
+ _assert_correctness_binary(
+ "logical_and",
+ operator.and_,
+ in_dtypes=[x1.dtype, x2.dtype],
+ in_shapes=[x1.shape, x2.shape],
+ in_arrs=[x1, x2],
+ out=out,
+ expr_template="({} and {})={}"
+ )
+
+
+@given(*hh.two_mutual_arrays([xp.bool]))
+def test_logical_or(x1, x2):
+ out = xp.logical_or(x1, x2)
+ _assert_correctness_binary(
+ "logical_or",
+ operator.or_,
+ in_dtypes=[x1.dtype, x2.dtype],
+ in_shapes=[x1.shape, x2.shape],
+ in_arrs=[x1, x2],
+ out=out,
+ expr_template="({} or {})={}"
+ )
+
+
+@given(*hh.two_mutual_arrays([xp.bool]))
+def test_logical_xor(x1, x2):
+ out = xp.logical_xor(x1, x2)
+ _assert_correctness_binary(
+ "logical_xor",
+ operator.xor,
+ in_dtypes=[x1.dtype, x2.dtype],
+ in_shapes=[x1.shape, x2.shape],
+ in_arrs=[x1, x2],
+ out=out,
+ expr_template="({} ^ {})={}"
+ )
+
+
+@pytest.mark.min_version("2023.12")
+@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
+def test_maximum(x1, x2):
+ out = xp.maximum(x1, x2)
+ _assert_correctness_binary(
+ "maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
+ )
+
+
+@pytest.mark.min_version("2023.12")
+@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
+def test_minimum(x1, x2):
+ out = xp.minimum(x1, x2)
+ _assert_correctness_binary(
+ "minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
+ )
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
+@given(data=st.data())
+def test_multiply(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "*", operator.mul)
+
+
+# TODO: clarify if uints are acceptable, adjust accordingly
+@pytest.mark.parametrize("ctx", make_unary_params("negative", dh.numeric_dtypes))
+@given(data=st.data())
+def test_negative(ctx, data):
+ x = data.draw(ctx.strat, label="x")
+ # negative of the smallest negative integer is out-of-scope
+ if x.dtype in dh.int_dtypes:
+ assume(xp.all(x > dh.dtype_ranges[x.dtype].min))
+
+ out = ctx.func(x)
+
+ ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl(
+ ctx.func_name, x, out, operator.neg, expr_template="-({})={}" # type: ignore
+ )
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
+@given(data=st.data())
+def test_not_equal(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ out = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, out, xp.bool)
+ binary_param_assert_shape(ctx, left, right, out)
+ if not ctx.right_is_scalar:
+ # See test_equal note
+ promoted_dtype = dh.promotion_table[left.dtype, right.dtype]
+ left = xp.astype(left, promoted_dtype)
+ right = xp.astype(right, promoted_dtype)
+ binary_param_assert_against_refimpl(
+ ctx, left, right, out, "!=", operator.ne, res_stype=bool
+ )
+
+
+@pytest.mark.min_version("2024.12")
+@given(
+ shapes=hh.two_mutually_broadcastable_shapes,
+ dtype=hh.real_floating_dtypes,
+ data=st.data()
+)
+def test_nextafter(shapes, dtype, data):
+ x1 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x1")
+ x2 = data.draw(hh.arrays(dtype=dtype, shape=shapes[0]), label="x2")
+
+ out = xp.nextafter(x1, x2)
+ _assert_correctness_binary(
+ "nextafter",
+ math.nextafter,
+ in_dtypes=[x1.dtype, x2.dtype],
+ in_shapes=[x1.shape, x2.shape],
+ in_arrs=[x1, x2],
+ out=out
+ )
+
+@pytest.mark.parametrize("ctx", make_unary_params("positive", dh.numeric_dtypes))
+@given(data=st.data())
+def test_positive(ctx, data):
+ x = data.draw(ctx.strat, label="x")
+
+ out = ctx.func(x)
+
+ ph.assert_dtype(ctx.func_name, in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape(ctx.func_name, out_shape=out.shape, expected=x.shape)
+ ph.assert_array_elements(ctx.func_name, out=out, expected=x)
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
+@given(data=st.data())
+def test_pow(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+ if ctx.right_is_scalar:
+ if isinstance(right, int):
+ assume(right >= 0)
+ else:
+ if dh.is_int_dtype(right.dtype):
+ assume(xp.all(right >= 0))
+
+ with hh.reject_overflow():
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ # Values testing pow is too finicky
+
+
+@pytest.mark.min_version("2022.12")
+@pytest.mark.skipif(hh.complex_dtypes.is_empty, reason="no complex data types to draw from")
+@given(hh.arrays(dtype=hh.complex_dtypes, shape=hh.shapes()))
+def test_real(x):
+ out = xp.real(x)
+ ph.assert_dtype("real", in_dtype=x.dtype, out_dtype=out.dtype, expected=dh.dtype_components[x.dtype])
+ ph.assert_shape("real", out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl("real", x, out, operator.attrgetter("real"))
+
+
+@pytest.mark.min_version("2024.12")
+@given(hh.arrays(dtype=hh.floating_dtypes, shape=hh.shapes(), elements=finite_kw))
+def test_reciprocal(x):
+ out = xp.reciprocal(x)
+ ph.assert_dtype("reciprocal", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("reciprocal", out_shape=out.shape, expected=x.shape)
+ refimpl = lambda x: 1.0 / x
+ unary_assert_against_refimpl(
+ "reciprocal",
+ x,
+ out,
+ refimpl,
+ strict_check=True,
+ )
+
+
+@pytest.mark.skip(reason="flaky")
+@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.real_dtypes))
+@given(data=st.data())
+def test_remainder(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+ if ctx.right_is_scalar:
+ assume(right != 0)
+ else:
+ assume(not xp.any(right == 0))
+
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "%", operator.mod)
+
+
+@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
+def test_round(x):
+ out = xp.round(x)
+ ph.assert_dtype("round", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("round", out_shape=out.shape, expected=x.shape)
+ if x.dtype in dh.complex_dtypes:
+ refimpl = lambda z: complex(round(z.real), round(z.imag))
+ else:
+ refimpl = round
+ unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True)
+
+
+@pytest.mark.min_version("2023.12")
+@given(hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes()))
+def test_signbit(x):
+ out = xp.signbit(x)
+ ph.assert_dtype("signbit", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
+ ph.assert_shape("signbit", out_shape=out.shape, expected=x.shape)
+ refimpl = lambda x: math.copysign(1.0, x) < 0
+ unary_assert_against_refimpl("round", x, out, refimpl, strict_check=True)
+
+
+@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes(), elements=finite_kw))
+def test_sign(x):
+ out = xp.sign(x)
+ ph.assert_dtype("sign", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("sign", out_shape=out.shape, expected=x.shape)
+ refimpl = lambda x: x / abs(x) if x != 0 else 0
+ unary_assert_against_refimpl(
+ "sign",
+ x,
+ out,
+ refimpl,
+ strict_check=True,
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_sin(x):
+ out = xp.sin(x)
+ ph.assert_dtype("sin", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("sin", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.sin if x.dtype in dh.complex_dtypes else math.sin
+ unary_assert_against_refimpl("sin", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_sinh(x):
+ out = xp.sinh(x)
+ ph.assert_dtype("sinh", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("sinh", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.sinh if x.dtype in dh.complex_dtypes else math.sinh
+ unary_assert_against_refimpl("sinh", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.numeric_dtypes, shape=hh.shapes()))
+def test_square(x):
+ out = xp.square(x)
+ ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("square", out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl(
+ "square", x, out, lambda s: s*s, expr_template="{}²={}"
+ )
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_sqrt(x):
+ out = xp.sqrt(x)
+ ph.assert_dtype("sqrt", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("sqrt", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.sqrt if x.dtype in dh.complex_dtypes else math.sqrt
+ filter_ = default_filter if x.dtype in dh.complex_dtypes else lambda s: default_filter(s) and s >= 0
+ unary_assert_against_refimpl(
+ "sqrt", x, out, refimpl, filter_=filter_
+ )
+
+
+@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
+@given(data=st.data())
+def test_subtract(ctx, data):
+ left = data.draw(ctx.left_strat, label=ctx.left_sym)
+ right = data.draw(ctx.right_strat, label=ctx.right_sym)
+
+ with hh.reject_overflow():
+ res = ctx.func(left, right)
+
+ binary_param_assert_dtype(ctx, left, right, res)
+ binary_param_assert_shape(ctx, left, right, res)
+ binary_param_assert_against_refimpl(ctx, left, right, res, "-", operator.sub)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_tan(x):
+ out = xp.tan(x)
+ ph.assert_dtype("tan", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("tan", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.tan if x.dtype in dh.complex_dtypes else math.tan
+ unary_assert_against_refimpl("tan", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
+def test_tanh(x):
+ out = xp.tanh(x)
+ ph.assert_dtype("tanh", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("tanh", out_shape=out.shape, expected=x.shape)
+ refimpl = cmath.tanh if x.dtype in dh.complex_dtypes else math.tanh
+ unary_assert_against_refimpl("tanh", x, out, refimpl)
+
+
+@given(hh.arrays(dtype=hh.real_dtypes, shape=xps.array_shapes()))
+def test_trunc(x):
+ out = xp.trunc(x)
+ ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape)
+ unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True)
+
+
+def _check_binary_with_scalars(func_data, x1x2):
+ x1, x2 = x1x2
+ func_name, refimpl, kwds, expected_dtype = func_data
+ func = getattr(xp, func_name)
+ out = func(x1, x2)
+ in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
+ _assert_correctness_binary(
+ func_name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds
+ )
+
+
+def _filter_zero(x):
+ return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0))
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.parametrize('func_data',
+ # func_name, refimpl, kwargs, expected_dtype
+ [
+ ("add", operator.add, {}, None),
+ ("atan2", math.atan2, {}, None),
+ ("copysign", math.copysign, {}, None),
+ ("divide", operator.truediv, {"filter_": lambda s: s != 0}, None),
+ ("hypot", math.hypot, {}, None),
+ ("logaddexp", logaddexp_refimpl, {}, None),
+ ("nextafter", math.nextafter, {}, None),
+ ("maximum", max, {'strict_check': True}, None),
+ ("minimum", min, {'strict_check': True}, None),
+ ("multiply", operator.mul, {}, None),
+ ("subtract", operator.sub, {}, None),
+
+ ("equal", operator.eq, {}, xp.bool),
+ ("not_equal", operator.ne, {}, xp.bool),
+ ("less", operator.lt, {}, xp.bool),
+ ("less_equal", operator.le, {}, xp.bool),
+ ("greater", operator.gt, {}, xp.bool),
+ ("greater_equal", operator.ge, {}, xp.bool),
+ ("pow", operator.pow, {'check_values': False}, None) # value tests are too finicky for pow
+ ],
+ ids=lambda func_data: func_data[0] # use names for test IDs
+)
+@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
+def test_binary_with_scalars_real(func_data, x1x2):
+ _check_binary_with_scalars(func_data, x1x2)
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.parametrize('func_data',
+ # func_name, refimpl, kwargs, expected_dtype
+ [
+ ("logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None),
+ ("logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None),
+ ("logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None),
+ ],
+ ids=lambda func_data: func_data[0] # use names for test IDs
+)
+@given(x1x2=hh.array_and_py_scalar([xp.bool]))
+def test_binary_with_scalars_bool(func_data, x1x2):
+ _check_binary_with_scalars(func_data, x1x2)
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.parametrize('func_data',
+ # func_name, refimpl, kwargs, expected_dtype
+ [
+ ("floor_divide", operator.floordiv, {}, None),
+ ("remainder", operator.mod, {}, None),
+ ],
+ ids=lambda func_data: func_data[0] # use names for test IDs
+)
+@given(x1x2=hh.array_and_py_scalar([xp.int64]))
+def test_binary_with_scalars_int(func_data, x1x2):
+ assume(_filter_zero(x1x2[1]))
+ assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
+ _check_binary_with_scalars(func_data, x1x2)
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.parametrize('func_data',
+ # func_name, refimpl, kwargs, expected_dtype
+ [
+ ("bitwise_and", operator.and_, {}, None),
+ ("bitwise_or", operator.or_, {}, None),
+ ("bitwise_xor", operator.xor, {}, None),
+ ],
+ ids=lambda func_data: func_data[0] # use names for test IDs
+)
+@given(x1x2=hh.array_and_py_scalar([xp.int32]))
+def test_binary_with_scalars_bitwise(func_data, x1x2):
+ func_name, refimpl, kwargs, expected = func_data
+ # repack the refimpl
+ refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
+ _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.parametrize('func_data',
+ # func_name, refimpl, kwargs, expected_dtype
+ [
+ ("bitwise_left_shift", operator.lshift, {}, None),
+ ("bitwise_right_shift", operator.rshift, {}, None),
+ ],
+ ids=lambda func_data: func_data[0] # use names for test IDs
+)
+@given(x1x2=hh.array_and_py_scalar([xp.int32], positive=True, mM=(1, 3)))
+def test_binary_with_scalars_bitwise_shifts(func_data, x1x2):
+ func_name, refimpl, kwargs, expected = func_data
+ # repack the refimpl
+ refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
+ _check_binary_with_scalars((func_name, refimpl_, kwargs, expected), x1x2)
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.unvectorized
+@given(
+ x1x2=hh.array_and_py_scalar([xp.int32]),
+ data=st.data()
+)
+def test_where_with_scalars(x1x2, data):
+ x1, x2 = x1x2
+
+ if dh.is_scalar(x1):
+ dtype, shape = x2.dtype, x2.shape
+ x1_arr, x2_arr = xp.broadcast_to(xp.asarray(x1), shape), x2
+ else:
+ dtype, shape = x1.dtype, x1.shape
+ x1_arr, x2_arr = x1, xp.broadcast_to(xp.asarray(x2), shape)
+
+ condition = data.draw(hh.arrays(shape=shape, dtype=xp.bool))
+
+ out = xp.where(condition, x1, x2)
+
+ assert out.dtype == dtype, f"where: got {out.dtype = } for {dtype=}, {x1=} and {x2=}"
+ assert out.shape == shape, f"where: got {out.shape = } for {shape=}, {x1=} and {x2=}"
+
+ # value test
+ for idx in sh.ndindex(shape):
+ if condition[idx]:
+ assert out[idx] == x1_arr[idx]
+ else:
+ assert out[idx] == x2_arr[idx]
+
+
diff --git a/array_api_tests/test_searching_functions.py b/array_api_tests/test_searching_functions.py
new file mode 100644
index 00000000..b72c8030
--- /dev/null
+++ b/array_api_tests/test_searching_functions.py
@@ -0,0 +1,249 @@
+import math
+
+import pytest
+from hypothesis import given, note, assume
+from hypothesis import strategies as st
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from . import xps
+
+
+pytestmark = pytest.mark.unvectorized
+
+
+@given(
+ x=hh.arrays(
+ dtype=hh.real_dtypes,
+ shape=hh.shapes(min_dims=1, min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_argmax(x, data):
+ kw = data.draw(
+ hh.kwargs(
+ axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
+ keepdims=st.booleans(),
+ ),
+ label="kw",
+ )
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.argmax(x, **kw)
+
+ ph.assert_default_index("argmax", out.dtype)
+ axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "argmax", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(x.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
+ max_i = int(out[out_idx])
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = max(range(len(elements)), key=elements.__getitem__)
+ ph.assert_scalar_equals("argmax", type_=int, idx=out_idx, out=max_i,
+ expected=expected, kw=kw)
+
+
+@given(
+ x=hh.arrays(
+ dtype=hh.real_dtypes,
+ shape=hh.shapes(min_dims=1, min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_argmin(x, data):
+ kw = data.draw(
+ hh.kwargs(
+ axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
+ keepdims=st.booleans(),
+ ),
+ label="kw",
+ )
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.argmin(x, **kw)
+
+ ph.assert_default_index("argmin", out.dtype)
+ axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "argmin", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(x.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
+ min_i = int(out[out_idx])
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = min(range(len(elements)), key=elements.__getitem__)
+ ph.assert_scalar_equals("argmin", type_=int, idx=out_idx, out=min_i, expected=expected)
+
+
+# XXX: the strategy for x is problematic on JAX unless JAX_ENABLE_X64 is on
+# the problem is tha for ints >iinfo(int32) it runs into essentially this:
+# >>> jnp.asarray[2147483648], dtype=jnp.int64)
+# .... https://github.com/jax-ml/jax/pull/6047 ...
+# Explicitly limiting the range in elements(...) runs into problems with
+# hypothesis where floating-point numbers are not exactly representable.
+@pytest.mark.min_version("2024.12")
+@given(
+ x=hh.arrays(
+ dtype=hh.all_dtypes,
+ shape=hh.shapes(min_dims=1, min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_count_nonzero(x, data):
+ kw = data.draw(
+ hh.kwargs(
+ axis=hh.axes(x.ndim),
+ keepdims=st.booleans(),
+ ),
+ label="kw",
+ )
+ keepdims = kw.get("keepdims", False)
+
+ assume(kw.get("axis", None) != ()) # TODO clarify in the spec
+
+ out = xp.count_nonzero(x, **kw)
+
+ ph.assert_default_index("count_nonzero", out.dtype)
+ axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "count_nonzero", in_shape=x.shape, out_shape=out.shape, axes=axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(x.dtype)
+
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, axes), sh.ndindex(out.shape)):
+ count = int(out[out_idx])
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = sum(el != 0 for el in elements)
+ ph.assert_scalar_equals("count_nonzero", type_=int, idx=out_idx, out=count, expected=expected)
+
+
+@given(hh.arrays(dtype=hh.all_dtypes, shape=()))
+def test_nonzero_zerodim_error(x):
+ with pytest.raises(Exception):
+ xp.nonzero(x)
+
+
+@pytest.mark.data_dependent_shapes
+@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1, min_side=1)))
+def test_nonzero(x):
+ out = xp.nonzero(x)
+ assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
+ out_size = math.prod(out[0].shape)
+ for i in range(len(out)):
+ assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"
+ size_at = math.prod(out[i].shape)
+ assert size_at == out_size, (
+ f"prod(out[{i}].shape)={size_at}, "
+ f"but should be prod(out[0].shape)={out_size}"
+ )
+ ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
+ indices = []
+ if x.dtype == xp.bool:
+ for idx in sh.ndindex(x.shape):
+ if x[idx]:
+ indices.append(idx)
+ else:
+ for idx in sh.ndindex(x.shape):
+ if x[idx] != 0:
+ indices.append(idx)
+ if x.ndim == 0:
+ assert out_size == len(
+ indices
+ ), f"prod(out[0].shape)={out_size}, but should be {len(indices)}"
+ else:
+ for i in range(out_size):
+ idx = tuple(int(x[i]) for x in out)
+ f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}"
+ f_element = f"x[{idx}]={x[idx]}"
+ assert idx in indices, f"{f_idx} results in {f_element}, a zero element"
+ assert (
+ idx == indices[i]
+ ), f"{f_idx} is in the wrong position, should be {indices.index(idx)}"
+
+
+@given(
+ shapes=hh.mutually_broadcastable_shapes(3),
+ dtypes=hh.mutually_promotable_dtypes(),
+ data=st.data(),
+)
+def test_where(shapes, dtypes, data):
+ cond = data.draw(hh.arrays(dtype=xp.bool, shape=shapes[0]), label="condition")
+ x1 = data.draw(hh.arrays(dtype=dtypes[0], shape=shapes[1]), label="x1")
+ x2 = data.draw(hh.arrays(dtype=dtypes[1], shape=shapes[2]), label="x2")
+
+ out = xp.where(cond, x1, x2)
+
+ shape = sh.broadcast_shapes(*shapes)
+ ph.assert_shape("where", out_shape=out.shape, expected=shape)
+ # TODO: generate indices without broadcasting arrays
+ _cond = xp.broadcast_to(cond, shape)
+ _x1 = xp.broadcast_to(x1, shape)
+ _x2 = xp.broadcast_to(x2, shape)
+ for idx in sh.ndindex(shape):
+ if _cond[idx]:
+ ph.assert_0d_equals(
+ "where",
+ x_repr=f"_x1[{idx}]",
+ x_val=_x1[idx],
+ out_repr=f"out[{idx}]",
+ out_val=out[idx]
+ )
+ else:
+ ph.assert_0d_equals(
+ "where",
+ x_repr=f"_x2[{idx}]",
+ x_val=_x2[idx],
+ out_repr=f"out[{idx}]",
+ out_val=out[idx]
+ )
+
+
+@pytest.mark.min_version("2023.12")
+@given(data=st.data())
+def test_searchsorted(data):
+ # TODO: test side="right"
+ # TODO: Allow different dtypes for x1 and x2
+ _x1 = data.draw(
+ st.lists(xps.from_dtype(dh.default_float), min_size=1, unique=True),
+ label="_x1",
+ )
+ x1 = xp.asarray(_x1, dtype=dh.default_float)
+ if data.draw(st.booleans(), label="use sorter?"):
+ sorter = xp.argsort(x1)
+ else:
+ sorter = None
+ x1 = xp.sort(x1)
+ note(f"{x1=}")
+ x2 = data.draw(
+ st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
+ lambda o: xp.asarray(o, dtype=dh.default_float)
+ ),
+ label="x2",
+ )
+
+ out = xp.searchsorted(x1, x2, sorter=sorter)
+
+ ph.assert_dtype(
+ "searchsorted",
+ in_dtype=[x1.dtype, x2.dtype],
+ out_dtype=out.dtype,
+ expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
+ )
+ # TODO: shapes and values testing
diff --git a/array_api_tests/test_set_functions.py b/array_api_tests/test_set_functions.py
new file mode 100644
index 00000000..c9abaad1
--- /dev/null
+++ b/array_api_tests/test_set_functions.py
@@ -0,0 +1,239 @@
+# TODO: disable if opted out, refactor things
+import cmath
+import math
+from collections import Counter, defaultdict
+
+import pytest
+from hypothesis import assume, given
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+
+pytestmark = [pytest.mark.data_dependent_shapes, pytest.mark.unvectorized]
+
+
+@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
+def test_unique_all(x):
+ out = xp.unique_all(x)
+
+ assert hasattr(out, "values")
+ assert hasattr(out, "indices")
+ assert hasattr(out, "inverse_indices")
+ assert hasattr(out, "counts")
+
+ ph.assert_dtype(
+ "unique_all", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype"
+ )
+ ph.assert_default_index(
+ "unique_all", out.indices.dtype, repr_name="out.indices.dtype"
+ )
+ ph.assert_default_index(
+ "unique_all", out.inverse_indices.dtype, repr_name="out.inverse_indices.dtype"
+ )
+ ph.assert_default_index(
+ "unique_all", out.counts.dtype, repr_name="out.counts.dtype"
+ )
+
+ assert (
+ out.indices.shape == out.values.shape
+ ), f"{out.indices.shape=}, but should be {out.values.shape=}"
+ ph.assert_shape(
+ "unique_all",
+ out_shape=out.inverse_indices.shape,
+ expected=x.shape,
+ repr_name="out.inverse_indices.shape",
+ )
+ assert (
+ out.counts.shape == out.values.shape
+ ), f"{out.counts.shape=}, but should be {out.values.shape=}"
+
+ scalar_type = dh.get_scalar_type(out.values.dtype)
+ counts = defaultdict(int)
+ firsts = {}
+ for i, idx in enumerate(sh.ndindex(x.shape)):
+ val = scalar_type(x[idx])
+ if counts[val] == 0:
+ firsts[val] = i
+ counts[val] += 1
+
+ for idx in sh.ndindex(out.indices.shape):
+ val = scalar_type(out.values[idx])
+ if cmath.isnan(val):
+ break
+ i = int(out.indices[idx])
+ expected = firsts[val]
+ assert i == expected, (
+ f"out.values[{idx}]={val} and out.indices[{idx}]={i}, "
+ f"but first occurence of {val} is at {expected}"
+ )
+
+ for idx in sh.ndindex(out.inverse_indices.shape):
+ ridx = int(out.inverse_indices[idx])
+ val = out.values[ridx]
+ expected = x[idx]
+ msg = (
+ f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, "
+ f"but should result in x[{idx}]={expected}"
+ )
+ if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected):
+ assert xp.isnan(val), msg
+ else:
+ assert val == expected, msg
+
+ vals_idx = {}
+ nans = 0
+ for idx in sh.ndindex(out.values.shape):
+ val = scalar_type(out.values[idx])
+ count = int(out.counts[idx])
+ if cmath.isnan(val):
+ nans += 1
+ assert count == 1, (
+ f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
+ "but count should be 1 as NaNs are distinct"
+ )
+ else:
+ expected = counts[val]
+ assert (
+ expected > 0
+ ), f"out.values[{idx}]={val}, but {val} not in input array"
+ count = int(out.counts[idx])
+ assert count == expected, (
+ f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
+ f"but should be {expected}"
+ )
+ assert (
+ val not in vals_idx.keys()
+ ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
+ vals_idx[val] = idx
+
+ if dh.is_float_dtype(out.values.dtype):
+ assume(math.prod(x.shape) <= 128) # may not be representable
+ expected = sum(v for k, v in counts.items() if cmath.isnan(k))
+ assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
+
+
+@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
+def test_unique_counts(x):
+ out = xp.unique_counts(x)
+ assert hasattr(out, "values")
+ assert hasattr(out, "counts")
+ ph.assert_dtype(
+ "unique_counts", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype"
+ )
+ ph.assert_default_index(
+ "unique_counts", out.counts.dtype, repr_name="out.counts.dtype"
+ )
+ assert (
+ out.counts.shape == out.values.shape
+ ), f"{out.counts.shape=}, but should be {out.values.shape=}"
+ scalar_type = dh.get_scalar_type(out.values.dtype)
+ counts = Counter(scalar_type(x[idx]) for idx in sh.ndindex(x.shape))
+ vals_idx = {}
+ nans = 0
+ for idx in sh.ndindex(out.values.shape):
+ val = scalar_type(out.values[idx])
+ count = int(out.counts[idx])
+ if cmath.isnan(val):
+ nans += 1
+ assert count == 1, (
+ f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
+ "but count should be 1 as NaNs are distinct"
+ )
+ else:
+ expected = counts[val]
+ assert (
+ expected > 0
+ ), f"out.values[{idx}]={val}, but {val} not in input array"
+ count = int(out.counts[idx])
+ assert count == expected, (
+ f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
+ f"but should be {expected}"
+ )
+ assert (
+ val not in vals_idx.keys()
+ ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
+ vals_idx[val] = idx
+ if dh.is_float_dtype(out.values.dtype):
+ assume(math.prod(x.shape) <= 128) # may not be representable
+ expected = sum(v for k, v in counts.items() if cmath.isnan(k))
+ assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
+
+
+@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
+def test_unique_inverse(x):
+ out = xp.unique_inverse(x)
+ assert hasattr(out, "values")
+ assert hasattr(out, "inverse_indices")
+ ph.assert_dtype(
+ "unique_inverse", in_dtype=x.dtype, out_dtype=out.values.dtype, repr_name="out.values.dtype"
+ )
+ ph.assert_default_index(
+ "unique_inverse",
+ out.inverse_indices.dtype,
+ repr_name="out.inverse_indices.dtype",
+ )
+ ph.assert_shape(
+ "unique_inverse",
+ out_shape=out.inverse_indices.shape,
+ expected=x.shape,
+ repr_name="out.inverse_indices.shape",
+ )
+ scalar_type = dh.get_scalar_type(out.values.dtype)
+ distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape))
+ vals_idx = {}
+ nans = 0
+ for idx in sh.ndindex(out.values.shape):
+ val = scalar_type(out.values[idx])
+ if cmath.isnan(val):
+ nans += 1
+ else:
+ assert (
+ val in distinct
+ ), f"out.values[{idx}]={val}, but {val} not in input array"
+ assert (
+ val not in vals_idx.keys()
+ ), f"out.values[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
+ vals_idx[val] = idx
+ for idx in sh.ndindex(out.inverse_indices.shape):
+ ridx = int(out.inverse_indices[idx])
+ val = out.values[ridx]
+ expected = x[idx]
+ msg = (
+ f"out.inverse_indices[{idx}]={ridx} results in out.values[{ridx}]={val}, "
+ f"but should result in x[{idx}]={expected}"
+ )
+ if dh.is_float_dtype(out.values.dtype) and xp.isnan(expected):
+ assert xp.isnan(val), msg
+ else:
+ assert val == expected, msg
+ if dh.is_float_dtype(out.values.dtype):
+ assume(math.prod(x.shape) <= 128) # may not be representable
+ expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
+ assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}"
+
+
+@given(hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)))
+def test_unique_values(x):
+ out = xp.unique_values(x)
+ ph.assert_dtype("unique_values", in_dtype=x.dtype, out_dtype=out.dtype)
+ scalar_type = dh.get_scalar_type(x.dtype)
+ distinct = set(scalar_type(x[idx]) for idx in sh.ndindex(x.shape))
+ vals_idx = {}
+ nans = 0
+ for idx in sh.ndindex(out.shape):
+ val = scalar_type(out[idx])
+ if cmath.isnan(val):
+ nans += 1
+ else:
+ assert val in distinct, f"out[{idx}]={val}, but {val} not in input array"
+ assert (
+ val not in vals_idx.keys()
+ ), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
+ vals_idx[val] = idx
+ if dh.is_float_dtype(out.dtype):
+ assume(math.prod(x.shape) <= 128) # may not be representable
+ expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
+ assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
diff --git a/array_api_tests/test_signatures.py b/array_api_tests/test_signatures.py
index 997b59c8..1c9a8ef6 100644
--- a/array_api_tests/test_signatures.py
+++ b/array_api_tests/test_signatures.py
@@ -1,178 +1,323 @@
-import inspect
+"""
+Tests for function/method signatures compliance
+
+We're not interested in being 100% strict - instead we focus on areas which
+could affect interop, e.g. with
+
+ def add(x1, x2, /):
+ ...
+
+x1 and x2 don't need to be pos-only for the purposes of interoperability, but with
+
+ def squeeze(x, /, axis):
+ ...
+
+axis has to be pos-or-keyword to support both styles
+
+ >>> squeeze(x, 0)
+ ...
+ >>> squeeze(x, axis=0)
+ ...
+
+"""
+from collections import defaultdict
+from copy import copy
+from inspect import Parameter, Signature, signature
+from types import FunctionType
+from typing import Any, Callable, Dict, Literal, get_args
+from warnings import warn
import pytest
-from ._array_module import mod, mod_name, ones, eye, float64
-from .pytest_helpers import raises, doesnt_raise
-
-from . import function_stubs
-
-
-def stub_module(name):
- submodules = [m for m in dir(function_stubs) if
- inspect.ismodule(getattr(function_stubs, m)) and not
- m.startswith('_')]
- for m in submodules:
- if name in getattr(function_stubs, m).__all__:
- return m
-
-def array_method(name):
- return stub_module(name) == 'array_object'
-
-def function_category(name):
- return stub_module(name).rsplit('_', 1)[0].replace('_', ' ')
-
-def example_argument(arg, func_name):
- """
- Get an example argument for the argument arg for the function func_name
-
- The full tests for function behavior is in other files. We just need to
- have an example input for each argument name that should work so that we
- can check if the argument is implemented at all.
-
- """
- # Note: for keyword arguments that have a default, this should be
- # different from the default, as the default argument is tested separately
- # (it can have the same behavior as the default, just not literally the
- # same value).
- known_args = dict(
- M=1,
- N=1,
- arrays=(ones((1, 3, 3)), ones((1, 3, 3))),
- # These cannot be the same as each other, which is why all our test
- # arrays have to have at least 3 dimensions.
- axis1=2,
- axis2=2,
- axis=1,
- axes=(2, 1, 0),
- condition=ones((1, 3, 3), dtype=bool),
- correction=1.0,
- descending=True,
- dtype=float64,
- endpoint=False,
- fill_value=1.0,
- k=1,
- keepdims=True,
- key=0,
- num=2,
- offset=1,
- ord=1,
- return_counts=True,
- return_index=True,
- return_inverse=True,
- shape=(1, 3, 3),
- shift=1,
- sorted=False,
- stable=False,
- start=0,
- step=2,
- stop=1,
- value=0,
- x1=ones((1, 3, 3)),
- x2=ones((1, 3, 3)),
- x=ones((1, 3, 3)),
+from . import dtype_helpers as dh
+from . import xp
+from .stubs import (array_methods, category_to_funcs, extension_to_funcs,
+ name_to_func, info_funcs)
+
+ParameterKind = Literal[
+ Parameter.POSITIONAL_ONLY,
+ Parameter.VAR_POSITIONAL,
+ Parameter.POSITIONAL_OR_KEYWORD,
+ Parameter.KEYWORD_ONLY,
+ Parameter.VAR_KEYWORD,
+]
+ALL_KINDS = get_args(ParameterKind)
+VAR_KINDS = (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
+kind_to_str: Dict[ParameterKind, str] = {
+ Parameter.POSITIONAL_OR_KEYWORD: "pos or kw argument",
+ Parameter.POSITIONAL_ONLY: "pos-only argument",
+ Parameter.KEYWORD_ONLY: "keyword-only argument",
+ Parameter.VAR_POSITIONAL: "star-args (i.e. *args) argument",
+ Parameter.VAR_KEYWORD: "star-kwargs (i.e. **kwargs) argument",
+}
+
+
+def _test_inspectable_func(sig: Signature, stub_sig: Signature):
+ params = list(sig.parameters.values())
+ stub_params = list(stub_sig.parameters.values())
+
+ non_kwonly_stub_params = [
+ p for p in stub_params if p.kind != Parameter.KEYWORD_ONLY
+ ]
+ # sanity check
+ assert non_kwonly_stub_params == stub_params[: len(non_kwonly_stub_params)]
+ # We're not interested if the array module has additional arguments, so we
+ # only iterate through the arguments listed in the spec.
+ for i, stub_param in enumerate(non_kwonly_stub_params):
+ assert (
+ len(params) >= i + 1
+ ), f"Argument '{stub_param.name}' missing from signature"
+ param = params[i]
+
+ # We're not interested in the name if it isn't actually used
+ if stub_param.kind not in [Parameter.POSITIONAL_ONLY, *VAR_KINDS]:
+ assert (
+ param.name == stub_param.name
+ ), f"Expected argument '{param.name}' to be named '{stub_param.name}'"
+
+ if stub_param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]:
+ f_stub_kind = kind_to_str[stub_param.kind]
+ assert param.kind == stub_param.kind, (
+ f"{param.name} is a {kind_to_str[param.kind]}, "
+ f"but should be a {f_stub_kind}"
+ )
+
+ kwonly_stub_params = stub_params[len(non_kwonly_stub_params) :]
+ for stub_param in kwonly_stub_params:
+ assert (
+ stub_param.name in sig.parameters.keys()
+ ), f"Argument '{stub_param.name}' missing from signature"
+ param = next(p for p in params if p.name == stub_param.name)
+ f_stub_kind = kind_to_str[stub_param.kind]
+ assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], (
+ f"{param.name} is a {kind_to_str[param.kind]}, "
+ f"but should be a {f_stub_kind} "
+ f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})"
+ )
+
+
+def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
+ f_sig = f"{func_name}("
+ f_sig += ", ".join(str(a) for a in args)
+ if len(kwargs) != 0:
+ if len(args) != 0:
+ f_sig += ", "
+ f_sig += ", ".join(f"{k}={v}" for k, v in kwargs.items())
+ f_sig += ")"
+ return f_sig
+
+
+# We test uninspectable signatures by passing valid, manually-defined arguments
+# to the signature's function/method.
+#
+# Arguments which require use of the array module are specified as string
+# expressions to be eval()'d on runtime. This is as opposed to just using the
+# array module whilst setting up the tests, which is prone to halt the entire
+# test suite if an array module doesn't support a given expression.
+func_to_specified_args = defaultdict(
+ dict,
+ {
+ "permute_dims": {"axes": 0},
+ "reshape": {"shape": (1, 5)},
+ "broadcast_to": {"shape": (1, 5)},
+ "asarray": {"obj": [0, 1, 2, 3, 4]},
+ "full_like": {"fill_value": 42},
+ "matrix_power": {"n": 2},
+ },
+)
+func_to_specified_arg_exprs = defaultdict(
+ dict,
+ {
+ "stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
+ "iinfo": {"type": "xp.int64"},
+ "finfo": {"type": "xp.float64"},
+ "cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"},
+ "inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"},
+ "solve": {
+ a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"]
+ },
+ "outer": {"x1": "xp.ones((5,))", "x2": "xp.ones((5,))"},
+ },
+)
+# We default most array arguments heuristically. As functions/methods work only
+# with arrays of certain dtypes and shapes, we specify only supported arrays
+# respective to the function.
+casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"]
+matrixy_names = [
+ f.__name__
+ for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"]
+]
+matrixy_names += ["__matmul__", "triu", "tril"]
+for func_name, func in name_to_func.items():
+ stub_sig = signature(func)
+ array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"}
+ if func in array_methods:
+ array_argnames.add("self")
+ array_argnames -= set(func_to_specified_arg_exprs[func_name].keys())
+ if len(array_argnames) > 0:
+ in_dtypes = dh.func_in_dtypes[func_name]
+ for dtype_name in ["float64", "bool", "int64", "complex128"]:
+ # We try float64 first because uninspectable numerical functions
+ # tend to support float inputs first-and-foremost (i.e. PyTorch)
+ try:
+ dtype = getattr(xp, dtype_name)
+ except AttributeError:
+ pass
+ else:
+ if dtype in in_dtypes:
+ if func_name in casty_names:
+ shape = ()
+ elif func_name in matrixy_names:
+ shape = (3, 3)
+ else:
+ shape = (5,)
+ fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})"
+ break
+ else:
+ warn(
+ f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does "
+ "not contain any assumed dtypes, so skipping specifying fallback array."
+ )
+ continue
+ for argname in array_argnames:
+ func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr
+
+
+def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
+ params = list(stub_sig.parameters.values())
+
+ if len(params) == 0:
+ func()
+ return
+
+ uninspectable_msg = (
+ f"Note {func_name}() is not inspectable so arguments are passed "
+ "manually to test the signature."
)
- if arg in known_args:
- # Special cases:
-
- # squeeze() requires an axis of size 1, but other functions such as
- # cross() require axes of size >1
- if func_name == 'squeeze' and arg == 'axis':
- return 0
- # ones() is not invertible
- elif func_name == 'inv' and arg == 'x':
- return eye(3)
- return known_args[arg]
- else:
- raise RuntimeError(f"Don't know how to test argument {arg}. Please update test_signatures.py")
-
-@pytest.mark.parametrize('name', function_stubs.__all__)
-def test_has_names(name):
- if array_method(name):
- arr = ones((1,))
- if getattr(function_stubs.array_object, name) is None:
- assert hasattr(arr, name), f"The array object is missing the attribute {name}"
+ argname_to_arg = copy(func_to_specified_args[func_name])
+ argname_to_expr = func_to_specified_arg_exprs[func_name]
+ for argname, expr in argname_to_expr.items():
+ assert argname not in argname_to_arg.keys() # sanity check
+ try:
+ argname_to_arg[argname] = eval(expr, {"xp": xp})
+ except Exception as e:
+ pytest.skip(
+ f"Exception occured when evaluating {argname}={expr}: {e}\n"
+ f"{uninspectable_msg}"
+ )
+
+ posargs = []
+ posorkw_args = {}
+ kwargs = {}
+ no_arg_msg = (
+ "We have no argument specified for '{}'. Please ensure you're using "
+ "the latest version of array-api-tests, then open an issue if one "
+ f"doesn't already exist. {uninspectable_msg}"
+ )
+ for param in params:
+ if param.kind == Parameter.POSITIONAL_ONLY:
+ try:
+ posargs.append(argname_to_arg[param.name])
+ except KeyError:
+ pytest.skip(no_arg_msg.format(param.name))
+ elif param.kind == Parameter.POSITIONAL_OR_KEYWORD:
+ if param.default == Parameter.empty:
+ try:
+ posorkw_args[param.name] = argname_to_arg[param.name]
+ except KeyError:
+ pytest.skip(no_arg_msg.format(param.name))
+ else:
+ assert argname_to_arg[param.name]
+ posorkw_args[param.name] = param.default
+ elif param.kind == Parameter.KEYWORD_ONLY:
+ assert param.default != Parameter.empty # sanity check
+ kwargs[param.name] = param.default
else:
- assert hasattr(arr, name), f"The array object is missing the method {name}()"
- else:
- assert hasattr(mod, name), f"{mod_name} is missing the {function_category(name)} function {name}()"
-
-@pytest.mark.parametrize('name', function_stubs.__all__)
-def test_function_positional_args(name):
- # Note: We can't actually test that positional arguments are
- # positional-only, as that would require knowing the argument name and
- # checking that it can't be used as a keyword argument. But argument name
- # inspection does not work for most array library functions that are not
- # written in pure Python (e.g., it won't work for numpy ufuncs).
- if array_method(name):
- _mod = ones((1,))
- else:
- _mod = mod
-
- if not hasattr(_mod, name):
- pytest.skip(f"{mod_name} does not have {name}(), skipping.")
- stub_func = getattr(function_stubs, name)
- if stub_func is None:
- # TODO: Can we make this skip the parameterization entirely?
- pytest.skip(f"{name} is not a function, skipping.")
- mod_func = getattr(_mod, name)
- argspec = inspect.getfullargspec(stub_func)
- args = argspec.args
- if name.startswith('__'):
- args = args[1:]
- nargs = len(args)
- if argspec.defaults:
- raise RuntimeError(f"Unexpected non-keyword-only keyword argument for {name}. Please update test_signatures.py")
-
- args = [example_argument(arg, name) for arg in args]
- if not args:
- args = [example_argument('x', name)]
+ assert param.kind in VAR_KINDS # sanity check
+ pytest.skip(no_arg_msg.format(param.name))
+ if len(posorkw_args) == 0:
+ func(*posargs, **kwargs)
else:
- # Duplicate the last positional argument for the n+1 test.
- args = args + [args[-1]]
+ posorkw_name_to_arg_pairs = list(posorkw_args.items())
+ for i in range(len(posorkw_name_to_arg_pairs), -1, -1):
+ extra_posargs = [arg for _, arg in posorkw_name_to_arg_pairs[:i]]
+ extra_kwargs = dict(posorkw_name_to_arg_pairs[i:])
+ func(*posargs, *extra_posargs, **kwargs, **extra_kwargs)
- for n in range(nargs+2):
- if n == nargs:
- doesnt_raise(lambda: mod_func(*args[:n]))
- else:
- # NumPy ufuncs raise ValueError instead of TypeError
- raises((TypeError, ValueError), lambda: mod_func(*args[:n]), f"{name}() should not accept {n} positional arguments")
-@pytest.mark.parametrize('name', function_stubs.__all__)
-def test_function_keyword_only_args(name):
- if array_method(name):
- _mod = ones((1,))
+def _test_func_signature(func: Callable, stub: FunctionType, is_method=False):
+ stub_sig = signature(stub)
+ # If testing against array, ignore 'self' arg in stub as it won't be present
+ # in func (which should be a method).
+ if is_method:
+ stub_params = list(stub_sig.parameters.values())
+ if stub_params[0].name == "self":
+ del stub_params[0]
+ stub_sig = Signature(
+ parameters=stub_params, return_annotation=stub_sig.return_annotation
+ )
+
+ try:
+ sig = signature(func)
+ except ValueError:
+ try:
+ _test_uninspectable_func(stub.__name__, func, stub_sig)
+ except Exception as e:
+ raise e from None # suppress parent exception for cleaner pytest output
else:
- _mod = mod
-
- if not hasattr(_mod, name):
- pytest.skip(f"{mod_name} does not have {name}(), skipping.")
- stub_func = getattr(function_stubs, name)
- if stub_func is None:
- # TODO: Can we make this skip the parameterization entirely?
- pytest.skip(f"{name} is not a function, skipping.")
- mod_func = getattr(_mod, name)
- argspec = inspect.getfullargspec(stub_func)
- args = argspec.args
- if name.startswith('__'):
- args = args[1:]
- kwonlyargs = argspec.kwonlyargs
- kwonlydefaults = argspec.kwonlydefaults
-
- args = [example_argument(arg, name) for arg in args]
-
- for arg in kwonlyargs:
- value = example_argument(arg, name)
- # The "only" part of keyword-only is tested by the positional test above.
- doesnt_raise(lambda: mod_func(*args, **{arg: value}),
- f"{name}() should accept the keyword-only argument {arg!r}")
-
- # Make sure the default is accepted. These tests are not granular
- # enough to test that the default is actually the default, i.e., gives
- # the same value if the keyword isn't passed. That is tested in the
- # specific function tests.
- if arg in kwonlydefaults:
- default_value = kwonlydefaults[arg]
- doesnt_raise(lambda: mod_func(*args, **{arg: default_value}),
- f"{name}() should accept the default value {default_value!r} for the keyword-only argument {arg!r}")
+ _test_inspectable_func(sig, stub_sig)
+
+
+@pytest.mark.parametrize(
+ "stub",
+ [s for stubs in category_to_funcs.values() for s in stubs],
+ ids=lambda f: f.__name__,
+)
+def test_func_signature(stub: FunctionType):
+ assert hasattr(xp, stub.__name__), f"{stub.__name__} not found in array module"
+ func = getattr(xp, stub.__name__)
+ _test_func_signature(func, stub)
+
+
+extension_and_stub_params = []
+for ext, stubs in extension_to_funcs.items():
+ for stub in stubs:
+ p = pytest.param(
+ ext, stub, id=f"{ext}.{stub.__name__}", marks=pytest.mark.xp_extension(ext)
+ )
+ extension_and_stub_params.append(p)
+
+
+@pytest.mark.parametrize("extension, stub", extension_and_stub_params)
+def test_extension_func_signature(extension: str, stub: FunctionType):
+ mod = getattr(xp, extension)
+ assert hasattr(
+ mod, stub.__name__
+ ), f"{stub.__name__} not found in {extension} extension"
+ func = getattr(mod, stub.__name__)
+ _test_func_signature(func, stub)
+
+
+@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
+def test_array_method_signature(stub: FunctionType):
+ x_expr = func_to_specified_arg_exprs[stub.__name__]["self"]
+ try:
+ x = eval(x_expr, {"xp": xp})
+ except Exception as e:
+ pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}")
+ assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
+ method = getattr(x, stub.__name__)
+ _test_func_signature(method, stub, is_method=True)
+
+if info_funcs: # pytest fails collecting if info_funcs is empty
+ @pytest.mark.min_version("2023.12")
+ @pytest.mark.parametrize("stub", info_funcs, ids=lambda f: f.__name__)
+ def test_info_func_signature(stub: FunctionType):
+ try:
+ info_namespace = xp.__array_namespace_info__()
+ except Exception as e:
+ raise AssertionError(f"Could not get info namespace from xp.__array_namespace_info__(): {e}")
+
+ func = getattr(info_namespace, stub.__name__)
+ _test_func_signature(func, stub)
diff --git a/array_api_tests/test_sorting_functions.py b/array_api_tests/test_sorting_functions.py
new file mode 100644
index 00000000..3d25798c
--- /dev/null
+++ b/array_api_tests/test_sorting_functions.py
@@ -0,0 +1,138 @@
+import cmath
+from typing import Set
+
+import pytest
+from hypothesis import given
+from hypothesis import strategies as st
+from hypothesis.control import assume
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from .typing import Scalar, Shape
+
+
+def assert_scalar_in_set(
+ func_name: str,
+ idx: Shape,
+ out: Scalar,
+ set_: Set[Scalar],
+ kw={},
+):
+ out_repr = "out" if idx == () else f"out[{idx}]"
+ if cmath.isnan(out):
+ raise NotImplementedError()
+ msg = f"{out_repr}={out}, but should be in {set_} [{func_name}({ph.fmt_kw(kw)})]"
+ assert out in set_, msg
+
+
+# TODO: Test with signed zeros and NaNs (and ignore them somehow)
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.real_dtypes,
+ shape=hh.shapes(min_dims=1, min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_argsort(x, data):
+ if dh.is_float_dtype(x.dtype):
+ assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
+
+ kw = data.draw(
+ hh.kwargs(
+ axis=st.integers(-x.ndim, x.ndim - 1),
+ descending=st.booleans(),
+ stable=st.booleans(),
+ ),
+ label="kw",
+ )
+
+ out = xp.argsort(x, **kw)
+
+ ph.assert_default_index("argsort", out.dtype)
+ ph.assert_shape("argsort", out_shape=out.shape, expected=x.shape, kw=kw)
+ axis = kw.get("axis", -1)
+ axes = sh.normalize_axis(axis, x.ndim)
+ scalar_type = dh.get_scalar_type(x.dtype)
+ for indices in sh.axes_ndindex(x.shape, axes):
+ elements = [scalar_type(x[idx]) for idx in indices]
+ orders = list(range(len(elements)))
+ sorders = sorted(
+ orders, key=elements.__getitem__, reverse=kw.get("descending", False)
+ )
+ if kw.get("stable", True):
+ for idx, o in zip(indices, sorders):
+ ph.assert_scalar_equals("argsort", type_=int, idx=idx, out=int(out[idx]), expected=o, kw=kw)
+ else:
+ idx_elements = dict(zip(indices, elements))
+ idx_orders = dict(zip(indices, orders))
+ element_orders = {}
+ for e in set(elements):
+ element_orders[e] = [
+ idx_orders[idx] for idx in indices if idx_elements[idx] == e
+ ]
+ selements = [elements[o] for o in sorders]
+ for idx, e in zip(indices, selements):
+ expected_orders = element_orders[e]
+ out_o = int(out[idx])
+ if len(expected_orders) == 1:
+ ph.assert_scalar_equals(
+ "argsort", type_=int, idx=idx, out=out_o, expected=expected_orders[0], kw=kw
+ )
+ else:
+ assert_scalar_in_set(
+ "argsort", idx=idx, out=out_o, set_=set(expected_orders), kw=kw
+ )
+
+
+@pytest.mark.unvectorized
+# TODO: Test with signed zeros and NaNs (and ignore them somehow)
+@given(
+ x=hh.arrays(
+ dtype=hh.real_dtypes,
+ shape=hh.shapes(min_dims=1, min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_sort(x, data):
+ if dh.is_float_dtype(x.dtype):
+ assume(not xp.any(x == -0.0) and not xp.any(x == +0.0))
+
+ kw = data.draw(
+ hh.kwargs(
+ axis=st.integers(-x.ndim, x.ndim - 1),
+ descending=st.booleans(),
+ stable=st.booleans(),
+ ),
+ label="kw",
+ )
+
+ out = xp.sort(x, **kw)
+
+ ph.assert_dtype("sort", out_dtype=out.dtype, in_dtype=x.dtype)
+ ph.assert_shape("sort", out_shape=out.shape, expected=x.shape, kw=kw)
+ axis = kw.get("axis", -1)
+ axes = sh.normalize_axis(axis, x.ndim)
+ scalar_type = dh.get_scalar_type(x.dtype)
+ for indices in sh.axes_ndindex(x.shape, axes):
+ elements = [scalar_type(x[idx]) for idx in indices]
+ size = len(elements)
+ orders = sorted(
+ range(size), key=elements.__getitem__, reverse=kw.get("descending", False)
+ )
+ for out_idx, o in zip(indices, orders):
+ x_idx = indices[o]
+ # TODO: error message when unstable should not imply just one idx
+ ph.assert_0d_equals(
+ "sort",
+ x_repr=f"x[{x_idx}]",
+ x_val=x[x_idx],
+ out_repr=f"out[{out_idx}]",
+ out_val=out[out_idx],
+ kw=kw,
+ )
diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py
new file mode 100644
index 00000000..bf05a262
--- /dev/null
+++ b/array_api_tests/test_special_cases.py
@@ -0,0 +1,1353 @@
+"""
+Tests for special cases.
+
+Most test cases for special casing are built on runtime via the parametrized
+tests test_unary/test_binary/test_iop. Most of this file consists of utility
+classes and functions, all bought together to create the test cases (pytest
+params), to finally be run through generalised test logic.
+
+TODO: test integer arrays for relevant special cases
+"""
+# We use __future__ for forward reference type hints - this will work for even py3.8.0
+# See https://stackoverflow.com/a/33533514/5193926
+from __future__ import annotations
+
+import inspect
+import math
+import operator
+import re
+from dataclasses import dataclass, field
+from decimal import ROUND_HALF_EVEN, Decimal
+from enum import Enum, auto
+from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Literal
+from warnings import warn, filterwarnings, catch_warnings
+
+import pytest
+from hypothesis import given, note, settings, assume
+from hypothesis import strategies as st
+from hypothesis.errors import NonInteractiveExampleWarning
+
+from array_api_tests.typing import Array, DataType
+
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import xp, xps
+from .stubs import category_to_funcs
+
+UnaryCheck = Callable[[float], bool]
+BinaryCheck = Callable[[float, float], bool]
+
+
+def make_strict_eq(v: float) -> UnaryCheck:
+ if math.isnan(v):
+ return math.isnan
+ if v == 0:
+ if ph.is_pos_zero(v):
+ return ph.is_pos_zero
+ else:
+ return ph.is_neg_zero
+
+ def strict_eq(i: float) -> bool:
+ return i == v
+
+ return strict_eq
+
+
+def make_strict_neq(v: float) -> UnaryCheck:
+ strict_eq = make_strict_eq(v)
+
+ def strict_neq(i: float) -> bool:
+ return not strict_eq(i)
+
+ return strict_neq
+
+
+def make_rough_eq(v: float) -> UnaryCheck:
+ assert math.isfinite(v) # sanity check
+
+ def rough_eq(i: float) -> bool:
+ return math.isclose(i, v, abs_tol=0.01)
+
+ return rough_eq
+
+
+def make_gt(v: float) -> UnaryCheck:
+ assert not math.isnan(v) # sanity check
+
+ def gt(i: float) -> bool:
+ return i > v
+
+ return gt
+
+
+def make_lt(v: float) -> UnaryCheck:
+ assert not math.isnan(v) # sanity check
+
+ def lt(i: float) -> bool:
+ return i < v
+
+ return lt
+
+
+def make_or(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck:
+ def or_(i: float) -> bool:
+ return cond1(i) or cond2(i)
+
+ return or_
+
+
+def make_and(cond1: UnaryCheck, cond2: UnaryCheck) -> UnaryCheck:
+ def and_(i: float) -> bool:
+ return cond1(i) or cond2(i)
+
+ return and_
+
+
+def make_not_cond(cond: UnaryCheck) -> UnaryCheck:
+ def not_cond(i: float) -> bool:
+ return not cond(i)
+
+ return not_cond
+
+
+def absify_cond(cond: UnaryCheck) -> UnaryCheck:
+ def abs_cond(i: float) -> bool:
+ return cond(abs(i))
+
+ return abs_cond
+
+
+repr_to_value = {
+ "NaN": float("nan"),
+ "infinity": float("inf"),
+ "0": 0.0,
+ "1": 1.0,
+ "False": 0.0,
+ "True": 1.0,
+}
+r_value = re.compile(r"([+-]?)(.+)")
+r_pi = re.compile(r"(\d?)π(?:/(\d))?")
+
+
+@dataclass
+class ParseError(ValueError):
+ value: str
+
+
+def parse_value(value_str: str) -> float:
+ """
+ Parses a value string to return a float, e.g.
+
+ >>> parse_value('1')
+ 1.
+ >>> parse_value('-infinity')
+ -float('inf')
+ >>> parse_value('3π/4')
+ 2.356194490192345
+
+ """
+ m = r_value.match(value_str)
+ if m is None:
+ raise ParseError(value_str)
+ if pi_m := r_pi.match(m.group(2)):
+ value = math.pi
+ if numerator := pi_m.group(1):
+ value *= int(numerator)
+ if denominator := pi_m.group(2):
+ value /= int(denominator)
+ else:
+ try:
+ value = repr_to_value[m.group(2)]
+ except KeyError as e:
+ raise ParseError(value_str) from e
+ if sign := m.group(1):
+ if sign == "-":
+ value *= -1
+ return value
+
+
+r_code = re.compile(r"``([^\s]+)``")
+r_approx_value = re.compile(
+ rf"an implementation-dependent approximation to {r_code.pattern}"
+)
+r_not = re.compile("not (.+)")
+r_equal_to = re.compile(f"equal to {r_code.pattern}")
+r_array_element = re.compile(r"``([+-]?)x([12])_i``")
+r_either_code = re.compile(f"either {r_code.pattern} or {r_code.pattern}")
+r_gt = re.compile(f"greater than {r_code.pattern}")
+r_lt = re.compile(f"less than {r_code.pattern}")
+
+
+class FromDtypeFunc(Protocol):
+ """
+ Type hint for functions that return an elements strategy for arrays of the
+ given dtype, e.g. xps.from_dtype().
+ """
+
+ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
+ ...
+
+
+@dataclass
+class BoundFromDtype(FromDtypeFunc):
+ """
+ A xps.from_dtype()-like callable with bounded kwargs, filters and base function.
+
+ We can bound:
+
+ 1. Keyword arguments that xps.from_dtype() can use, e.g.
+
+ >>> from_dtype = BoundFromDtype(kwargs={'min_value': 0, 'allow_infinity': False})
+ >>> strategy = from_dtype(xp.float64)
+
+ is equivalent to
+
+ >>> strategy = xps.from_dtype(xp.float64, min_value=0, allow_infinity=False)
+
+ i.e. a strategy that generates finite floats above 0
+
+ 2. Functions that filter the elements strategy that xps.from_dtype() returns, e.g.
+
+ >>> from_dtype = BoundFromDtype(filter=lambda i: i != 0)
+ >>> strategy = from_dtype(xp.float64)
+
+ is equivalent to
+
+ >>> strategy = xps.from_dtype(xp.float64).filter(lambda i: i != 0)
+
+ i.e. a strategy that generates any float except +0 and -0
+
+ 3. The underlying function that returns an elements strategy from a dtype, e.g.
+
+ >>> from_dtype = BoundFromDtype(
+ ... from_dtype=lambda d: st.integers(
+ ... math.ceil(xp.finfo(d).min), math.floor(xp.finfo(d).max)
+ ... )
+ ... )
+ >>> strategy = from_dtype(xp.float64)
+
+ is equivalent to
+
+ >>> strategy = st.integers(
+ ... math.ceil(xp.finfo(xp.float64).min), math.floor(xp.finfo(xp.float64).max)
+ ... )
+
+ i.e. a strategy that generates integers (within the dtype's range)
+
+ This is useful to avoid translating special case conditions into either a
+ dict, filter or "base func", and instead allows us to generalise these three
+ components into a callable equivalent of xps.from_dtype().
+
+ Additionally, BoundFromDtype instances can be added together. This allows us
+ to keep parsing each condition individually - so we don't need to duplicate
+ complicated parsing code - as ultimately we can represent (and subsequently
+ test for) special cases which have more than one condition per array, e.g.
+
+ "If x1_i is greater than 0 and x1_i is not 42, ..."
+
+ could be translated as
+
+ >>> gt_0_from_dtype = BoundFromDtype(kwargs={'min_value': 0})
+ >>> not_42_from_dtype = BoundFromDtype(filter=lambda i: i != 42)
+ >>> gt_0_from_dtype + not_42_from_dtype
+ BoundFromDtype(kwargs={'min_value': 0}, filter=(i))
+
+ """
+
+ kwargs: Dict[str, Any] = field(default_factory=dict)
+ filter_: Optional[Callable[[Array], bool]] = None
+ base_func: Optional[FromDtypeFunc] = None
+
+ def __call__(self, dtype: DataType, **kw) -> st.SearchStrategy[float]:
+ assert len(kw) == 0 # sanity check
+ from_dtype = self.base_func or hh.from_dtype
+ strat = from_dtype(dtype, **self.kwargs)
+ if self.filter_ is not None:
+ strat = strat.filter(self.filter_)
+ return strat
+
+ def __add__(self, other: BoundFromDtype) -> BoundFromDtype:
+ for k in self.kwargs.keys():
+ if k in other.kwargs.keys():
+ assert self.kwargs[k] == other.kwargs[k] # sanity check
+ kwargs = {**self.kwargs, **other.kwargs}
+
+ if self.filter_ is not None and other.filter_ is not None:
+ filter_ = lambda i: self.filter_(i) and other.filter_(i)
+ else:
+ if self.filter_ is not None:
+ filter_ = self.filter_
+ elif other.filter_ is not None:
+ filter_ = other.filter_
+ else:
+ filter_ = None
+
+ # sanity check
+ assert not (self.base_func is not None and other.base_func is not None)
+ if self.base_func is not None:
+ base_func = self.base_func
+ elif other.base_func is not None:
+ base_func = other.base_func
+ else:
+ base_func = None
+
+ return BoundFromDtype(kwargs, filter_, base_func)
+
+
+def wrap_strat_as_from_dtype(strat: st.SearchStrategy[float]) -> FromDtypeFunc:
+ """
+ Wraps an elements strategy as a xps.from_dtype()-like function
+ """
+
+ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
+ assert len(kw) == 0 # sanity check
+ return strat
+
+ return from_dtype
+
+
+def parse_cond(cond_str: str) -> Tuple[UnaryCheck, str, BoundFromDtype]:
+ """
+ Parses a Sphinx-formatted condition string to return:
+
+ 1. A function which takes an input and returns True if it meets the
+ condition, otherwise False.
+ 2. A string template for expressing the condition.
+ 3. A xps.from_dtype()-like function which returns a strategy that generates
+ elements that meet the condition.
+
+ e.g.
+
+ >>> cond, expr_template, from_dtype = parse_cond('greater than ``0``')
+ >>> cond(42)
+ True
+ >>> cond(-123)
+ False
+ >>> expr_template.replace('{}', 'x_i')
+ 'x_i > 0'
+ >>> strategy = from_dtype(xp.float64)
+ >>> for _ in range(5):
+ ... print(strategy.example())
+ 1.
+ 0.1
+ 1.7976931348623155e+179
+ inf
+ 124.978
+
+ """
+ # We first identify whether the condition starts with "not". If so, we note
+ # this but parse the condition as if it was not negated.
+ if m := r_not.match(cond_str):
+ cond_str = m.group(1)
+ not_cond = True
+ else:
+ not_cond = False
+
+ # We parse the condition to identify the condition function, expression
+ # template, and xps.from_dtype()-like condition strategy.
+ kwargs = {}
+ filter_ = None
+ from_dtype = None # type: ignore
+ if m := r_code.match(cond_str):
+ value = parse_value(m.group(1))
+ cond = make_strict_eq(value)
+ expr_template = "{} is " + m.group(1)
+ from_dtype = wrap_strat_as_from_dtype(st.just(value))
+ elif m := r_either_code.match(cond_str):
+ v1 = parse_value(m.group(1))
+ v2 = parse_value(m.group(2))
+ cond = make_or(make_strict_eq(v1), make_strict_eq(v2))
+ expr_template = "({} is " + m.group(1) + " or {} == " + m.group(2) + ")"
+ from_dtype = wrap_strat_as_from_dtype(st.sampled_from([v1, v2]))
+ elif m := r_equal_to.match(cond_str):
+ value = parse_value(m.group(1))
+ if math.isnan(value):
+ raise ParseError(cond_str)
+ cond = lambda i: i == value
+ expr_template = "{} == " + m.group(1)
+ elif m := r_gt.match(cond_str):
+ value = parse_value(m.group(1))
+ cond = make_gt(value)
+ expr_template = "{} > " + m.group(1)
+ kwargs = {"min_value": value, "exclude_min": True}
+ elif m := r_lt.match(cond_str):
+ value = parse_value(m.group(1))
+ cond = make_lt(value)
+ expr_template = "{} < " + m.group(1)
+ kwargs = {"max_value": value, "exclude_max": True}
+ elif cond_str in ["finite", "a finite number"]:
+ cond = math.isfinite
+ expr_template = "isfinite({})"
+ kwargs = {"allow_nan": False, "allow_infinity": False}
+ elif cond_str in "a positive (i.e., greater than ``0``) finite number":
+ cond = lambda i: math.isfinite(i) and i > 0
+ expr_template = "isfinite({}) and {} > 0"
+ kwargs = {
+ "allow_nan": False,
+ "allow_infinity": False,
+ "min_value": 0,
+ "exclude_min": True,
+ }
+ elif cond_str == "a negative (i.e., less than ``0``) finite number":
+ cond = lambda i: math.isfinite(i) and i < 0
+ expr_template = "isfinite({}) and {} < 0"
+ kwargs = {
+ "allow_nan": False,
+ "allow_infinity": False,
+ "max_value": 0,
+ "exclude_max": True,
+ }
+ elif cond_str == "positive":
+ cond = lambda i: math.copysign(1, i) == 1
+ expr_template = "copysign(1, {}) == 1"
+ # We assume (positive) zero is special cased seperately
+ kwargs = {"min_value": 0, "exclude_min": True}
+ elif cond_str == "negative":
+ cond = lambda i: math.copysign(1, i) == -1
+ expr_template = "copysign(1, {}) == -1"
+ # We assume (negative) zero is special cased seperately
+ kwargs = {"max_value": 0, "exclude_max": True}
+ elif "nonzero finite" in cond_str:
+ cond = lambda i: math.isfinite(i) and i != 0
+ expr_template = "isfinite({}) and {} != 0"
+ kwargs = {"allow_nan": False, "allow_infinity": False}
+ filter_ = lambda n: n != 0
+ elif cond_str == "an integer value":
+ cond = lambda i: i.is_integer()
+ expr_template = "{}.is_integer()"
+ from_dtype = integers_from_dtype # type: ignore
+ elif cond_str == "an odd integer value":
+ cond = lambda i: i.is_integer() and i % 2 == 1
+ expr_template = "{}.is_integer() and {} % 2 == 1"
+ if not_cond:
+ expr_template = f"({expr_template})"
+
+ def from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
+ return integers_from_dtype(dtype, **kw).filter(lambda n: n % 2 == 1)
+
+ else:
+ raise ParseError(cond_str)
+
+ if not_cond:
+ # We handle negated conitions by simply negating the condition function
+ # and using it as a filter for xps.from_dtype() (or an equivalent).
+ cond = make_not_cond(cond)
+ expr_template = f"not {expr_template}"
+ filter_ = cond
+ return cond, expr_template, BoundFromDtype(filter_=filter_)
+ else:
+ return cond, expr_template, BoundFromDtype(kwargs, filter_, from_dtype)
+
+
+def parse_result(result_str: str) -> Tuple[UnaryCheck, str]:
+ """
+ Parses a Sphinx-formatted result string to return:
+
+ 1. A function which takes an input and returns True if it is the expected
+ result (or meets the condition of the expected result), otherwise False.
+ 2. A string that expresses the result.
+
+ e.g.
+
+ >>> check_result, expr = parse_result('``42``')
+ >>> check_result(7)
+ False
+ >>> check_result(42)
+ True
+ >>> expr
+ '42'
+
+ """
+ if m := r_code.match(result_str):
+ value = parse_value(m.group(1))
+ check_result = make_strict_eq(value) # type: ignore
+ expr = m.group(1)
+ elif m := r_approx_value.match(result_str):
+ value = parse_value(m.group(1))
+ check_result = make_rough_eq(value) # type: ignore
+ repr_ = m.group(1).replace("π", "pi") # for pytest param names
+ expr = f"roughly {repr_}"
+ elif "positive" in result_str:
+
+ def check_result(result: float) -> bool:
+ if math.isnan(result):
+ # The sign of NaN is out-of-scope
+ return True
+ return math.copysign(1, result) == 1
+
+ expr = "positive sign"
+ elif "negative" in result_str:
+
+ def check_result(result: float) -> bool:
+ if math.isnan(result):
+ # The sign of NaN is out-of-scope
+ return True
+ return math.copysign(1, result) == -1
+
+ expr = "negative sign"
+ else:
+ raise ParseError(result_str)
+
+ return check_result, expr
+
+
+class Case(Protocol):
+ cond_expr: str
+ result_expr: str
+ raw_case: Optional[str]
+
+ def cond(self, *args) -> bool:
+ ...
+
+ def check_result(self, *args) -> bool:
+ ...
+
+ def __str__(self) -> str:
+ return f"{self.cond_expr} -> {self.result_expr}"
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(<{self}>)"
+
+
+r_case_block = re.compile(
+ r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*"
+ r"(?:.+\n--+)?(?:\.\. versionchanged.*)?"
+)
+r_case = re.compile(r"\s+-\s*(.*)\.")
+
+
+class UnaryCond(Protocol):
+ def __call__(self, i: float) -> bool:
+ ...
+
+
+class UnaryResultCheck(Protocol):
+ def __call__(self, i: float, result: float) -> bool:
+ ...
+
+
+@dataclass(repr=False)
+class UnaryCase(Case):
+ cond_expr: str
+ result_expr: str
+ cond_from_dtype: FromDtypeFunc
+ cond: UnaryCheck
+ check_result: UnaryResultCheck
+ raw_case: Optional[str] = field(default=None)
+
+
+r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
+r_already_int_case = re.compile(
+ "If ``x_i`` is already integer-valued, the result is ``x_i``"
+)
+r_even_round_halves_case = re.compile(
+ "If two integers are equally close to ``x_i``, "
+ "the result is the even integer closest to ``x_i``"
+)
+r_nan_signbit = re.compile(
+ "If ``x_i`` is ``NaN`` and the sign bit of ``x_i`` is ``(.+)``, "
+ "the result is ``(.+)``"
+)
+
+
+def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
+ """
+ Returns a strategy that generates float-casted integers within the bounds of dtype.
+ """
+ for k in kw.keys():
+ # sanity check
+ assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
+ m, M = dh.dtype_ranges[dtype]
+ if "min_value" in kw.keys():
+ m = kw["min_value"]
+ if "exclude_min" in kw.keys():
+ m += 1
+ if "max_value" in kw.keys():
+ M = kw["max_value"]
+ if "exclude_max" in kw.keys():
+ M -= 1
+ return st.integers(math.ceil(m), math.floor(M)).map(float)
+
+
+def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
+ """
+ Returns a strategy that generates floats that end with .5 and are within the
+ bounds of dtype.
+ """
+ # We bound our base integers strategy to a range of values which should be
+ # able to represent a decimal 5 when .5 is added or subtracted.
+ if dtype == xp.float32:
+ abs_max = 10**4
+ else:
+ abs_max = 10**16
+ return st.sampled_from([0.5, -0.5]).flatmap(
+ lambda half: st.integers(-abs_max, abs_max).map(lambda n: n + half)
+ )
+
+
+already_int_case = UnaryCase(
+ cond_expr="x_i.is_integer()",
+ cond=lambda i: i.is_integer(),
+ cond_from_dtype=integers_from_dtype,
+ result_expr="x_i",
+ check_result=lambda i, result: i == result,
+)
+even_round_halves_case = UnaryCase(
+ cond_expr="modf(i)[0] == 0.5",
+ cond=lambda i: math.modf(i)[0] == 0.5,
+ cond_from_dtype=trailing_halves_from_dtype,
+ result_expr="Decimal(i).to_integral_exact(ROUND_HALF_EVEN)",
+ check_result=lambda i, result: (
+ result == float(Decimal(i).to_integral_exact(ROUND_HALF_EVEN))
+ ),
+)
+
+
+def make_nan_signbit_case(signbit: Literal[0, 1], expected: bool) -> UnaryCase:
+ if signbit:
+ sign = -1
+ nan_expr = "-NaN"
+ float_arg = "-nan"
+ else:
+ sign = 1
+ nan_expr = "+NaN"
+ float_arg = "nan"
+
+ return UnaryCase(
+ cond_expr=f"x_i is {nan_expr}",
+ cond=lambda i: math.isnan(i) and math.copysign(1, i) == sign,
+ cond_from_dtype=lambda _: st.just(float(float_arg)),
+ result_expr=str(expected),
+ check_result=lambda _, result: result == float(expected),
+ )
+
+
+def make_unary_check_result(check_just_result: UnaryCheck) -> UnaryResultCheck:
+ def check_result(i: float, result: float) -> bool:
+ return check_just_result(result)
+
+ return check_result
+
+
+def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
+ """
+ Parses a Sphinx-formatted docstring of a unary function to return a list of
+ codified unary cases, e.g.
+
+ >>> def sqrt(x):
+ ... '''
+ ... Calculates the square root
+ ...
+ ... **Special Cases**
+ ...
+ ... For floating-point operands,
+ ...
+ ... - If ``x_i`` is less than ``0``, the result is ``NaN``.
+ ... - If ``x_i`` is ``NaN``, the result is ``NaN``.
+ ... - If ``x_i`` is ``+0``, the result is ``+0``.
+ ... - If ``x_i`` is ``-0``, the result is ``-0``.
+ ... - If ``x_i`` is ``+infinity``, the result is ``+infinity``.
+ ...
+ ... Parameters
+ ... ----------
+ ... x: array
+ ... input array
+ ...
+ ... Returns
+ ... -------
+ ... out: array
+ ... an array containing the square root of each element in ``x``
+ ... '''
+ ...
+ >>> case_block = r_case_block.search(sqrt.__doc__).group(1)
+ >>> unary_cases = parse_unary_case_block(case_block, 'sqrt')
+ >>> for case in unary_cases:
+ ... print(repr(case))
+ UnaryCase( NaN>)
+ UnaryCase( NaN>)
+ UnaryCase( +0>)
+ UnaryCase( -0>)
+ UnaryCase( +infinity>)
+ >>> lt_0_case = unary_cases[0]
+ >>> lt_0_case.cond(-123)
+ True
+ >>> lt_0_case.check_result(-123, float('nan'))
+ True
+
+ """
+ cases = []
+ for case_m in r_case.finditer(case_block):
+ case_str = case_m.group(1)
+ if r_already_int_case.search(case_str):
+ cases.append(already_int_case)
+ elif r_even_round_halves_case.search(case_str):
+ cases.append(even_round_halves_case)
+ elif m := r_nan_signbit.search(case_str):
+ signbit = parse_value(m.group(1))
+ expected = bool(parse_value(m.group(2)))
+ cases.append(make_nan_signbit_case(signbit, expected))
+ elif m := r_unary_case.search(case_str):
+ try:
+ cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
+ _check_result, result_expr = parse_result(m.group(2))
+ except ParseError as e:
+ warn(f"case for {func_name} not machine-readable: '{e.value}'")
+ continue
+ cond_expr = cond_expr_template.replace("{}", "x_i")
+ # Do not define check_result in this function's body - see
+ # parse_binary_case comment.
+ check_result = make_unary_check_result(_check_result)
+ case = UnaryCase(
+ cond_expr=cond_expr,
+ cond=cond,
+ cond_from_dtype=cond_from_dtype,
+ result_expr=result_expr,
+ check_result=check_result,
+ raw_case=case_str,
+ )
+ cases.append(case)
+ else:
+ if not r_remaining_case.search(case_str):
+ warn(f"case for {func_name} not machine-readable: '{case_str}'")
+ return cases
+
+
+class BinaryCond(Protocol):
+ def __call__(self, i1: float, i2: float) -> bool:
+ ...
+
+
+class BinaryResultCheck(Protocol):
+ def __call__(self, i1: float, i2: float, result: float) -> bool:
+ ...
+
+
+@dataclass(repr=False)
+class BinaryCase(Case):
+ cond_expr: str
+ result_expr: str
+ x1_cond_from_dtype: FromDtypeFunc
+ x2_cond_from_dtype: FromDtypeFunc
+ cond: BinaryCond
+ check_result: BinaryResultCheck
+ raw_case: Optional[str] = field(default=None)
+
+
+r_binary_case = re.compile("If (.+), the result (.+)")
+r_remaining_case = re.compile("In the remaining cases.+")
+r_cond_sep = re.compile(r"(? float:
+ return n
+
+
+def make_binary_cond(
+ cond_arg: BinaryCondArg,
+ unary_cond: UnaryCheck,
+ *,
+ input_wrapper: Optional[Callable[[float], float]] = None,
+) -> BinaryCond:
+ """
+ Wraps a unary condition as a binary condition, e.g.
+
+ >>> unary_cond = lambda i: i == 42
+ >>> binary_cond_first = make_binary_cond(BinaryCondArg.FIRST, unary_cond)
+ >>> binary_cond_first(42, 0)
+ True
+ >>> binary_cond_second = make_binary_cond(BinaryCondArg.SECOND, unary_cond)
+ >>> binary_cond_second(42, 0)
+ False
+ >>> binary_cond_second(0, 42)
+ True
+ >>> binary_cond_both = make_binary_cond(BinaryCondArg.BOTH, unary_cond)
+ >>> binary_cond_both(42, 0)
+ False
+ >>> binary_cond_both(42, 42)
+ True
+ >>> binary_cond_either = make_binary_cond(BinaryCondArg.EITHER, unary_cond)
+ >>> binary_cond_either(0, 0)
+ False
+ >>> binary_cond_either(42, 0)
+ True
+ >>> binary_cond_either(0, 42)
+ True
+ >>> binary_cond_either(42, 42)
+ True
+
+ """
+ if input_wrapper is None:
+ input_wrapper = noop
+
+ if cond_arg == BinaryCondArg.FIRST:
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ return unary_cond(input_wrapper(i1))
+
+ elif cond_arg == BinaryCondArg.SECOND:
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ return unary_cond(input_wrapper(i2))
+
+ elif cond_arg == BinaryCondArg.BOTH:
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ return unary_cond(input_wrapper(i1)) and unary_cond(input_wrapper(i2))
+
+ else:
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ return unary_cond(input_wrapper(i1)) or unary_cond(input_wrapper(i2))
+
+ return partial_cond
+
+
+def make_eq_input_check_result(
+ eq_to: BinaryCondArg, *, eq_neg: bool = False
+) -> BinaryResultCheck:
+ """
+ Returns a result checker for cases where the result equals an array element
+
+ >>> check_result_first = make_eq_input_check_result(BinaryCondArg.FIRST)
+ >>> check_result(42, 0, 42)
+ True
+ >>> check_result_second = make_eq_input_check_result(BinaryCondArg.SECOND)
+ >>> check_result(42, 0, 42)
+ False
+ >>> check_result(0, 42, 42)
+ True
+ >>> check_result_neg_first = make_eq_input_check_result(BinaryCondArg.FIRST, eq_neg=True)
+ >>> check_result_neg_first(42, 0, 42)
+ False
+ >>> check_result_neg_first(42, 0, -42)
+ True
+
+ """
+ if eq_neg:
+ input_wrapper = lambda i: -i
+ else:
+ input_wrapper = noop
+
+ if eq_to == BinaryCondArg.FIRST:
+
+ def check_result(i1: float, i2: float, result: float) -> bool:
+ eq = make_strict_eq(input_wrapper(i1))
+ return eq(result)
+
+ elif eq_to == BinaryCondArg.SECOND:
+
+ def check_result(i1: float, i2: float, result: float) -> bool:
+ eq = make_strict_eq(input_wrapper(i2))
+ return eq(result)
+
+ else:
+ raise ValueError(f"{eq_to=} must be FIRST or SECOND")
+
+ return check_result
+
+
+def make_binary_check_result(check_just_result: UnaryCheck) -> BinaryResultCheck:
+ def check_result(i1: float, i2: float, result: float) -> bool:
+ return check_just_result(result)
+
+ return check_result
+
+
+def parse_binary_case(case_str: str) -> BinaryCase:
+ """
+ Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
+
+ >>> case_str = (
+ ... "If ``x1_i`` is greater than ``0``, ``x1_i`` is a finite number, "
+ ... "and ``x2_i`` is ``+infinity``, the result is ``NaN``."
+ ... )
+ >>> case = parse_binary_case(case_str)
+ >>> case
+ BinaryCase( 0 and isfinite(x1_i) and x2_i == +infinity -> NaN>)
+ >>> case.cond(42, float('inf'))
+ True
+ >>> case.check_result(42, float('inf'), float('nan'))
+ True
+
+ """
+ case_m = r_binary_case.match(case_str)
+ assert case_m is not None # sanity check
+ cond_strs = r_cond_sep.split(case_m.group(1))
+
+ partial_conds = []
+ partial_exprs = []
+ x1_cond_from_dtypes = []
+ x2_cond_from_dtypes = []
+ for cond_str in cond_strs:
+ if m := r_input_is_array_element.match(cond_str):
+ in_sign, in_no, other_sign, other_no = m.groups()
+ if in_sign != "" or other_no == in_no:
+ raise ParseError(cond_str)
+ partial_expr = f"{in_sign}x{in_no}_i == {other_sign}x{other_no}_i"
+
+ # For these scenarios, we want to make sure both array elements
+ # generate respective to one another by using a shared strategy.
+ shared_from_dtype = lambda d, **kw: st.shared(
+ xps.from_dtype(d, **kw), key=cond_str
+ )
+ input_wrapper = lambda i: -i if other_sign == "-" else noop
+ if other_no == "1":
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ eq = make_strict_eq(input_wrapper(i1))
+ return eq(i2)
+
+ _x2_cond_from_dtype = shared_from_dtype # type: ignore
+
+ def _x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
+ return shared_from_dtype(dtype, **kw).map(input_wrapper)
+
+ elif other_no == "2":
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ eq = make_strict_eq(input_wrapper(i2))
+ return eq(i1)
+
+ _x1_cond_from_dtype = shared_from_dtype # type: ignore
+
+ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
+ return shared_from_dtype(dtype, **kw).map(input_wrapper)
+
+ else:
+ raise ParseError(cond_str)
+
+ x1_cond_from_dtypes.append(BoundFromDtype(base_func=_x1_cond_from_dtype))
+ x2_cond_from_dtypes.append(BoundFromDtype(base_func=_x2_cond_from_dtype))
+
+ elif m := r_both_inputs_are_value.match(cond_str):
+ unary_cond, expr_template, cond_from_dtype = parse_cond(m.group(1))
+ left_expr = expr_template.replace("{}", "x1_i")
+ right_expr = expr_template.replace("{}", "x2_i")
+ partial_expr = f"{left_expr} and {right_expr}"
+ partial_cond = make_binary_cond( # type: ignore
+ BinaryCondArg.BOTH, unary_cond
+ )
+ x1_cond_from_dtypes.append(cond_from_dtype)
+ x2_cond_from_dtypes.append(cond_from_dtype)
+ else:
+ cond_m = r_cond.match(cond_str)
+ if cond_m is None:
+ raise ParseError(cond_str)
+ input_str, value_str = cond_m.groups()
+
+ if value_str == "the same mathematical sign":
+ partial_expr = "copysign(1, x1_i) == copysign(1, x2_i)"
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ return math.copysign(1, i1) == math.copysign(1, i2)
+
+ x1_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1}))
+ x2_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1}))
+
+ elif value_str == "different mathematical signs":
+ partial_expr = "copysign(1, x1_i) != copysign(1, x2_i)"
+
+ def partial_cond(i1: float, i2: float) -> bool:
+ return math.copysign(1, i1) != math.copysign(1, i2)
+
+ x1_cond_from_dtypes.append(BoundFromDtype(kwargs={"min_value": 1}))
+ x2_cond_from_dtypes.append(BoundFromDtype(kwargs={"max_value": -1}))
+
+ else:
+ unary_cond, expr_template, cond_from_dtype = parse_cond(value_str)
+ # Do not define partial_cond via the def keyword or lambda
+ # expressions, as one partial_cond definition can mess up
+ # previous definitions in the partial_conds list. This is a
+ # hard-limitation of using local functions with the same name
+ # and that use the same outer variables (i.e. unary_cond). Use
+ # def in a called function avoids this problem.
+ input_wrapper = None
+ if m := r_input.match(input_str):
+ x_no = m.group(1)
+ partial_expr = expr_template.replace("{}", f"x{x_no}_i")
+ cond_arg = BinaryCondArg.from_x_no(x_no)
+ elif m := r_abs_input.match(input_str):
+ x_no = m.group(1)
+ partial_expr = expr_template.replace("{}", f"abs(x{x_no}_i)")
+ cond_arg = BinaryCondArg.from_x_no(x_no)
+ input_wrapper = abs
+ elif r_and_input.match(input_str):
+ left_expr = expr_template.replace("{}", "x1_i")
+ right_expr = expr_template.replace("{}", "x2_i")
+ partial_expr = f"{left_expr} and {right_expr}"
+ cond_arg = BinaryCondArg.BOTH
+ elif r_or_input.match(input_str):
+ left_expr = expr_template.replace("{}", "x1_i")
+ right_expr = expr_template.replace("{}", "x2_i")
+ partial_expr = f"{left_expr} or {right_expr}"
+ if len(cond_strs) != 1:
+ partial_expr = f"({partial_expr})"
+ cond_arg = BinaryCondArg.EITHER
+ else:
+ raise ParseError(input_str)
+ partial_cond = make_binary_cond( # type: ignore
+ cond_arg, unary_cond, input_wrapper=input_wrapper
+ )
+ if cond_arg == BinaryCondArg.FIRST:
+ x1_cond_from_dtypes.append(cond_from_dtype)
+ elif cond_arg == BinaryCondArg.SECOND:
+ x2_cond_from_dtypes.append(cond_from_dtype)
+ elif cond_arg == BinaryCondArg.BOTH:
+ x1_cond_from_dtypes.append(cond_from_dtype)
+ x2_cond_from_dtypes.append(cond_from_dtype)
+ else:
+ # For "either x1_i or x2_i is " cases, we want to
+ # test three scenarios:
+ #
+ # 1. x1_i is
+ # 2. x2_i is
+ # 3. x1_i AND x2_i is
+ #
+ # This is achieved by a shared base strategy that picks one
+ # of these scenarios to determine whether each array will
+ # use either cond_from_dtype() (i.e. meet the condition), or
+ # simply xps.from_dtype() (i.e. be any value).
+
+ use_x1_or_x2_strat = st.shared(
+ st.sampled_from([(True, False), (False, True), (True, True)])
+ )
+
+ def _x1_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
+ assert len(kw) == 0 # sanity check
+ return use_x1_or_x2_strat.flatmap(
+ lambda t: cond_from_dtype(dtype)
+ if t[0]
+ else hh.from_dtype(dtype)
+ )
+
+ def _x2_cond_from_dtype(dtype, **kw) -> st.SearchStrategy[float]:
+ assert len(kw) == 0 # sanity check
+ return use_x1_or_x2_strat.flatmap(
+ lambda t: cond_from_dtype(dtype)
+ if t[1]
+ else hh.from_dtype(dtype)
+ )
+
+ x1_cond_from_dtypes.append(
+ BoundFromDtype(base_func=_x1_cond_from_dtype)
+ )
+ x2_cond_from_dtypes.append(
+ BoundFromDtype(base_func=_x2_cond_from_dtype)
+ )
+
+ partial_conds.append(partial_cond)
+ partial_exprs.append(partial_expr)
+
+ result_m = r_result.match(case_m.group(2))
+ if result_m is None:
+ raise ParseError(case_m.group(2))
+ result_str = result_m.group(1)
+ # Like with partial_cond, do not define check_result in this function's body.
+ if m := r_array_element.match(result_str):
+ sign, x_no = m.groups()
+ result_expr = f"{sign}x{x_no}_i"
+ check_result = make_eq_input_check_result( # type: ignore
+ BinaryCondArg.from_x_no(x_no), eq_neg=sign == "-"
+ )
+ else:
+ _check_result, result_expr = parse_result(result_m.group(1))
+ check_result = make_binary_check_result(_check_result)
+
+ cond_expr = " and ".join(partial_exprs)
+
+ def cond(i1: float, i2: float) -> bool:
+ return all(pc(i1, i2) for pc in partial_conds)
+
+ x1_cond_from_dtype = sum(x1_cond_from_dtypes, start=BoundFromDtype())
+ x2_cond_from_dtype = sum(x2_cond_from_dtypes, start=BoundFromDtype())
+
+ return BinaryCase(
+ cond_expr=cond_expr,
+ cond=cond,
+ x1_cond_from_dtype=x1_cond_from_dtype,
+ x2_cond_from_dtype=x2_cond_from_dtype,
+ result_expr=result_expr,
+ check_result=check_result,
+ raw_case=case_str,
+ )
+
+
+r_redundant_case = re.compile("result.+determined by the rule already stated above")
+
+
+def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]:
+ """
+ Parses a Sphinx-formatted docstring of a binary function to return a list of
+ codified binary cases, e.g.
+
+ >>> def logaddexp(x1, x2):
+ ... '''
+ ... Calculates the logarithm of the sum of exponentiations
+ ...
+ ... **Special Cases**
+ ...
+ ... For floating-point operands,
+ ...
+ ... - If either ``x1_i`` or ``x2_i`` is ``NaN``, the result is ``NaN``.
+ ... - If ``x1_i`` is ``+infinity`` and ``x2_i`` is not ``NaN``, the result is ``+infinity``.
+ ... - If ``x1_i`` is not ``NaN`` and ``x2_i`` is ``+infinity``, the result is ``+infinity``.
+ ...
+ ... Parameters
+ ... ----------
+ ... x1: array
+ ... first input array
+ ... x2: array
+ ... second input array
+ ...
+ ... Returns
+ ... -------
+ ... out: array
+ ... an array containing the results
+ ... '''
+ ...
+ >>> case_block = r_case_block.search(logaddexp.__doc__).group(1)
+ >>> binary_cases = parse_binary_case_block(case_block, 'logaddexp')
+ >>> for case in binary_cases:
+ ... print(repr(case))
+ BinaryCase( NaN>)
+ BinaryCase( +infinity>)
+ BinaryCase( +infinity>)
+
+ """
+ cases = []
+ for case_m in r_case.finditer(case_block):
+ case_str = case_m.group(1)
+ if r_redundant_case.search(case_str):
+ continue
+ if r_binary_case.match(case_str):
+ try:
+ case = parse_binary_case(case_str)
+ cases.append(case)
+ except ParseError as e:
+ warn(f"case for {func_name} not machine-readable: '{e.value}'")
+ else:
+ if not r_remaining_case.match(case_str):
+ warn(f"case for {func_name} not machine-readable: '{case_str}'")
+ return cases
+
+
+unary_params = []
+binary_params = []
+iop_params = []
+func_to_op: Dict[str, str] = {v: k for k, v in dh.op_to_func.items()}
+for stub in category_to_funcs["elementwise"]:
+ func_name = stub.__name__
+ if stub.__doc__ is None:
+ warn(f"{func_name}() stub has no docstring")
+ continue
+ if m := r_case_block.search(stub.__doc__):
+ case_block = m.group(1)
+ else:
+ continue
+ marks = []
+ try:
+ func = getattr(xp, func_name)
+ except AttributeError:
+ marks.append(
+ pytest.mark.skip(reason=f"{func_name} not found in array module")
+ )
+ func = None
+ sig = inspect.signature(stub)
+ param_names = list(sig.parameters.keys())
+ if len(sig.parameters) == 0:
+ warn(f"{func=} has no parameters")
+ continue
+ if param_names[0] == "x":
+ if cases := parse_unary_case_block(case_block, func_name):
+ name_to_func = {func_name: func}
+ if func_name in func_to_op.keys():
+ op_name = func_to_op[func_name]
+ op = getattr(operator, op_name)
+ name_to_func[op_name] = op
+ for func_name, func in name_to_func.items():
+ for case in cases:
+ id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
+ p = pytest.param(func_name, func, case, id=id_)
+ unary_params.append(p)
+ else:
+ warn(f"Special cases found for {func_name} but none were parsed")
+ continue
+ if len(sig.parameters) == 1:
+ warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
+ continue
+ if param_names[0] == "x1" and param_names[1] == "x2":
+ if cases := parse_binary_case_block(case_block, func_name):
+ name_to_func = {func_name: func}
+ if func_name in func_to_op.keys():
+ op_name = func_to_op[func_name]
+ op = getattr(operator, op_name)
+ name_to_func[op_name] = op
+ # We collect inplace operator test cases seperately
+ if "equal" in func_name:
+ continue
+ iop_name = "__i" + op_name[2:]
+ iop = getattr(operator, iop_name)
+ for case in cases:
+ id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
+ p = pytest.param(iop_name, iop, case, id=id_)
+ iop_params.append(p)
+ for func_name, func in name_to_func.items():
+ for case in cases:
+ id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
+ p = pytest.param(func_name, func, case, id=id_)
+ binary_params.append(p)
+ else:
+ warn(f"Special cases found for {func_name} but none were parsed")
+ continue
+ else:
+ warn(
+ f"{func=} starts with two parameters '{param_names[0]}' and "
+ f"'{param_names[1]}', which are not named 'x1' and 'x2'"
+ )
+
+
+# test_{unary/binary/iop} naively generate arrays, i.e. arrays that might not
+# meet the condition that is being test. We then forcibly make the array meet
+# the condition by picking a random index to insert an acceptable element.
+#
+# good_example is a flag that tells us whether Hypothesis generated an array
+# with at least on element that is special-cased. We reject the example when
+# its False - Hypothesis will complain if we reject too many examples, thus
+# indicating we've done something wrong.
+
+# sanity checks
+assert len(unary_params) != 0
+assert len(binary_params) != 0
+assert len(iop_params) != 0
+
+
+@pytest.mark.parametrize("func_name, func, case", unary_params)
+def test_unary(func_name, func, case):
+ with catch_warnings():
+ # XXX: We are using example here to generate one example draw, but
+ # hypothesis issues a warning from this. We should consider either
+ # drawing multiple examples like a normal test, or just hard-coding a
+ # single example test case without using hypothesis.
+ filterwarnings('ignore', category=NonInteractiveExampleWarning)
+ in_value = case.cond_from_dtype(xp.float64).example()
+ x = xp.asarray(in_value, dtype=xp.float64)
+ out = func(x)
+ out_value = float(out)
+ assert case.check_result(in_value, out_value), (
+ f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n"
+ )
+
+
+@pytest.mark.parametrize("func_name, func, case", binary_params)
+@settings(max_examples=1)
+@given(data=st.data())
+def test_binary(func_name, func, case, data):
+ # We don't use example() like in test_unary because the same internal shared
+ # strategies used in both x1's and x2's don't "sync" with example() draws.
+ x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value")
+ x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value")
+ x1 = xp.asarray(x1_value, dtype=xp.float64)
+ x2 = xp.asarray(x2_value, dtype=xp.float64)
+
+ out = func(x1, x2)
+ out_value = float(out)
+
+ assert case.check_result(x1_value, x2_value, out_value), (
+ f"out={out_value}, but should be {case.result_expr} [{func_name}()]\n"
+ f"condition: {case}\n"
+ f"x1={x1_value}, x2={x2_value}"
+ )
+
+
+
+@pytest.mark.parametrize("iop_name, iop, case", iop_params)
+@settings(max_examples=1)
+@given(data=st.data())
+def test_iop(iop_name, iop, case, data):
+ # See test_binary comment
+ x1_value = data.draw(case.x1_cond_from_dtype(xp.float64), label="x1_value")
+ x2_value = data.draw(case.x2_cond_from_dtype(xp.float64), label="x2_value")
+ x1 = xp.asarray(x1_value, dtype=xp.float64)
+ x2 = xp.asarray(x2_value, dtype=xp.float64)
+
+ res = iop(x1, x2)
+ res_value = float(res)
+
+ assert case.check_result(x1_value, x2_value, res_value), (
+ f"x1={res}, but should be {case.result_expr} [{func_name}()]\n"
+ f"condition: {case}\n"
+ f"x1={x1_value}, x2={x2_value}"
+ )
+
+
+@pytest.mark.parametrize(
+ "func_name, expected",
+ [
+ ("mean", float("nan")),
+ ("prod", 1),
+ ("std", float("nan")),
+ ("sum", 0),
+ ("var", float("nan")),
+ ],
+ ids=["mean", "prod", "std", "sum", "var"],
+)
+def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get expected
+ func = getattr(xp, func_name)
+ out = func(xp.asarray([], dtype=dh.default_float))
+ ph.assert_shape(func_name, out_shape=out.shape, expected=()) # sanity check
+ msg = f"{out=!r}, but should be {expected}"
+ if math.isnan(expected):
+ assert xp.isnan(out), msg
+ else:
+ assert out == expected, msg
+
+
+@pytest.mark.parametrize(
+ "func_name", [f.__name__ for f in category_to_funcs["statistical"]
+ if f.__name__ not in ['cumulative_sum', 'cumulative_prod']]
+)
+@given(
+ x=hh.arrays(dtype=hh.real_floating_dtypes, shape=hh.shapes(min_side=1)),
+ data=st.data(),
+)
+def test_nan_propagation(func_name, x, data):
+ func = getattr(xp, func_name)
+ nan_positions = data.draw(
+ hh.arrays(dtype=hh.bool_dtype, shape=x.shape), label="nan_positions"
+ )
+ assume(xp.any(nan_positions))
+ x = xp.where(nan_positions, xp.asarray(float("nan")), x)
+ note(f"{x=}")
+
+ out = func(x)
+
+ ph.assert_shape(func_name, out_shape=out.shape, expected=()) # sanity check
+ assert xp.isnan(out), f"{out=!r}, but should be NaN"
diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py
new file mode 100644
index 00000000..0e3aa9d4
--- /dev/null
+++ b/array_api_tests/test_statistical_functions.py
@@ -0,0 +1,452 @@
+import cmath
+import math
+from typing import Optional
+
+import pytest
+from hypothesis import assume, given
+from hypothesis import strategies as st
+from ndindex import iter_indices
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+from ._array_module import _UndefinedStub
+from .typing import DataType
+
+
+@pytest.mark.min_version("2023.12")
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.numeric_dtypes,
+ shape=hh.shapes(min_dims=1)),
+ data=st.data(),
+)
+def test_cumulative_sum(x, data):
+ axes = st.integers(-x.ndim, x.ndim - 1)
+ if x.ndim == 1:
+ axes = axes | st.none()
+ axis = data.draw(axes, label='axis')
+ _axis, = sh.normalize_axis(axis, x.ndim)
+ dtype = data.draw(kwarg_dtypes(x.dtype))
+ include_initial = data.draw(st.booleans(), label="include_initial")
+
+ kw = data.draw(
+ hh.specified_kwargs(
+ ("axis", axis, None),
+ ("dtype", dtype, None),
+ ("include_initial", include_initial, False),
+ ),
+ label="kw",
+ )
+
+ out = xp.cumulative_sum(x, **kw)
+
+ expected_shape = list(x.shape)
+ if include_initial:
+ expected_shape[_axis] += 1
+ expected_shape = tuple(expected_shape)
+ ph.assert_shape("cumulative_sum", out_shape=out.shape, expected=expected_shape)
+
+ expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
+ if expected_dtype is None:
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
+ # uint32 or uint64), we skip testing the output dtype.
+ # See https://github.com/data-apis/array-api-tests/issues/106
+ if x.dtype in dh.uint_dtypes:
+ assert dh.is_int_dtype(out.dtype) # sanity check
+ else:
+ ph.assert_dtype("cumulative_sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
+
+ scalar_type = dh.get_scalar_type(out.dtype)
+
+ for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis):
+ x_arr = x[x_idx.raw]
+ out_arr = out[out_idx.raw]
+
+ if include_initial:
+ ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=0)
+
+ for n in range(x.shape[_axis]):
+ start = 1 if include_initial else 0
+ out_val = out_arr[n + start]
+ assume(cmath.isfinite(out_val))
+ elements = []
+ for idx in range(n + 1):
+ s = scalar_type(x_arr[idx])
+ elements.append(s)
+ expected = sum(elements)
+ if dh.is_int_dtype(out.dtype):
+ m, M = dh.dtype_ranges[out.dtype]
+ assume(m <= expected <= M)
+ ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
+ idx=out_idx.raw, out=out_val,
+ expected=expected)
+ else:
+ condition_number = _sum_condition_number(elements)
+ assume(condition_number < 1e6)
+ ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type,
+ idx=out_idx.raw, out=out_val,
+ expected=expected)
+
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.numeric_dtypes,
+ shape=hh.shapes(min_dims=1)),
+ data=st.data(),
+)
+def test_cumulative_prod(x, data):
+ axes = st.integers(-x.ndim, x.ndim - 1)
+ if x.ndim == 1:
+ axes = axes | st.none()
+ axis = data.draw(axes, label='axis')
+ _axis, = sh.normalize_axis(axis, x.ndim)
+ dtype = data.draw(kwarg_dtypes(x.dtype))
+ include_initial = data.draw(st.booleans(), label="include_initial")
+
+ kw = data.draw(
+ hh.specified_kwargs(
+ ("axis", axis, None),
+ ("dtype", dtype, None),
+ ("include_initial", include_initial, False),
+ ),
+ label="kw",
+ )
+
+ out = xp.cumulative_prod(x, **kw)
+
+ expected_shape = list(x.shape)
+ if include_initial:
+ expected_shape[_axis] += 1
+ expected_shape = tuple(expected_shape)
+ ph.assert_shape("cumulative_prod", out_shape=out.shape, expected=expected_shape)
+
+ expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
+ if expected_dtype is None:
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
+ # uint32 or uint64), we skip testing the output dtype.
+ # See https://github.com/data-apis/array-api-tests/issues/106
+ if x.dtype in dh.uint_dtypes:
+ assert dh.is_int_dtype(out.dtype) # sanity check
+ else:
+ ph.assert_dtype("cumulative_prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
+
+ scalar_type = dh.get_scalar_type(out.dtype)
+
+ for x_idx, out_idx, in iter_indices(x.shape, expected_shape, skip_axes=_axis):
+ #x_arr = x[x_idx.raw]
+ out_arr = out[out_idx.raw]
+
+ if include_initial:
+ ph.assert_scalar_equals("cumulative_prod", type_=scalar_type, idx=out_idx.raw, out=out_arr[0], expected=1)
+
+ #TODO: add value testing of cumulative_prod
+
+
+def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
+ dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
+ dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
+ assert len(dtypes) > 0 # sanity check
+ return st.none() | st.sampled_from(dtypes)
+
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.real_dtypes,
+ shape=hh.shapes(min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_max(x, data):
+ kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.max(x, **kw)
+
+ ph.assert_dtype("max", in_dtype=x.dtype, out_dtype=out.dtype)
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "max", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(out.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
+ max_ = scalar_type(out[out_idx])
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = max(elements)
+ ph.assert_scalar_equals("max", type_=scalar_type, idx=out_idx, out=max_, expected=expected)
+
+
+@given(
+ x=hh.arrays(
+ dtype=hh.real_floating_dtypes,
+ shape=hh.shapes(min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_mean(x, data):
+ kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.mean(x, **kw)
+
+ ph.assert_dtype("mean", in_dtype=x.dtype, out_dtype=out.dtype)
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "mean", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ # Values testing mean is too finicky
+
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.real_dtypes,
+ shape=hh.shapes(min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_min(x, data):
+ kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.min(x, **kw)
+
+ ph.assert_dtype("min", in_dtype=x.dtype, out_dtype=out.dtype)
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "min", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(out.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
+ min_ = scalar_type(out[out_idx])
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = min(elements)
+ ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected)
+
+
+def _prod_condition_number(elements):
+ # Relative condition number using the infinity norm
+ abs_max = max([abs(i) for i in elements])
+ abs_min = min([abs(i) for i in elements])
+
+ if abs_min == 0:
+ return float('inf')
+
+ return abs_max / abs_min
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.numeric_dtypes,
+ shape=hh.shapes(min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_prod(x, data):
+ kw = data.draw(
+ hh.kwargs(
+ axis=hh.axes(x.ndim),
+ dtype=kwarg_dtypes(x.dtype),
+ keepdims=st.booleans(),
+ ),
+ label="kw",
+ )
+ keepdims = kw.get("keepdims", False)
+
+ with hh.reject_overflow():
+ out = xp.prod(x, **kw)
+
+ dtype = kw.get("dtype", None)
+ expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
+ if expected_dtype is None:
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
+ # uint32 or uint64), we skip testing the output dtype.
+ # See https://github.com/data-apis/array-api-tests/issues/106
+ if x.dtype in dh.uint_dtypes:
+ assert dh.is_int_dtype(out.dtype) # sanity check
+ else:
+ ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(out.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
+ prod = scalar_type(out[out_idx])
+ assume(cmath.isfinite(prod))
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = math.prod(elements)
+ if dh.is_int_dtype(out.dtype):
+ m, M = dh.dtype_ranges[out.dtype]
+ assume(m <= expected <= M)
+ ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx,
+ out=prod, expected=expected)
+ else:
+ condition_number = _prod_condition_number(elements)
+ assume(condition_number < 1e15)
+ ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx,
+ out=prod, expected=expected)
+
+
+@pytest.mark.skip(reason="flaky") # TODO: fix!
+@given(
+ x=hh.arrays(
+ dtype=hh.real_floating_dtypes,
+ shape=hh.shapes(min_side=1),
+ elements={"allow_nan": False},
+ ).filter(lambda x: math.prod(x.shape) >= 2),
+ data=st.data(),
+)
+def test_std(x, data):
+ axis = data.draw(hh.axes(x.ndim), label="axis")
+ _axes = sh.normalize_axis(axis, x.ndim)
+ N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
+ correction = data.draw(
+ st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
+ label="correction",
+ )
+ _keepdims = data.draw(st.booleans(), label="keepdims")
+ kw = data.draw(
+ hh.specified_kwargs(
+ ("axis", axis, None),
+ ("correction", correction, 0.0),
+ ("keepdims", _keepdims, False),
+ ),
+ label="kw",
+ )
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.std(x, **kw)
+
+ ph.assert_dtype("std", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_keepdimable_shape(
+ "std", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ # We can't easily test the result(s) as standard deviation methods vary a lot
+
+
+def _sum_condition_number(elements):
+ sum_abs = sum([abs(i) for i in elements])
+ abs_sum = abs(sum(elements))
+
+ if abs_sum == 0:
+ return float('inf')
+
+ return sum_abs / abs_sum
+
+# @pytest.mark.unvectorized
+@given(
+ x=hh.arrays(
+ dtype=hh.numeric_dtypes,
+ shape=hh.shapes(min_side=1),
+ elements={"allow_nan": False},
+ ),
+ data=st.data(),
+)
+def test_sum(x, data):
+ kw = data.draw(
+ hh.kwargs(
+ axis=hh.axes(x.ndim),
+ dtype=kwarg_dtypes(x.dtype),
+ keepdims=st.booleans(),
+ ),
+ label="kw",
+ )
+ keepdims = kw.get("keepdims", False)
+
+ with hh.reject_overflow():
+ out = xp.sum(x, **kw)
+
+ dtype = kw.get("dtype", None)
+ expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
+ if expected_dtype is None:
+ # If a default uint cannot exist (i.e. in PyTorch which doesn't support
+ # uint32 or uint64), we skip testing the output dtype.
+ # See https://github.com/data-apis/array-api-tests/issues/160
+ if x.dtype in dh.uint_dtypes:
+ assert dh.is_int_dtype(out.dtype) # sanity check
+ else:
+ ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(out.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
+ sum_ = scalar_type(out[out_idx])
+ assume(cmath.isfinite(sum_))
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = sum(elements)
+ if dh.is_int_dtype(out.dtype):
+ m, M = dh.dtype_ranges[out.dtype]
+ assume(m <= expected <= M)
+ ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx,
+ out=sum_, expected=expected)
+ else:
+ # Avoid value testing for ill conditioned summations. See
+ # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and
+ # https://en.wikipedia.org/wiki/Condition_number.
+ condition_number = _sum_condition_number(elements)
+ assume(condition_number < 1e6)
+ ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
+
+
+@pytest.mark.unvectorized
+@pytest.mark.skip(reason="flaky") # TODO: fix!
+@given(
+ x=hh.arrays(
+ dtype=hh.real_floating_dtypes,
+ shape=hh.shapes(min_side=1),
+ elements={"allow_nan": False},
+ ).filter(lambda x: math.prod(x.shape) >= 2),
+ data=st.data(),
+)
+def test_var(x, data):
+ axis = data.draw(hh.axes(x.ndim), label="axis")
+ _axes = sh.normalize_axis(axis, x.ndim)
+ N = sum(side for axis, side in enumerate(x.shape) if axis not in _axes)
+ correction = data.draw(
+ st.floats(0.0, N, allow_infinity=False, allow_nan=False) | st.integers(0, N),
+ label="correction",
+ )
+ _keepdims = data.draw(st.booleans(), label="keepdims")
+ kw = data.draw(
+ hh.specified_kwargs(
+ ("axis", axis, None),
+ ("correction", correction, 0.0),
+ ("keepdims", _keepdims, False),
+ ),
+ label="kw",
+ )
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.var(x, **kw)
+
+ ph.assert_dtype("var", in_dtype=x.dtype, out_dtype=out.dtype)
+ ph.assert_keepdimable_shape(
+ "var", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ # We can't easily test the result(s) as variance methods vary a lot
diff --git a/array_api_tests/test_type_promotion.py b/array_api_tests/test_type_promotion.py
deleted file mode 100644
index 2f4be85e..00000000
--- a/array_api_tests/test_type_promotion.py
+++ /dev/null
@@ -1,213 +0,0 @@
-"""
-https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html
-"""
-
-import pytest
-
-from hypothesis import given, example
-from hypothesis.strategies import from_type, data
-
-from .hypothesis_helpers import shapes
-from .pytest_helpers import nargs
-from .array_helpers import assert_exactly_equal
-
-from .function_stubs import elementwise_functions
-from ._array_module import (ones, int8, int16, int32, int64, uint8,
- uint16, uint32, uint64, float32, float64)
-from . import _array_module
-
-dtype_mapping = {
- 'i1': int8,
- 'i2': int16,
- 'i4': int32,
- 'i8': int64,
- 'u1': uint8,
- 'u2': uint16,
- 'u4': uint32,
- 'u8': uint64,
- 'f4': float32,
- 'f8': float64,
-}
-
-signed_integer_promotion_table = {
- ('i1', 'i1'): 'i1',
- ('i1', 'i2'): 'i2',
- ('i1', 'i4'): 'i4',
- ('i1', 'i8'): 'i8',
- ('i2', 'i1'): 'i2',
- ('i2', 'i2'): 'i2',
- ('i2', 'i4'): 'i4',
- ('i2', 'i8'): 'i8',
- ('i4', 'i1'): 'i4',
- ('i4', 'i2'): 'i4',
- ('i4', 'i4'): 'i4',
- ('i4', 'i8'): 'i8',
- ('i8', 'i1'): 'i8',
- ('i8', 'i2'): 'i8',
- ('i8', 'i4'): 'i8',
- ('i8', 'i8'): 'i8',
-}
-
-unsigned_integer_promotion_table = {
- ('u1', 'u1'): 'u1',
- ('u1', 'u2'): 'u2',
- ('u1', 'u4'): 'u4',
- ('u1', 'u8'): 'u8',
- ('u2', 'u1'): 'u2',
- ('u2', 'u2'): 'u2',
- ('u2', 'u4'): 'u4',
- ('u2', 'u8'): 'u8',
- ('u4', 'u1'): 'u4',
- ('u4', 'u2'): 'u4',
- ('u4', 'u4'): 'u4',
- ('u4', 'u8'): 'u8',
- ('u8', 'u1'): 'u8',
- ('u8', 'u2'): 'u8',
- ('u8', 'u4'): 'u8',
- ('u8', 'u8'): 'u8',
-}
-
-mixed_signed_unsigned_promotion_table = {
- ('i1', 'u1'): 'i2',
- ('i1', 'u2'): 'i4',
- ('i1', 'u4'): 'i8',
- ('i2', 'u1'): 'i2',
- ('i2', 'u2'): 'i4',
- ('i2', 'u4'): 'i8',
- ('i4', 'u1'): 'i4',
- ('i4', 'u2'): 'i4',
- ('i4', 'u4'): 'i8',
-}
-
-flipped_mixed_signed_unsigned_promotion_table = {(u, i): p for (i, u), p in mixed_signed_unsigned_promotion_table.items()}
-
-float_promotion_table = {
- ('f4', 'f4'): 'f4',
- ('f4', 'f8'): 'f8',
- ('f8', 'f4'): 'f8',
- ('f8', 'f8'): 'f8',
-}
-
-promotion_table = {
- **signed_integer_promotion_table,
- **unsigned_integer_promotion_table,
- **mixed_signed_unsigned_promotion_table,
- **flipped_mixed_signed_unsigned_promotion_table,
- **float_promotion_table,
-}
-
-
-binary_operators = {
- '__add__': '+',
- '__and__': '&',
- '__eq__': '==',
- '__floordiv__': '//',
- '__ge__': '>=',
- '__gt__': '>',
- '__le__': '<=',
- '__lshift__': '<<',
- '__lt__': '<',
- '__matmul__': '@',
- '__mod__': '%',
- '__mul__': '*',
- '__ne__': '!=',
- '__or__': '|',
- '__pow__': '**',
- '__rshift__': '>>',
- '__sub__': '-',
- '__truediv__': '/',
- '__xor__': '^',
-}
-
-unary_operators = {
- '__invert__': '~',
- '__neg__': '-',
- '__pos__': '+',
-}
-
-dtypes_to_scalar = {
- _array_module.bool: bool,
- _array_module.int8: int,
- _array_module.int16: int,
- _array_module.int32: int,
- _array_module.int64: int,
- _array_module.uint8: int,
- _array_module.uint16: int,
- _array_module.uint32: int,
- _array_module.uint64: int,
- _array_module.float32: float,
- _array_module.float64: float,
-}
-
-scalar_to_dtype = {s: [d for d, _s in dtypes_to_scalar.items() if _s == s] for
- s in dtypes_to_scalar.values()}
-
-# TODO: Extend this to all functions (not just elementwise), and handle
-# functions that take more than 2 args
-@pytest.mark.parametrize('func_name', [i for i in
- elementwise_functions.__all__ if
- nargs(i) > 1])
-@pytest.mark.parametrize('dtypes', promotion_table.items())
-# The spec explicitly requires type promotion to work for shape 0
-@example(shape=(0,))
-@given(shape=shapes)
-def test_promotion(func_name, shape, dtypes):
- assert nargs(func_name) == 2
- func = getattr(_array_module, func_name)
-
-
- (type1, type2), res_type = dtypes
- dtype1 = dtype_mapping[type1]
- dtype2 = dtype_mapping[type2]
- res_dtype = dtype_mapping[res_type]
-
- for i in [func, dtype1, dtype2, res_dtype]:
- if isinstance(i, _array_module._UndefinedStub):
- func._raise()
-
- a1 = ones(shape, dtype=dtype1)
- a2 = ones(shape, dtype=dtype2)
- res = func(a1, a2)
-
- assert res.dtype == res_dtype, f"{func_name}({dtype1}, {dtype2}) promoted to {res.dtype}, should have promoted to {res_dtype} (shape={shape})"
-
-@pytest.mark.parametrize('binary_op', sorted(set(binary_operators.values()) - {'@'}))
-@pytest.mark.parametrize('scalar_type,dtype', [(s, d) for s in scalar_to_dtype
- for d in scalar_to_dtype[s]])
-@given(shape=shapes, scalars=data())
-def test_operator_scalar_promotion(binary_op, scalar_type, dtype, shape, scalars):
- """
- See https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
- """
- if binary_op == '@':
- pytest.skip("matmul (@) is not supported for scalars")
- a = ones(shape, dtype=dtype)
- s = scalars.draw(from_type(scalar_type))
- scalar_as_array = _array_module.full((), s, dtype=dtype)
- get_locals = lambda: dict(a=a, s=s, scalar_as_array=scalar_as_array)
-
- # As per the spec:
-
- # The expected behavior is then equivalent to:
- #
- # 1. Convert the scalar to a 0-D array with the same dtype as that of the
- # array used in the expression.
- #
- # 2. Execute the operation for `array 0-D array` (or `0-D array
- # array` if `scalar` was the left-hand argument).
-
- array_scalar = f'a {binary_op} s'
- array_scalar_expected = f'a {binary_op} scalar_as_array'
- res = eval(array_scalar, get_locals())
- expected = eval(array_scalar_expected, get_locals())
- assert_exactly_equal(res, expected)
-
- scalar_array = f's {binary_op} a'
- scalar_array_expected = f'scalar_as_array {binary_op} a'
- res = eval(scalar_array, get_locals())
- expected = eval(scalar_array_expected, get_locals())
- assert_exactly_equal(res, expected)
-
-if __name__ == '__main__':
- for (i, j), p in promotion_table.items():
- print(f"({i}, {j}) -> {p}")
diff --git a/array_api_tests/test_utility_functions.py b/array_api_tests/test_utility_functions.py
new file mode 100644
index 00000000..b6e0a4fe
--- /dev/null
+++ b/array_api_tests/test_utility_functions.py
@@ -0,0 +1,141 @@
+import pytest
+from hypothesis import given
+from hypothesis import strategies as st
+
+from . import _array_module as xp
+from . import dtype_helpers as dh
+from . import hypothesis_helpers as hh
+from . import pytest_helpers as ph
+from . import shape_helpers as sh
+
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_side=1)),
+ data=st.data(),
+)
+def test_all(x, data):
+ kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.all(x, **kw)
+
+ ph.assert_dtype("all", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "all", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
+ )
+ scalar_type = dh.get_scalar_type(x.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
+ result = bool(out[out_idx])
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = all(elements)
+ ph.assert_scalar_equals("all", type_=scalar_type, idx=out_idx,
+ out=result, expected=expected, kw=kw)
+
+
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes()),
+ data=st.data(),
+)
+def test_any(x, data):
+ kw = data.draw(hh.kwargs(axis=hh.axes(x.ndim), keepdims=st.booleans()), label="kw")
+ keepdims = kw.get("keepdims", False)
+
+ out = xp.any(x, **kw)
+
+ ph.assert_dtype("any", in_dtype=x.dtype, out_dtype=out.dtype, expected=xp.bool)
+ _axes = sh.normalize_axis(kw.get("axis", None), x.ndim)
+ ph.assert_keepdimable_shape(
+ "any", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw,
+ )
+ scalar_type = dh.get_scalar_type(x.dtype)
+ for indices, out_idx in zip(sh.axes_ndindex(x.shape, _axes), sh.ndindex(out.shape)):
+ result = bool(out[out_idx])
+ elements = []
+ for idx in indices:
+ s = scalar_type(x[idx])
+ elements.append(s)
+ expected = any(elements)
+ ph.assert_scalar_equals("any", type_=scalar_type, idx=out_idx,
+ out=result, expected=expected, kw=kw)
+
+
+@pytest.mark.unvectorized
+@pytest.mark.min_version("2024.12")
+@given(
+ x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
+ data=st.data(),
+)
+def test_diff(x, data):
+ axis = data.draw(
+ st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
+ label="axis"
+ )
+ if axis is None:
+ axis_kw = {"axis": -1}
+ n_axis = x.ndim - 1
+ else:
+ axis_kw = {"axis": axis}
+ n_axis = axis + x.ndim if axis < 0 else axis
+
+ n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
+
+ out = xp.diff(x, **axis_kw, n=n)
+
+ expected_shape = list(x.shape)
+ expected_shape[n_axis] -= n
+
+ assert out.shape == tuple(expected_shape)
+
+ # value test
+ if n == 1:
+ for idx in sh.ndindex(out.shape):
+ l = list(idx)
+ l[n_axis] += 1
+ assert out[idx] == x[tuple(l)] - x[idx], f"diff failed with {idx = }"
+
+
+@pytest.mark.min_version("2024.12")
+@pytest.mark.unvectorized
+@given(
+ x=hh.arrays(hh.numeric_dtypes, hh.shapes(min_dims=1, min_side=1)),
+ data=st.data(),
+)
+def test_diff_append_prepend(x, data):
+ axis = data.draw(
+ st.integers(-x.ndim, max(x.ndim - 1, 0)) | st.none(),
+ label="axis"
+ )
+ if axis is None:
+ axis_kw = {"axis": -1}
+ n_axis = x.ndim - 1
+ else:
+ axis_kw = {"axis": axis}
+ n_axis = axis + x.ndim if axis < 0 else axis
+
+ n = data.draw(st.integers(1, min(x.shape[n_axis], 3)))
+
+ append_shape = list(x.shape)
+ append_axis_len = data.draw(st.integers(1, 2*append_shape[n_axis]), label="append_axis")
+ append_shape[n_axis] = append_axis_len
+ append = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(append_shape)), label="append")
+
+ prepend_shape = list(x.shape)
+ prepend_axis_len = data.draw(st.integers(1, 2*prepend_shape[n_axis]), label="prepend_axis")
+ prepend_shape[n_axis] = prepend_axis_len
+ prepend = data.draw(hh.arrays(dtype=x.dtype, shape=tuple(prepend_shape)), label="prepend")
+
+ out = xp.diff(x, **axis_kw, n=n, append=append, prepend=prepend)
+
+ in_1 = xp.concat((prepend, x, append), **axis_kw)
+ out_1 = xp.diff(in_1, **axis_kw, n=n)
+
+ assert out.shape == out_1.shape
+ for idx in sh.ndindex(out.shape):
+ assert out[idx] == out_1[idx], f"{idx = }"
+
diff --git a/array_api_tests/typing.py b/array_api_tests/typing.py
new file mode 100644
index 00000000..84311ff3
--- /dev/null
+++ b/array_api_tests/typing.py
@@ -0,0 +1,21 @@
+from typing import Any, Tuple, Type, Union
+
+__all__ = [
+ "DataType",
+ "Scalar",
+ "ScalarType",
+ "Array",
+ "Shape",
+ "AtomicIndex",
+ "Index",
+ "Param",
+]
+
+DataType = Type[Any]
+Scalar = Union[bool, int, float, complex]
+ScalarType = Union[Type[bool], Type[int], Type[float], Type[complex]]
+Array = Any
+Shape = Tuple[int, ...]
+AtomicIndex = Union[int, "ellipsis", slice, None] # noqa
+Index = Union[AtomicIndex, Tuple[AtomicIndex, ...]]
+Param = Tuple
diff --git a/conftest.py b/conftest.py
index cd188b3d..05baebc1 100644
--- a/conftest.py
+++ b/conftest.py
@@ -1,45 +1,285 @@
+from functools import lru_cache
+from pathlib import Path
+import argparse
+import warnings
+import os
+
from hypothesis import settings
+from hypothesis.errors import InvalidArgument
+from pytest import mark
+
+from array_api_tests import _array_module as xp
+from array_api_tests import api_version
+from array_api_tests._array_module import _UndefinedStub
+from array_api_tests.stubs import EXTENSIONS
+from array_api_tests import xp_name, xp as array_module
+
+from reporting import pytest_metadata, pytest_json_modifyreport, add_extra_json_metadata # noqa
-# Add a --hypothesis-max-examples flag to pytest. See
-# https://github.com/HypothesisWorks/hypothesis/issues/2434#issuecomment-630309150
+def pytest_report_header(config):
+ disabled_extensions = config.getoption("--disable-extension")
+ enabled_extensions = sorted({
+ ext for ext in EXTENSIONS + ['fft'] if ext not in disabled_extensions and xp_has_ext(ext)
+ })
+
+ try:
+ array_module_version = array_module.__version__
+ except AttributeError:
+ array_module_version = "version unknown"
+
+ return f"Array API Tests Module: {xp_name} ({array_module_version}). API Version: {api_version}. Enabled Extensions: {', '.join(enabled_extensions)}"
def pytest_addoption(parser):
- # Add an option to change the Hypothesis max_examples setting.
+ # Hypothesis max examples
+ # See https://github.com/HypothesisWorks/hypothesis/issues/2434
parser.addoption(
"--hypothesis-max-examples",
"--max-examples",
action="store",
- default=None,
+ default=100,
+ type=int,
help="set the Hypothesis max_examples setting",
)
-
- # Add an option to disable the Hypothesis deadline
+ # Hypothesis deadline
parser.addoption(
"--hypothesis-disable-deadline",
"--disable-deadline",
action="store_true",
help="disable the Hypothesis deadline",
)
+ # Hypothesis derandomize
+ parser.addoption(
+ "--hypothesis-derandomize",
+ "--derandomize",
+ action="store_true",
+ help="set the Hypothesis derandomize parameter",
+ )
+ # disable extensions
+ parser.addoption(
+ "--disable-extension",
+ metavar="ext",
+ nargs="+",
+ default=[],
+ help="disable testing for Array API extension(s)",
+ )
+ # data-dependent shape
+ parser.addoption(
+ "--disable-data-dependent-shapes",
+ "--disable-dds",
+ action="store_true",
+ help="disable testing functions with output shapes dependent on input",
+ )
+ # CI
+ parser.addoption("--ci", action="store_true", help=argparse.SUPPRESS ) # deprecated
+ parser.addoption(
+ "--skips-file",
+ action="store",
+ help="file with tests to skip. Defaults to skips.txt"
+ )
+ parser.addoption(
+ "--xfails-file",
+ action="store",
+ help="file with tests to skip. Defaults to xfails.txt"
+ )
def pytest_configure(config):
- # Set Hypothesis max_examples.
- hypothesis_max_examples = config.getoption("--hypothesis-max-examples")
- disable_deadline = config.getoption('--hypothesis-disable-deadline')
- profile_settings = {}
- if hypothesis_max_examples is not None:
- profile_settings['max_examples'] = int(hypothesis_max_examples)
- if disable_deadline is not None:
- profile_settings['deadline'] = None
- if profile_settings:
- import hypothesis
-
- hypothesis.settings.register_profile(
- "array-api-tests-hypothesis-overridden", **profile_settings,
+ config.addinivalue_line(
+ "markers", "xp_extension(ext): tests an Array API extension"
+ )
+ config.addinivalue_line(
+ "markers", "data_dependent_shapes: output shapes are dependent on inputs"
+ )
+ config.addinivalue_line(
+ "markers",
+ "min_version(api_version): run when greater or equal to api_version",
+ )
+ config.addinivalue_line(
+ "markers",
+ "unvectorized: asserts against values via element-wise iteration (not performative!)",
+ )
+ # Hypothesis
+ deadline = None if config.getoption("--hypothesis-disable-deadline") else 800
+ settings.register_profile(
+ "array-api-tests",
+ max_examples=config.getoption("--hypothesis-max-examples"),
+ derandomize=config.getoption("--hypothesis-derandomize"),
+ deadline=deadline,
+ )
+ settings.load_profile("array-api-tests")
+ # CI
+ if config.getoption("--ci"):
+ warnings.warn(
+ "Custom pytest option --ci is deprecated as any tests not for CI "
+ "are now located in meta_tests/"
+ )
+
+
+@lru_cache
+def xp_has_ext(ext: str) -> bool:
+ try:
+ return not isinstance(getattr(xp, ext), _UndefinedStub)
+ except AttributeError:
+ return False
+
+
+def check_id_match(id_, pattern):
+ id_ = id_.removeprefix('array-api-tests/')
+
+ if id_ == pattern:
+ return True
+
+ if id_.startswith(pattern.removesuffix("/") + "/"):
+ return True
+
+ if pattern.endswith(".py") and id_.startswith(pattern):
+ return True
+
+ if id_.split("::", maxsplit=2)[0] == pattern:
+ return True
+
+ if id_.split("[", maxsplit=2)[0] == pattern:
+ return True
+
+ return False
+
+
+def get_xfail_mark():
+ """Skip or xfail tests from the xfails-file.txt."""
+ m = os.environ.get("ARRAY_API_TESTS_XFAIL_MARK", "xfail")
+ if m == "xfail":
+ return mark.xfail
+ elif m == "skip":
+ return mark.skip
+ else:
+ raise ValueError(
+ f'ARRAY_API_TESTS_XFAIL_MARK value should be one of "skip" or "xfail" '
+ f'got {m} instead.'
)
- hypothesis.settings.load_profile("array-api-tests-hypothesis-overridden")
+def pytest_collection_modifyitems(config, items):
+ # 1. Prepare for iterating over items
+ # -----------------------------------
+
+ skips_file = skips_path = config.getoption('--skips-file')
+ if skips_file is None:
+ skips_file = Path(__file__).parent / "skips.txt"
+ if skips_file.exists():
+ skips_path = skips_file
+
+ skip_ids = []
+ if skips_path:
+ with open(os.path.expanduser(skips_path)) as f:
+ for line in f:
+ if line.startswith("array_api_tests"):
+ id_ = line.strip("\n")
+ skip_ids.append(id_)
-settings.register_profile('array_api_tests_hypothesis_profile', deadline=800)
-settings.load_profile('array_api_tests_hypothesis_profile')
+ xfails_file = xfails_path = config.getoption('--xfails-file')
+ if xfails_file is None:
+ xfails_file = Path(__file__).parent / "xfails.txt"
+ if xfails_file.exists():
+ xfails_path = xfails_file
+
+ xfail_ids = []
+ if xfails_path:
+ with open(os.path.expanduser(xfails_path)) as f:
+ for line in f:
+ if not line.strip() or line.startswith('#'):
+ continue
+ id_ = line.strip("\n")
+ xfail_ids.append(id_)
+
+ skip_id_matched = {id_: False for id_ in skip_ids}
+ xfail_id_matched = {id_: False for id_ in xfail_ids}
+
+ disabled_exts = config.getoption("--disable-extension")
+ disabled_dds = config.getoption("--disable-data-dependent-shapes")
+ unvectorized_max_examples = max(1, config.getoption("--hypothesis-max-examples")//10)
+
+ # 2. Iterate through items and apply markers accordingly
+ # ------------------------------------------------------
+
+ xfail_mark = get_xfail_mark()
+
+ for item in items:
+ markers = list(item.iter_markers())
+ # skip if specified in skips file
+ for id_ in skip_ids:
+ if check_id_match(item.nodeid, id_):
+ item.add_marker(mark.skip(reason=f"--skips-file ({skips_file})"))
+ skip_id_matched[id_] = True
+ break
+ # xfail if specified in xfails file
+ for id_ in xfail_ids:
+ if check_id_match(item.nodeid, id_):
+ item.add_marker(xfail_mark(reason=f"--xfails-file ({xfails_file})"))
+ xfail_id_matched[id_] = True
+ break
+ # skip if disabled or non-existent extension
+ ext_mark = next((m for m in markers if m.name == "xp_extension"), None)
+ if ext_mark is not None:
+ ext = ext_mark.args[0]
+ if ext in disabled_exts:
+ item.add_marker(
+ mark.skip(reason=f"{ext} disabled in --disable-extensions")
+ )
+ elif not xp_has_ext(ext):
+ item.add_marker(mark.skip(reason=f"{ext} not found in array module"))
+ # skip if disabled by dds flag
+ if disabled_dds:
+ for m in markers:
+ if m.name == "data_dependent_shapes":
+ item.add_marker(
+ mark.skip(reason="disabled via --disable-data-dependent-shapes")
+ )
+ break
+ # skip if test is for greater api_version
+ ver_mark = next((m for m in markers if m.name == "min_version"), None)
+ if ver_mark is not None:
+ min_version = ver_mark.args[0]
+ if api_version < min_version:
+ item.add_marker(
+ mark.skip(
+ reason=f"requires ARRAY_API_TESTS_VERSION >= {min_version}"
+ )
+ )
+ # reduce max generated Hypothesis example for unvectorized tests
+ if any(m.name == "unvectorized" for m in markers):
+ # TODO: limit generated examples when settings already applied
+ if not hasattr(item.obj, "_hypothesis_internal_settings_applied"):
+ try:
+ item.obj = settings(max_examples=unvectorized_max_examples)(item.obj)
+ except InvalidArgument as e:
+ warnings.warn(
+ f"Tried decorating {item.name} with settings() but got "
+ f"hypothesis.errors.InvalidArgument: {e}"
+ )
+
+
+ # 3. Warn on bad skipped/xfailed ids
+ # ----------------------------------
+
+ bad_ids_end_msg = (
+ "Note the relevant tests might not have been collected by pytest, or "
+ "another specified id might have already matched a test."
+ )
+ bad_skip_ids = [id_ for id_, matched in skip_id_matched.items() if not matched]
+ if bad_skip_ids:
+ f_bad_ids = "\n".join(f" {id_}" for id_ in bad_skip_ids)
+ warnings.warn(
+ f"{len(bad_skip_ids)} ids in skips file don't match any collected tests: \n"
+ f"{f_bad_ids}\n"
+ f"(skips file: {skips_file})\n"
+ f"{bad_ids_end_msg}"
+ )
+ bad_xfail_ids = [id_ for id_, matched in xfail_id_matched.items() if not matched]
+ if bad_xfail_ids:
+ f_bad_ids = "\n".join(f" {id_}" for id_ in bad_xfail_ids)
+ warnings.warn(
+ f"{len(bad_xfail_ids)} ids in xfails file don't match any collected tests: \n"
+ f"{f_bad_ids}\n"
+ f"(xfails file: {xfails_file})\n"
+ f"{bad_ids_end_msg}"
+ )
diff --git a/generate_stubs.py b/generate_stubs.py
deleted file mode 100755
index c6e1d6d3..00000000
--- a/generate_stubs.py
+++ /dev/null
@@ -1,669 +0,0 @@
-#!/usr/bin/env python
-"""
-Generate stub files for the tests.
-
-To run the script, first clone the https://github.com/data-apis/array-api
-repo, then run
-
-./generate_stubs.py path/to/clone/of/array-api
-
-This will update the stub files in array_api_tests/function_stubs/
-"""
-import argparse
-import os
-import sys
-import ast
-from collections import defaultdict
-
-import regex
-from removestar.removestar import fix_code
-
-FUNCTION_HEADER_RE = regex.compile(r'\(function-(.*?)\)')
-HEADER_RE = regex.compile(r'\((?:function|method|constant|attribute)-(.*?)\)')
-FUNCTION_RE = regex.compile(r'\(function-.*\)=\n#+ ?(.*\(.*\))')
-METHOD_RE = regex.compile(r'\(method-.*\)=\n#+ ?(.*\(.*\))')
-CONSTANT_RE = regex.compile(r'\(constant-.*\)=\n#+ ?(.*)')
-ATTRIBUTE_RE = regex.compile(r'\(attribute-.*\)=\n#+ ?(.*)')
-NAME_RE = regex.compile(r'(.*)\(.*\)')
-
-STUB_FILE_HEADER = '''\
-"""
-Function stubs for {title}.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-
-See
-https://github.com/data-apis/array-api/blob/master/spec/API_specification/{filename}
-"""
-
-from __future__ import annotations
-
-from ._types import *
-from .constants import *
-'''
-# ^ Constants are used in some of the type annotations
-
-INIT_HEADER = '''\
-"""
-Stub definitions for functions defined in the spec
-
-These are used to test function signatures.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-__all__ = []
-'''
-
-SPECIAL_CASES_HEADER = '''\
-"""
-Special cases tests for {func}.
-
-These tests are generated from the special cases listed in the spec.
-
-NOTE: This file is generated automatically by the generate_stubs.py script. Do
-not modify it directly.
-"""
-
-from ..array_helpers import *
-from ..hypothesis_helpers import numeric_arrays
-from .._array_module import {func}
-
-from hypothesis import given
-
-'''
-
-TYPES_HEADER = '''\
-"""
-This file defines the types for type annotations.
-
-The type variables should be replaced with the actual types for a given
-library, e.g., for NumPy TypeVar('array') would be replaced with ndarray.
-"""
-
-from typing import Literal, Optional, Tuple, Union, TypeVar
-
-array = TypeVar('array')
-device = TypeVar('device')
-dtype = TypeVar('dtype')
-
-__all__ = ['Literal', 'Optional', 'Tuple', 'Union', 'array', 'device', 'dtype']
-
-'''
-def main():
- parser = argparse.ArgumentParser(__doc__)
- parser.add_argument('array_api_repo', help="Path to clone of the array-api repository")
- parser.add_argument('--no-write', help="""Print what it would do but don't
- write any files""", action='store_false', dest='write')
- parser.add_argument('--quiet', help="""Don't print any output to the terminal""", action='store_true', dest='quiet')
- args = parser.parse_args()
-
- types_path = os.path.join('array_api_tests', 'function_stubs', '_types.py')
- if args.write:
- with open(types_path, 'w') as f:
- f.write(TYPES_HEADER)
-
- spec_dir = os.path.join(args.array_api_repo, 'spec', 'API_specification')
- modules = {}
- for filename in sorted(os.listdir(spec_dir)):
- with open(os.path.join(spec_dir, filename)) as f:
- text = f.read()
- functions = FUNCTION_RE.findall(text)
- methods = METHOD_RE.findall(text)
- constants = CONSTANT_RE.findall(text)
- attributes = ATTRIBUTE_RE.findall(text)
- if not (functions or methods or constants or attributes):
- continue
- if not args.quiet:
- print(f"Found signatures in {filename}")
- if not args.write:
- continue
- py_file = filename.replace('.md', '.py')
- py_path = os.path.join('array_api_tests', 'function_stubs', py_file)
- title = filename.replace('.md', '').replace('_', ' ')
- module_name = py_file.replace('.py', '')
- modules[module_name] = []
- if not args.quiet:
- print(f"Writing {py_path}")
-
- annotations = parse_annotations(text, verbose=not args.quiet)
-
- sigs = {}
- code = ""
- code += STUB_FILE_HEADER.format(filename=filename, title=title)
- for sig in functions + methods:
- ismethod = sig in methods
- sig = sig.replace(r'\_', '_')
- func_name = NAME_RE.match(sig).group(1)
- doc = ""
- if ismethod:
- doc = f'''
- """
- Note: {func_name} is a method of the array object.
- """'''
- if func_name not in annotations:
- print(f"Warning: No annotations found for {func_name}")
- annotated_sig = sig
- else:
- annotated_sig = add_annotation(sig, annotations[func_name])
- if not args.quiet:
- print(f"Writing stub for {annotated_sig}")
- code += f"""
-def {annotated_sig}:{doc}
- pass
-"""
- modules[module_name].append(func_name)
- sigs[func_name] = sig
- for const in constants + attributes:
- if not args.quiet:
- print(f"Writing stub for {const}")
- isattr = const in attributes
- if isattr:
- code += f"\n# Note: {const} is an attribute of the array object."
- code += f"\n{const} = None\n"
- modules[module_name].append(const)
-
- code += '\n__all__ = ['
- code += ', '.join(f"'{i}'" for i in modules[module_name])
- code += ']\n'
-
- code = fix_code(code, file=py_path, verbose=False, quiet=False)
- with open(py_path, 'w') as f:
- f.write(code)
- if filename == 'elementwise_functions.md':
- special_cases = parse_special_cases(text, verbose=not args.quiet)
- for func in special_cases:
- py_path = os.path.join('array_api_tests', 'special_cases', f'test_{func}.py')
- tests = []
- for typ in special_cases[func]:
- multiple = len(special_cases[func][typ]) > 1
- for i, m in enumerate(special_cases[func][typ], 1):
- test_name_extra = typ.lower()
- if multiple:
- test_name_extra += f"_{i}"
- try:
- test = generate_special_case_test(func, typ, m,
- test_name_extra, sigs)
- if test is None:
- raise NotImplementedError("Special case test not implemented")
- tests.append(test)
- except:
- print(f"Error with {func}() {typ}: {m.group(0)}:\n", file=sys.stderr)
- raise
- if tests:
- code = SPECIAL_CASES_HEADER.format(func=func) + '\n'.join(tests)
- # quiet=False will make it print a warning if a name is not found (indicating an error)
- code = fix_code(code, file=py_path, verbose=False, quiet=False)
- if args.write:
- with open(py_path, 'w') as f:
- f.write(code)
-
- init_path = os.path.join('array_api_tests', 'function_stubs', '__init__.py')
- if args.write:
- with open(init_path, 'w') as f:
- f.write(INIT_HEADER)
- for module_name in modules:
- f.write(f"\nfrom .{module_name} import ")
- f.write(', '.join(modules[module_name]))
- f.write('\n\n')
- f.write('__all__ += [')
- f.write(', '.join(f"'{i}'" for i in modules[module_name]))
- f.write(']\n')
-
-# (?|...) is a branch reset (regex module only feature). It works like (?:...)
-# except only the matched alternative is assigned group numbers, so \1, \2, and
-# so on will always refer to a single match from _value.
-_value = r"(?|`([^`]*)`|a (finite) number|a (positive \(i\.e\., greater than `0`\) finite) number|a (negative \(i\.e\., less than `0`\) finite) number|(finite)|(positive)|(negative)|(nonzero)|(?:a )?(nonzero finite) numbers?|an (integer) value|already (integer)-valued|an (odd integer) value|(even integer closest to `x_i`)|an implementation-dependent approximation to `([^`]*)`(?: \(rounded\))?|a (signed (?:infinity|zero)) with the mathematical sign determined by the rule already stated above|(positive mathematical sign)|(negative mathematical sign))"
-SPECIAL_CASE_REGEXS = dict(
- ONE_ARG_EQUAL = regex.compile(rf'^- +If `x_i` is {_value}, the result is {_value}\.$'),
- ONE_ARG_GREATER = regex.compile(rf'^- +If `x_i` is greater than {_value}, the result is {_value}\.$'),
- ONE_ARG_LESS = regex.compile(rf'^- +If `x_i` is less than {_value}, the result is {_value}\.$'),
- ONE_ARG_EITHER = regex.compile(rf'^- +If `x_i` is either {_value} or {_value}, the result is {_value}\.$'),
- ONE_ARG_TWO_INTEGERS_EQUALLY_CLOSE = regex.compile(rf'^- +If two integers are equally close to `x_i`, the result is the {_value}\.$'),
-
- TWO_ARGS_EQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_GREATER__EQUAL = regex.compile(rf'^- +If `x1_i` is greater than {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_GREATER_EQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is greater than {_value}, `x1_i` is {_value}, and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_LESS__EQUAL = regex.compile(rf'^- +If `x1_i` is less than {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_LESS_EQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is less than {_value}, `x1_i` is {_value}, and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_LESS_EQUAL__EQUAL_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is less than {_value}, `x1_i` is {_value}, `x2_i` is {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__GREATER = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is greater than {_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__LESS = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is less than {_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is not (?:equal to )?{_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__LESS_EQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is less than {_value}, and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__LESS_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is less than {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__GREATER_EQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is greater than {_value}, and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__GREATER_NOTEQUAL = regex.compile(rf'^- +If `x1_i` is {_value}, `x2_i` is greater than {_value}, and `x2_i` is not {_value}, the result is {_value}\.$'),
- TWO_ARGS_NOTEQUAL__EQUAL = regex.compile(rf'^- +If `x1_i` is not equal to {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_ABSEQUAL__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_ABSGREATER__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is greater than {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_ABSLESS__EQUAL = regex.compile(rf'^- +If `abs\(x1_i\)` is less than {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_EITHER = regex.compile(rf'^- +If either `x1_i` or `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_EITHER__EQUAL = regex.compile(rf'^- +If `x1_i` is either {_value} or {_value} and `x2_i` is {_value}, the result is {_value}\.$'),
- TWO_ARGS_EQUAL__EITHER = regex.compile(rf'^- +If `x1_i` is {_value} and `x2_i` is either {_value} or {_value}, the result is {_value}\.$'),
- TWO_ARGS_EITHER__EITHER = regex.compile(rf'^- +If `x1_i` is either {_value} or {_value} and `x2_i` is either {_value} or {_value}, the result is {_value}\.$'),
- TWO_ARGS_SAME_SIGN = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign, the result has a {_value}\.$'),
- TWO_ARGS_SAME_SIGN_EXCEPT = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign, the result has a {_value}, unless the result is {_value}\. If the result is {_value}, the "sign" of {_value} is implementation-defined\.$'),
- TWO_ARGS_SAME_SIGN_BOTH = regex.compile(rf'^- +If `x1_i` and `x2_i` have the same mathematical sign and are both {_value}, the result has a {_value}\.$'),
- TWO_ARGS_DIFFERENT_SIGNS = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs, the result has a {_value}\.$'),
- TWO_ARGS_DIFFERENT_SIGNS_EXCEPT = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs, the result has a {_value}, unless the result is {_value}\. If the result is {_value}, the "sign" of {_value} is implementation-defined\.$'),
- TWO_ARGS_DIFFERENT_SIGNS_BOTH = regex.compile(rf'^- +If `x1_i` and `x2_i` have different mathematical signs and are both {_value}, the result has a {_value}\.$'),
- TWO_ARGS_EVEN_IF = regex.compile(rf'^- +If `x2_i` is {_value}, the result is {_value}, even if `x1_i` is {_value}\.$'),
-
- REMAINING = regex.compile(r"^- +In the remaining cases, (.*)$"),
-)
-
-
-def parse_value(value, arg):
- if value == 'NaN':
- return f"NaN({arg}.shape, {arg}.dtype)"
- elif value == "+infinity":
- return f"infinity({arg}.shape, {arg}.dtype)"
- elif value == "-infinity":
- return f"-infinity({arg}.shape, {arg}.dtype)"
- elif value in ["0", "+0"]:
- return f"zero({arg}.shape, {arg}.dtype)"
- elif value == "-0":
- return f"-zero({arg}.shape, {arg}.dtype)"
- elif value in ["1", "+1"]:
- return f"one({arg}.shape, {arg}.dtype)"
- elif value == "-1":
- return f"-one({arg}.shape, {arg}.dtype)"
- # elif value == 'signed infinity':
- elif value == 'signed zero':
- return f"zero({arg}.shape, {arg}.dtype))"
- elif 'π' in value:
- value = regex.sub(r'(\d+)π', r'\1*π', value)
- return value.replace('π', f'π({arg}.shape, {arg}.dtype)')
- elif 'x1_i' in value or 'x2_i' in value:
- return value
- elif value.startswith('where('):
- return value
- elif value in ['finite', 'nonzero', 'nonzero finite',
- "integer", "odd integer", "positive",
- "negative", "positive mathematical sign",
- "negative mathematical sign"]:
- return value
- # There's no way to remove the parenthetical from the matching group in
- # the regular expression.
- elif value == "positive (i.e., greater than `0`) finite":
- return "positive finite"
- elif value == 'negative (i.e., less than `0`) finite':
- return "negative finite"
- else:
- raise RuntimeError(f"Unexpected input value {value!r}")
-
-def _check_exactly_equal(typ, value):
- if not typ == 'exactly_equal':
- raise RuntimeError(f"Unexpected mask type {typ}: {value}")
-
-def get_mask(typ, arg, value):
- if typ.startswith("not"):
- if value.startswith('zero('):
- return f"notequal({arg}, {value})"
- return f"logical_not({get_mask(typ[len('not'):], arg, value)})"
- if typ.startswith("abs"):
- return get_mask(typ[len("abs"):], f"abs({arg})", value)
- if value == 'finite':
- _check_exactly_equal(typ, value)
- return f"isfinite({arg})"
- elif value == 'nonzero':
- _check_exactly_equal(typ, value)
- return f"non_zero({arg})"
- elif value == 'positive finite':
- _check_exactly_equal(typ, value)
- return f"logical_and(isfinite({arg}), ispositive({arg}))"
- elif value == 'negative finite':
- _check_exactly_equal(typ, value)
- return f"logical_and(isfinite({arg}), isnegative({arg}))"
- elif value == 'nonzero finite':
- _check_exactly_equal(typ, value)
- return f"logical_and(isfinite({arg}), non_zero({arg}))"
- elif value == 'positive':
- _check_exactly_equal(typ, value)
- return f"ispositive({arg})"
- elif value == 'positive mathematical sign':
- _check_exactly_equal(typ, value)
- return f"positive_mathematical_sign({arg})"
- elif value == 'negative':
- _check_exactly_equal(typ, value)
- return f"isnegative({arg})"
- elif value == 'negative mathematical sign':
- _check_exactly_equal(typ, value)
- return f"negative_mathematical_sign({arg})"
- elif value == 'integer':
- _check_exactly_equal(typ, value)
- return f"isintegral({arg})"
- elif value == 'odd integer':
- _check_exactly_equal(typ, value)
- return f"isodd({arg})"
- elif 'x_i' in value:
- return f"{typ}({arg}, {value.replace('x_i', 'arg1')})"
- elif 'x1_i' in value:
- return f"{typ}({arg}, {value.replace('x1_i', 'arg1')})"
- elif 'x2_i' in value:
- return f"{typ}({arg}, {value.replace('x2_i', 'arg2')})"
- return f"{typ}({arg}, {value})"
-
-def get_assert(typ, result):
- # TODO: Refactor this so typ is actually what it should be
- if result == "signed infinity":
- _check_exactly_equal(typ, result)
- return "assert_isinf(res[mask])"
- elif result == "positive":
- _check_exactly_equal(typ, result)
- return "assert_positive(res[mask])"
- elif result == "positive mathematical sign":
- _check_exactly_equal(typ, result)
- return "assert_positive_mathematical_sign(res[mask])"
- elif result == "negative":
- _check_exactly_equal(typ, result)
- return "assert_negative(res[mask])"
- elif result == "negative mathematical sign":
- _check_exactly_equal(typ, result)
- return "assert_negative_mathematical_sign(res[mask])"
- elif result == 'even integer closest to `x_i`':
- _check_exactly_equal(typ, result)
- return "assert_iseven(res[mask])\n assert_positive(subtract(one(arg1[mask].shape, arg1[mask].dtype), abs(subtract(arg1[mask], res[mask]))))"
- elif 'x_i' in result:
- return f"assert_{typ}(res[mask], ({result.replace('x_i', 'arg1')})[mask])"
- elif 'x1_i' in result:
- return f"assert_{typ}(res[mask], ({result.replace('x1_i', 'arg1')})[mask])"
- elif 'x2_i' in result:
- return f"assert_{typ}(res[mask], ({result.replace('x2_i', 'arg2')})[mask])"
-
- # TODO: Get use something better than arg1 here for the arg
- result = parse_value(result, "arg1")
- try:
- # This won't catch all unknown values, but will catch some.
- ast.parse(result)
- except SyntaxError:
- raise RuntimeError(f"Unexpected result value {result!r} for {typ} (bad syntax)")
- return f"assert_{typ}(res[mask], ({result})[mask])"
-
-ONE_ARG_TEMPLATE = """
-{decorator}
-def test_{func}_special_cases_{test_name_extra}(arg1):
- {doc}
- res = {func}(arg1)
- mask = {mask}
- {assertion}
-"""
-
-TWO_ARGS_TEMPLATE = """
-{decorator}
-def test_{func}_special_cases_{test_name_extra}(arg1, arg2):
- {doc}
- res = {func}(arg1, arg2)
- mask = {mask}
- {assertion}
-"""
-
-REMAINING_TEMPLATE = """# TODO: Implement REMAINING test for:
-# {text}
-"""
-
-def generate_special_case_test(func, typ, m, test_name_extra, sigs):
-
- doc = f'''"""
- Special case test for `{sigs[func]}`:
-
- {m.group(0)}
-
- """'''
- if typ.startswith("ONE_ARG"):
- decorator = "@given(numeric_arrays)"
- if typ == "ONE_ARG_EQUAL":
- value1, result = m.groups()
- value1 = parse_value(value1, 'arg1')
- mask = get_mask("exactly_equal", "arg1", value1)
- elif typ == "ONE_ARG_GREATER":
- value1, result = m.groups()
- value1 = parse_value(value1, 'arg1')
- mask = get_mask("greater", "arg1", value1)
- elif typ == "ONE_ARG_LESS":
- value1, result = m.groups()
- value1 = parse_value(value1, 'arg1')
- mask = get_mask("less", "arg1", value1)
- elif typ == "ONE_ARG_EITHER":
- value1, value2, result = m.groups()
- value1 = parse_value(value1, 'arg1')
- value2 = parse_value(value2, 'arg1')
- mask1 = get_mask("exactly_equal", "arg1", value1)
- mask2 = get_mask("exactly_equal", "arg1", value2)
- mask = f"logical_or({mask1}, {mask2})"
- elif typ == "ONE_ARG_ALREADY_INTEGER_VALUED":
- result, = m.groups()
- mask = parse_value("integer", "arg1")
- elif typ == "ONE_ARG_TWO_INTEGERS_EQUALLY_CLOSE":
- result, = m.groups()
- mask = "logical_and(not_equal(floor(arg1), ceil(arg1)), equal(subtract(arg1, floor(arg1)), subtract(ceil(arg1), arg1)))"
- else:
- raise ValueError(f"Unrecognized special value type {typ}")
- assertion = get_assert("exactly_equal", result)
- return ONE_ARG_TEMPLATE.format(
- decorator=decorator,
- func=func,
- test_name_extra=test_name_extra,
- doc=doc,
- mask=mask,
- assertion=assertion,
- )
-
- elif typ.startswith("TWO_ARGS"):
- decorator = "@given(numeric_arrays, numeric_arrays)"
- if typ in [
- "TWO_ARGS_EQUAL__EQUAL",
- "TWO_ARGS_GREATER__EQUAL",
- "TWO_ARGS_LESS__EQUAL",
- "TWO_ARGS_EQUAL__GREATER",
- "TWO_ARGS_EQUAL__LESS",
- "TWO_ARGS_EQUAL__NOTEQUAL",
- "TWO_ARGS_NOTEQUAL__EQUAL",
- "TWO_ARGS_ABSEQUAL__EQUAL",
- "TWO_ARGS_ABSGREATER__EQUAL",
- "TWO_ARGS_ABSLESS__EQUAL",
- "TWO_ARGS_GREATER_EQUAL__EQUAL",
- "TWO_ARGS_LESS_EQUAL__EQUAL",
- "TWO_ARGS_EQUAL__LESS_EQUAL",
- "TWO_ARGS_EQUAL__LESS_NOTEQUAL",
- "TWO_ARGS_EQUAL__GREATER_EQUAL",
- "TWO_ARGS_EQUAL__GREATER_NOTEQUAL",
- "TWO_ARGS_LESS_EQUAL__EQUAL_NOTEQUAL",
- "TWO_ARGS_EITHER__EQUAL",
- "TWO_ARGS_EQUAL__EITHER",
- "TWO_ARGS_EITHER__EITHER",
- ]:
- arg1typs, arg2typs = [i.split('_') for i in typ[len("TWO_ARGS_"):].split("__")]
- if arg1typs == ["EITHER"]:
- arg1typs = ["EITHER_EQUAL", "EITHER_EQUAL"]
- if arg2typs == ["EITHER"]:
- arg2typs = ["EITHER_EQUAL", "EITHER_EQUAL"]
- *values, result = m.groups()
- if len(values) != len(arg1typs) + len(arg2typs):
- raise RuntimeError(f"Unexpected number of parsed values for {typ}: len({values}) != len({arg1typs}) + len({arg2typs})")
- arg1values, arg2values = values[:len(arg1typs)], values[len(arg1typs):]
- arg1values = [parse_value(value, 'arg1') for value in arg1values]
- arg2values = [parse_value(value, 'arg2') for value in arg2values]
-
- tomask = lambda t: t.lower().replace("either_equal", "equal").replace("equal", "exactly_equal")
- value1masks = [get_mask(tomask(t), 'arg1', v) for t, v in
- zip(arg1typs, arg1values)]
- value2masks = [get_mask(tomask(t), 'arg2', v) for t, v in
- zip(arg2typs, arg2values)]
- if len(value1masks) > 1:
- if arg1typs[0] == "EITHER_EQUAL":
- mask1 = f"logical_or({value1masks[0]}, {value1masks[1]})"
- else:
- mask1 = f"logical_and({value1masks[0]}, {value1masks[1]})"
- else:
- mask1 = value1masks[0]
- if len(value2masks) > 1:
- if arg2typs[0] == "EITHER_EQUAL":
- mask2 = f"logical_or({value2masks[0]}, {value2masks[1]})"
- else:
- mask2 = f"logical_and({value2masks[0]}, {value2masks[1]})"
- else:
- mask2 = value2masks[0]
-
- mask = f"logical_and({mask1}, {mask2})"
- assertion = get_assert("exactly_equal", result)
-
- elif typ == "TWO_ARGS_EITHER":
- value, result = m.groups()
- value = parse_value(value, "arg1")
- mask1 = get_mask("exactly_equal", "arg1", value)
- mask2 = get_mask("exactly_equal", "arg2", value)
- mask = f"logical_or({mask1}, {mask2})"
- assertion = get_assert("exactly_equal", result)
- elif typ == "TWO_ARGS_SAME_SIGN":
- result, = m.groups()
- mask = "same_sign(arg1, arg2)"
- assertion = get_assert("exactly_equal", result)
- elif typ == "TWO_ARGS_SAME_SIGN_EXCEPT":
- result, value, value1, value2 = m.groups()
- assert value == value1 == value2
- value = parse_value(value, "res")
- mask = f"logical_and(same_sign(arg1, arg2), logical_not(exactly_equal(res, {value})))"
- assertion = get_assert("exactly_equal", result)
- elif typ == "TWO_ARGS_SAME_SIGN_BOTH":
- value, result = m.groups()
- mask1 = get_mask("exactly_equal", "arg1", value)
- mask2 = get_mask("exactly_equal", "arg2", value)
- mask = f"logical_and(same_sign(arg1, arg2), logical_and({mask1}, {mask2}))"
- assertion = get_assert("exactly_equal", result)
- elif typ == "TWO_ARGS_DIFFERENT_SIGNS":
- result, = m.groups()
- mask = "logical_not(same_sign(arg1, arg2))"
- assertion = get_assert("exactly_equal", result)
- elif typ == "TWO_ARGS_DIFFERENT_SIGNS_EXCEPT":
- result, value, value1, value2 = m.groups()
- assert value == value1 == value2
- value = parse_value(value, "res")
- mask = f"logical_and(logical_not(same_sign(arg1, arg2)), logical_not(exactly_equal(res, {value})))"
- assertion = get_assert("exactly_equal", result)
- elif typ == "TWO_ARGS_DIFFERENT_SIGNS_BOTH":
- value, result = m.groups()
- mask1 = get_mask("exactly_equal", "arg1", value)
- mask2 = get_mask("exactly_equal", "arg2", value)
- mask = f"logical_and(logical_not(same_sign(arg1, arg2)), logical_and({mask1}, {mask2}))"
- assertion = get_assert("exactly_equal", result)
- elif typ == "TWO_ARGS_EVEN_IF":
- value1, result, value2 = m.groups()
- value1 = parse_value(value1, "arg2")
- mask = get_mask("exactly_equal", "arg2", value1)
- assertion = get_assert("exactly_equal", result)
- else:
- raise ValueError(f"Unrecognized special value type {typ}")
- return TWO_ARGS_TEMPLATE.format(
- decorator=decorator,
- func=func,
- test_name_extra=test_name_extra,
- doc=doc,
- mask=mask,
- assertion=assertion,
- )
-
- elif typ == "REMAINING":
- return REMAINING_TEMPLATE.format(text=m.group(0))
- else:
- raise RuntimeError(f"Unexpected type {typ}")
-
-def parse_special_cases(spec_text, verbose=False):
- special_cases = {}
- in_block = False
- for line in spec_text.splitlines():
- m = FUNCTION_HEADER_RE.match(line)
- if m:
- name = m.group(1)
- special_cases[name] = defaultdict(list)
- continue
- if line == '#### Special Cases':
- in_block = True
- continue
- elif line.startswith('#'):
- in_block = False
- continue
- if in_block:
- if '- ' not in line:
- continue
- for typ, reg in SPECIAL_CASE_REGEXS.items():
- m = reg.match(line)
- if m:
- if verbose:
- print(f"Matched {typ} for {name}: {m.groups()}")
- special_cases[name][typ].append(m)
- break
- else:
- raise ValueError(f"Unrecognized special case string for '{name}':\n{line}")
-
- return special_cases
-
-PARAMETER_RE = regex.compile(r"- +\*\*(.*)\*\*: _(.*)_")
-def parse_annotations(spec_text, verbose=False):
- annotations = defaultdict(dict)
- in_block = False
- for line in spec_text.splitlines():
- m = HEADER_RE.match(line)
- if m:
- name = m.group(1)
- continue
- if line == '#### Parameters':
- in_block = True
- continue
- elif line == '#### Returns':
- in_block = True
- continue
- elif line.startswith('#'):
- in_block = False
- continue
- if in_block:
- if not line.startswith('- '):
- continue
- m = PARAMETER_RE.match(line)
- if m:
- param, typ = m.groups()
- typ = clean_type(typ)
- if verbose:
- print(f"Matched parameter for {name}: {param}: {typ}")
- annotations[name][param] = typ
- else:
- raise ValueError(f"Unrecognized special case string for '{name}':\n{line}")
-
- return annotations
-
-def clean_type(typ):
- # TODO: How to handle dtypes in annotations? For now, we just remove the
- # one that exists (from where()).
- typ = typ.replace('<bool>', '')
- typ = typ.replace('<', '')
- typ = typ.replace('>', '')
- typ = typ.replace('\\', '')
- typ = typ.replace(' ', '')
- typ = typ.replace(',', ', ')
- return typ
-
-def add_annotation(sig, annotation):
- if 'out' not in annotation:
- raise RuntimeError(f"No return annotation for {sig}")
- for param, typ in annotation.items():
- if param == 'out':
- sig = f"{sig} -> {typ}"
- continue
- PARAM_DEFAULT = regex.compile(rf"([\( ]{param})=")
- sig2 = PARAM_DEFAULT.sub(rf'\1: {typ} = ', sig)
- if sig2 != sig:
- sig = sig2
- continue
- PARAM = regex.compile(rf"([\( ]{param})([,\)])")
- sig2 = PARAM.sub(rf'\1: {typ}\2', sig)
- if sig2 != sig:
- sig = sig2
- continue
- raise RuntimeError(f"Parameter {param} not found in {sig}")
- return sig
-
-if __name__ == '__main__':
- main()
diff --git a/meta_tests/README.md b/meta_tests/README.md
new file mode 100644
index 00000000..fb563cf6
--- /dev/null
+++ b/meta_tests/README.md
@@ -0,0 +1 @@
+Testing the utilities used in `array_api_tests/`
\ No newline at end of file
diff --git a/array_api_tests/meta_tests/__init__.py b/meta_tests/__init__.py
similarity index 100%
rename from array_api_tests/meta_tests/__init__.py
rename to meta_tests/__init__.py
diff --git a/meta_tests/test_array_helpers.py b/meta_tests/test_array_helpers.py
new file mode 100644
index 00000000..c46df50d
--- /dev/null
+++ b/meta_tests/test_array_helpers.py
@@ -0,0 +1,35 @@
+from hypothesis import given
+from hypothesis import strategies as st
+
+from array_api_tests import _array_module as xp
+from array_api_tests.hypothesis_helpers import (int_dtypes, arrays,
+ two_mutually_broadcastable_shapes)
+from array_api_tests.shape_helpers import iter_indices, broadcast_shapes
+from array_api_tests .array_helpers import exactly_equal, notequal, less
+
+# TODO: These meta-tests currently only work with NumPy
+
+def test_exactly_equal():
+ a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1])
+ b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2])
+
+ res = xp.asarray([True, False, True, False, True, False, True, False])
+ assert xp.all(xp.equal(exactly_equal(a, b), res))
+
+def test_notequal():
+ a = xp.asarray([0, 0., -0., -0., xp.nan, xp.nan, 1, 1])
+ b = xp.asarray([0, -1, -0., 0., xp.nan, 1, 1, 2])
+
+ res = xp.asarray([False, True, False, False, False, True, False, True])
+ assert xp.all(xp.equal(notequal(a, b), res))
+
+
+@given(two_mutually_broadcastable_shapes, int_dtypes, int_dtypes, st.data())
+def test_less(shapes, dtype1, dtype2, data):
+ x = data.draw(arrays(shape=shapes[0], dtype=dtype1))
+ y = data.draw(arrays(shape=shapes[1], dtype=dtype2))
+
+ res = less(x, y)
+
+ for i, j, k in iter_indices(x.shape, y.shape, broadcast_shapes(x.shape, y.shape)):
+ assert res[k] == (int(x[i]) < int(y[j]))
diff --git a/meta_tests/test_broadcasting.py b/meta_tests/test_broadcasting.py
new file mode 100644
index 00000000..2f6310c1
--- /dev/null
+++ b/meta_tests/test_broadcasting.py
@@ -0,0 +1,35 @@
+"""
+https://github.com/data-apis/array-api/blob/master/spec/API_specification/broadcasting.md
+"""
+
+import pytest
+
+from array_api_tests import shape_helpers as sh
+
+
+@pytest.mark.parametrize(
+ "shape1, shape2, expected",
+ [
+ [(8, 1, 6, 1), (7, 1, 5), (8, 7, 6, 5)],
+ [(5, 4), (1,), (5, 4)],
+ [(5, 4), (4,), (5, 4)],
+ [(15, 3, 5), (15, 1, 5), (15, 3, 5)],
+ [(15, 3, 5), (3, 5), (15, 3, 5)],
+ [(15, 3, 5), (3, 1), (15, 3, 5)],
+ ],
+)
+def test_broadcast_shapes(shape1, shape2, expected):
+ assert sh._broadcast_shapes(shape1, shape2) == expected
+
+
+@pytest.mark.parametrize(
+ "shape1, shape2",
+ [
+ [(3,), (4,)], # dimension does not match
+ [(2, 1), (8, 4, 3)], # second dimension does not match
+ [(15, 3, 5), (15, 3)], # singleton dimensions can only be prepended
+ ],
+)
+def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
+ with pytest.raises(sh.BroadcastError):
+ sh._broadcast_shapes(shape1, shape2)
diff --git a/meta_tests/test_equality_mapping.py b/meta_tests/test_equality_mapping.py
new file mode 100644
index 00000000..8ac481f6
--- /dev/null
+++ b/meta_tests/test_equality_mapping.py
@@ -0,0 +1,37 @@
+import pytest
+
+from array_api_tests .dtype_helpers import EqualityMapping
+
+
+def test_raises_on_distinct_eq_key():
+ with pytest.raises(ValueError):
+ EqualityMapping([(float("nan"), "value")])
+
+
+def test_raises_on_indistinct_eq_keys():
+ class AlwaysEq:
+ def __init__(self, hash):
+ self._hash = hash
+
+ def __eq__(self, other):
+ return True
+
+ def __hash__(self):
+ return self._hash
+
+ with pytest.raises(ValueError):
+ EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")])
+
+
+def test_key_error():
+ mapping = EqualityMapping([("key", "value")])
+ with pytest.raises(KeyError):
+ mapping["nonexistent key"]
+
+
+def test_iter():
+ mapping = EqualityMapping([("key", "value")])
+ it = iter(mapping)
+ assert next(it) == "key"
+ with pytest.raises(StopIteration):
+ next(it)
diff --git a/meta_tests/test_hypothesis_helpers.py b/meta_tests/test_hypothesis_helpers.py
new file mode 100644
index 00000000..b14b728c
--- /dev/null
+++ b/meta_tests/test_hypothesis_helpers.py
@@ -0,0 +1,170 @@
+from math import prod
+from typing import Type
+
+import pytest
+from hypothesis import given, settings
+from hypothesis import strategies as st
+from hypothesis.errors import Unsatisfiable
+
+from array_api_tests import _array_module as xp
+from array_api_tests import array_helpers as ah
+from array_api_tests import dtype_helpers as dh
+from array_api_tests import hypothesis_helpers as hh
+from array_api_tests import shape_helpers as sh
+from array_api_tests import xps
+from array_api_tests ._array_module import _UndefinedStub
+
+UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
+pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
+
+@given(hh.mutually_promotable_dtypes(dtypes=dh.real_float_dtypes))
+def test_mutually_promotable_dtypes(pair):
+ assert pair in (
+ (xp.float32, xp.float32),
+ (xp.float32, xp.float64),
+ (xp.float64, xp.float32),
+ (xp.float64, xp.float64),
+ )
+
+
+@given(
+ hh.mutually_promotable_dtypes(
+ dtypes=[xp.uint8, _UndefinedStub("uint16"), xp.uint32]
+ )
+)
+def test_partial_mutually_promotable_dtypes(pair):
+ assert pair in (
+ (xp.uint8, xp.uint8),
+ (xp.uint8, xp.uint32),
+ (xp.uint32, xp.uint8),
+ (xp.uint32, xp.uint32),
+ )
+
+
+def valid_shape(shape) -> bool:
+ return (
+ all(isinstance(side, int) for side in shape)
+ and all(side >= 0 for side in shape)
+ and prod(shape) < hh.MAX_ARRAY_SIZE
+ )
+
+
+@given(hh.shapes())
+def test_shapes(shape):
+ assert valid_shape(shape)
+
+
+@given(hh.two_mutually_broadcastable_shapes)
+def test_two_mutually_broadcastable_shapes(pair):
+ for shape in pair:
+ assert valid_shape(shape)
+
+
+@given(hh.two_broadcastable_shapes())
+def test_two_broadcastable_shapes(pair):
+ for shape in pair:
+ assert valid_shape(shape)
+ assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0]
+
+
+@given(*hh.two_mutual_arrays())
+def test_two_mutual_arrays(x1, x2):
+ assert (x1.dtype, x2.dtype) in dh.promotion_table.keys()
+
+
+def test_two_mutual_arrays_raises_on_bad_dtypes():
+ with pytest.raises(TypeError):
+ hh.two_mutual_arrays(dtypes=xps.scalar_dtypes())
+
+
+def test_kwargs():
+ results = []
+
+ @given(hh.kwargs(n=st.integers(0, 10), c=st.from_regex("[a-f]")))
+ @settings(max_examples=100)
+ def run(kw):
+ results.append(kw)
+ run()
+
+ assert all(isinstance(kw, dict) for kw in results)
+ for size in [0, 1, 2]:
+ assert any(len(kw) == size for kw in results)
+
+ n_results = [kw for kw in results if "n" in kw]
+ assert len(n_results) > 0
+ assert all(isinstance(kw["n"], int) for kw in n_results)
+
+ c_results = [kw for kw in results if "c" in kw]
+ assert len(c_results) > 0
+ assert all(isinstance(kw["c"], str) for kw in c_results)
+
+
+def test_specified_kwargs():
+ results = []
+
+ @given(n=st.integers(0, 10), d=st.none() | xps.scalar_dtypes(), data=st.data())
+ @settings(max_examples=100)
+ def run(n, d, data):
+ kw = data.draw(
+ hh.specified_kwargs(
+ hh.KVD("n", n, 0),
+ hh.KVD("d", d, None),
+ ),
+ label="kw",
+ )
+ results.append(kw)
+ run()
+
+ assert all(isinstance(kw, dict) for kw in results)
+
+ assert any(len(kw) == 0 for kw in results)
+
+ assert any("n" not in kw.keys() for kw in results)
+ assert any("n" in kw.keys() and kw["n"] == 0 for kw in results)
+ assert any("n" in kw.keys() and kw["n"] != 0 for kw in results)
+
+ assert any("d" not in kw.keys() for kw in results)
+ assert any("d" in kw.keys() and kw["d"] is None for kw in results)
+ assert any("d" in kw.keys() and kw["d"] is xp.float64 for kw in results)
+
+
+@given(finite=st.booleans(), dtype=xps.floating_dtypes(), data=st.data())
+def test_symmetric_matrices(finite, dtype, data):
+ m = data.draw(hh.symmetric_matrices(st.just(dtype), finite=finite), label="m")
+ assert m.dtype == dtype
+ # TODO: This part of this test should be part of the .mT test
+ ah.assert_exactly_equal(m, m.mT)
+
+ if finite:
+ ah.assert_finite(m)
+
+
+@given(dtype=xps.floating_dtypes(), data=st.data())
+def test_positive_definite_matrices(dtype, data):
+ m = data.draw(hh.positive_definite_matrices(st.just(dtype)), label="m")
+ assert m.dtype == dtype
+ # TODO: Test that it actually is positive definite
+
+
+def make_raising_func(cls: Type[Exception], msg: str):
+ def raises():
+ raise cls(msg)
+
+ return raises
+
+@pytest.mark.parametrize(
+ "func",
+ [
+ make_raising_func(OverflowError, "foo"),
+ make_raising_func(RuntimeError, "Overflow when unpacking long"),
+ make_raising_func(Exception, "Got an overflow"),
+ ]
+)
+def test_reject_overflow(func):
+ @given(data=st.data())
+ def test_case(data):
+ with hh.reject_overflow():
+ func()
+
+ with pytest.raises(Unsatisfiable):
+ test_case()
diff --git a/meta_tests/test_linalg.py b/meta_tests/test_linalg.py
new file mode 100644
index 00000000..82794b6c
--- /dev/null
+++ b/meta_tests/test_linalg.py
@@ -0,0 +1,16 @@
+import pytest
+
+from hypothesis import given
+
+from array_api_tests .hypothesis_helpers import symmetric_matrices
+from array_api_tests import array_helpers as ah
+from array_api_tests import _array_module as xp
+
+@pytest.mark.xp_extension('linalg')
+@given(x=symmetric_matrices(finite=True))
+def test_symmetric_matrices(x):
+ upper = xp.triu(x)
+ lower = xp.tril(x)
+ lowerT = ah._matrix_transpose(lower)
+
+ ah.assert_exactly_equal(upper, lowerT)
diff --git a/meta_tests/test_partial_adopters.py b/meta_tests/test_partial_adopters.py
new file mode 100644
index 00000000..de3a7e76
--- /dev/null
+++ b/meta_tests/test_partial_adopters.py
@@ -0,0 +1,18 @@
+import pytest
+from hypothesis import given
+
+from array_api_tests import dtype_helpers as dh
+from array_api_tests import hypothesis_helpers as hh
+from array_api_tests import _array_module as xp
+from array_api_tests ._array_module import _UndefinedStub
+
+
+# e.g. PyTorch only supports uint8 currently
+@pytest.mark.skipif(isinstance(xp.uint8, _UndefinedStub), reason="uint8 not defined")
+@pytest.mark.skipif(
+ not all(isinstance(d, _UndefinedStub) for d in dh.uint_dtypes[1:]),
+ reason="uints defined",
+)
+@given(hh.mutually_promotable_dtypes(dtypes=dh.uint_dtypes))
+def test_mutually_promotable_dtypes(pair):
+ assert pair == (xp.uint8, xp.uint8)
diff --git a/meta_tests/test_pytest_helpers.py b/meta_tests/test_pytest_helpers.py
new file mode 100644
index 00000000..a0aa0930
--- /dev/null
+++ b/meta_tests/test_pytest_helpers.py
@@ -0,0 +1,29 @@
+from pytest import raises
+
+from array_api_tests import xp as _xp
+from array_api_tests import _array_module as xp
+from array_api_tests import pytest_helpers as ph
+
+
+def test_assert_dtype():
+ ph.assert_dtype("promoted_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.int16)
+ with raises(AssertionError):
+ ph.assert_dtype("bad_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.float32)
+ ph.assert_dtype("bool_func", in_dtype=[xp.uint8, xp.int8], out_dtype=xp.bool, expected=xp.bool)
+ ph.assert_dtype("single_promoted_func", in_dtype=[xp.uint8], out_dtype=xp.uint8)
+ ph.assert_dtype("single_bool_func", in_dtype=[xp.uint8], out_dtype=xp.bool, expected=xp.bool)
+
+
+def test_assert_array_elements():
+ ph.assert_array_elements("int zeros", out=xp.asarray(0), expected=xp.asarray(0))
+ ph.assert_array_elements("pos zeros", out=xp.asarray(0.0), expected=xp.asarray(0.0))
+ ph.assert_array_elements("neg zeros", out=xp.asarray(-0.0), expected=xp.asarray(-0.0))
+ if hasattr(_xp, "signbit"):
+ with raises(AssertionError):
+ ph.assert_array_elements("mixed sign zeros", out=xp.asarray(0.0), expected=xp.asarray(-0.0))
+ with raises(AssertionError):
+ ph.assert_array_elements("mixed sign zeros", out=xp.asarray(-0.0), expected=xp.asarray(0.0))
+
+ ph.assert_array_elements("nans", out=xp.asarray(float("nan")), expected=xp.asarray(float("nan")))
+ with raises(AssertionError):
+ ph.assert_array_elements("nan and zero", out=xp.asarray(float("nan")), expected=xp.asarray(0.0))
diff --git a/meta_tests/test_signatures.py b/meta_tests/test_signatures.py
new file mode 100644
index 00000000..937f73f3
--- /dev/null
+++ b/meta_tests/test_signatures.py
@@ -0,0 +1,67 @@
+from inspect import Parameter, Signature, signature
+
+import pytest
+
+from array_api_tests .test_signatures import _test_inspectable_func
+
+
+def stub(foo, /, bar=None, *, baz=None):
+ pass
+
+
+stub_sig = signature(stub)
+
+
+@pytest.mark.parametrize(
+ "sig",
+ [
+ Signature(
+ [
+ Parameter("foo", Parameter.POSITIONAL_ONLY),
+ Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
+ Parameter("baz", Parameter.KEYWORD_ONLY),
+ ]
+ ),
+ Signature(
+ [
+ Parameter("foo", Parameter.POSITIONAL_ONLY),
+ Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
+ Parameter("baz", Parameter.POSITIONAL_OR_KEYWORD),
+ ]
+ ),
+ Signature(
+ [
+ Parameter("foo", Parameter.POSITIONAL_ONLY),
+ Parameter("bar", Parameter.POSITIONAL_OR_KEYWORD),
+ Parameter("qux", Parameter.KEYWORD_ONLY),
+ Parameter("baz", Parameter.KEYWORD_ONLY),
+ ]
+ ),
+ ],
+)
+def test_good_sig_passes(sig):
+ _test_inspectable_func(sig, stub_sig)
+
+
+@pytest.mark.parametrize(
+ "sig",
+ [
+ Signature(
+ [
+ Parameter("foo", Parameter.POSITIONAL_ONLY),
+ Parameter("bar", Parameter.POSITIONAL_ONLY),
+ Parameter("baz", Parameter.KEYWORD_ONLY),
+ ]
+ ),
+ Signature(
+ [
+ Parameter("foo", Parameter.POSITIONAL_ONLY),
+ Parameter("bar", Parameter.KEYWORD_ONLY),
+ Parameter("baz", Parameter.KEYWORD_ONLY),
+ ]
+ ),
+ ],
+)
+def test_raises_on_bad_sig(sig):
+ with pytest.raises(AssertionError):
+ _test_inspectable_func(sig, stub_sig)
diff --git a/meta_tests/test_special_cases.py b/meta_tests/test_special_cases.py
new file mode 100644
index 00000000..40c7806c
--- /dev/null
+++ b/meta_tests/test_special_cases.py
@@ -0,0 +1,10 @@
+import math
+
+from array_api_tests .test_special_cases import parse_result
+
+
+def test_parse_result():
+ check_result, _ = parse_result(
+ "an implementation-dependent approximation to ``+3π/4``"
+ )
+ assert check_result(3 * math.pi / 4)
diff --git a/meta_tests/test_utils.py b/meta_tests/test_utils.py
new file mode 100644
index 00000000..911ba899
--- /dev/null
+++ b/meta_tests/test_utils.py
@@ -0,0 +1,120 @@
+import pytest
+from hypothesis import given
+from hypothesis import strategies as st
+
+from array_api_tests import _array_module as xp
+from array_api_tests import dtype_helpers as dh
+from array_api_tests import hypothesis_helpers as hh
+from array_api_tests import shape_helpers as sh
+from array_api_tests import xps
+from array_api_tests .test_creation_functions import frange
+from array_api_tests .test_manipulation_functions import roll_ndindex
+from array_api_tests .test_operators_and_elementwise_functions import mock_int_dtype
+
+
+@pytest.mark.parametrize(
+ "r, size, elements",
+ [
+ (frange(0, 1, 1), 1, [0]),
+ (frange(1, 0, -1), 1, [1]),
+ (frange(0, 1, -1), 0, []),
+ (frange(0, 1, 2), 1, [0]),
+ ],
+)
+def test_frange(r, size, elements):
+ assert len(r) == size
+ assert list(r) == elements
+
+
+@pytest.mark.parametrize(
+ "shape, expected",
+ [((), [()])],
+)
+def test_ndindex(shape, expected):
+ assert list(sh.ndindex(shape)) == expected
+
+
+@pytest.mark.parametrize(
+ "shape, axis, expected",
+ [
+ ((1,), 0, [(slice(None, None),)]),
+ ((1, 2), 0, [(slice(None, None), slice(None, None))]),
+ (
+ (2, 4),
+ 1,
+ [(0, slice(None, None)), (1, slice(None, None))],
+ ),
+ ],
+)
+def test_axis_ndindex(shape, axis, expected):
+ assert list(sh.axis_ndindex(shape, axis)) == expected
+
+
+@pytest.mark.parametrize(
+ "shape, axes, expected",
+ [
+ ((), (), [[()]]),
+ ((1,), (0,), [[(0,)]]),
+ (
+ (2, 2),
+ (0,),
+ [
+ [(0, 0), (1, 0)],
+ [(0, 1), (1, 1)],
+ ],
+ ),
+ ],
+)
+def test_axes_ndindex(shape, axes, expected):
+ assert list(sh.axes_ndindex(shape, axes)) == expected
+
+
+@pytest.mark.parametrize(
+ "shape, shifts, axes, expected",
+ [
+ ((1, 1), (0,), (0,), [(0, 0)]),
+ ((2, 1), (1, 1), (0, 1), [(1, 0), (0, 0)]),
+ ((2, 2), (1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]),
+ ((2, 2), (-1, 1), (0, 1), [(1, 1), (1, 0), (0, 1), (0, 0)]),
+ ],
+)
+def test_roll_ndindex(shape, shifts, axes, expected):
+ assert list(roll_ndindex(shape, shifts, axes)) == expected
+
+
+@pytest.mark.parametrize(
+ "idx, expected",
+ [
+ ((), "x"),
+ (42, "x[42]"),
+ ((42,), "x[42]"),
+ ((42, 7), "x[42, 7]"),
+ (slice(None, 2), "x[:2]"),
+ (slice(2, None), "x[2:]"),
+ (slice(0, 2), "x[0:2]"),
+ (slice(0, 2, -1), "x[0:2:-1]"),
+ (slice(None, None, -1), "x[::-1]"),
+ (slice(None, None), "x[:]"),
+ (..., "x[...]"),
+ ((None, 42), "x[None, 42]"),
+ ],
+)
+def test_fmt_idx(idx, expected):
+ assert sh.fmt_idx("x", idx) == expected
+
+
+@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes())
+def test_int_to_dtype(x, dtype):
+ with hh.reject_overflow():
+ d = xp.asarray(x, dtype=dtype)
+ assert mock_int_dtype(x, dtype) == d
+
+
+@given(hh.oneway_promotable_dtypes(dh.all_dtypes))
+def test_oneway_promotable_dtypes(D):
+ assert D.result_dtype == dh.result_type(*D)
+
+
+@given(hh.oneway_broadcastable_shapes())
+def test_oneway_broadcastable_shapes(S):
+ assert S.result_shape == sh.broadcast_shapes(*S)
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 00000000..935c67fe
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,8 @@
+[pytest]
+filterwarnings =
+ # Ignore floating-point warnings from NumPy
+ ignore:invalid value encountered in:RuntimeWarning
+ ignore:overflow encountered in:RuntimeWarning
+ ignore:divide by zero encountered in:RuntimeWarning
+
+
diff --git a/reporting.py b/reporting.py
new file mode 100644
index 00000000..579aa211
--- /dev/null
+++ b/reporting.py
@@ -0,0 +1,111 @@
+from array_api_tests.dtype_helpers import dtype_to_name
+from array_api_tests import _array_module as xp
+from array_api_tests import __version__
+
+from collections import Counter
+from types import BuiltinFunctionType, FunctionType
+import dataclasses
+import json
+import warnings
+
+from hypothesis.strategies import SearchStrategy
+
+from pytest import hookimpl, fixture
+try:
+ import pytest_jsonreport # noqa
+except ImportError:
+ raise ImportError("pytest-json-report is required to run the array API tests")
+
+def to_json_serializable(o):
+ if o in dtype_to_name:
+ return dtype_to_name[o]
+ if isinstance(o, (BuiltinFunctionType, FunctionType, type)):
+ return o.__name__
+ if dataclasses.is_dataclass(o):
+ return to_json_serializable(dataclasses.asdict(o))
+ if isinstance(o, SearchStrategy):
+ return repr(o)
+ if isinstance(o, dict):
+ return {to_json_serializable(k): to_json_serializable(v) for k, v in o.items()}
+ if isinstance(o, tuple):
+ if hasattr(o, '_asdict'): # namedtuple
+ return to_json_serializable(o._asdict())
+ return tuple(to_json_serializable(i) for i in o)
+ if isinstance(o, list):
+ return [to_json_serializable(i) for i in o]
+ if callable(o):
+ return repr(o)
+
+ # Ensure everything is JSON serializable. If this warning is issued, it
+ # means the given type needs to be added above if possible.
+ try:
+ json.dumps(o)
+ except TypeError:
+ warnings.warn(f"{o!r} (of type {type(o)}) is not JSON-serializable. Using the repr instead.")
+ return repr(o)
+
+ return o
+
+@hookimpl(optionalhook=True)
+def pytest_metadata(metadata):
+ """
+ Additional global metadata for --json-report.
+ """
+ metadata['array_api_tests_module'] = xp.__name__
+ metadata['array_api_tests_version'] = __version__
+
+@fixture(autouse=True)
+def add_extra_json_metadata(request, json_metadata):
+ """
+ Additional per-test metadata for --json-report
+ """
+ def add_metadata(name, obj):
+ obj = to_json_serializable(obj)
+ json_metadata[name] = obj
+
+ test_module = request.module.__name__
+ if test_module.startswith('array_api_tests.meta'):
+ return
+
+ test_function = request.function.__name__
+ assert test_function.startswith('test_'), 'unexpected test function name'
+
+ if test_module == 'array_api_tests.test_has_names':
+ array_api_function_name = None
+ else:
+ array_api_function_name = test_function[len('test_'):]
+
+ add_metadata('test_module', test_module)
+ add_metadata('test_function', test_function)
+ add_metadata('array_api_function_name', array_api_function_name)
+
+ if hasattr(request.node, 'callspec'):
+ params = request.node.callspec.params
+ add_metadata('params', params)
+
+ def finalizer():
+ # TODO: This metadata is all in the form of error strings. It might be
+ # nice to extract the hypothesis failing inputs directly somehow.
+ if hasattr(request.node, 'hypothesis_report_information'):
+ add_metadata('hypothesis_report_information', request.node.hypothesis_report_information)
+ if hasattr(request.node, 'hypothesis_statistics'):
+ add_metadata('hypothesis_statistics', request.node.hypothesis_statistics)
+
+ request.addfinalizer(finalizer)
+
+@hookimpl(optionalhook=True)
+def pytest_json_modifyreport(json_report):
+ # Deduplicate warnings. These duplicate warnings can cause the file size
+ # to become huge. For instance, a warning from np.bool which is emitted
+ # every time hypothesis runs (over a million times) causes the warnings
+ # JSON for a plain numpy namespace run to be over 500MB.
+
+ # This will lose information about what order the warnings were issued in,
+ # but that isn't particularly helpful anyway since the warning metadata
+ # doesn't store a full stack of where it was issued from. The resulting
+ # warnings will be in order of the first time each warning is issued since
+ # collections.Counter is ordered just like dict().
+ counted_warnings = Counter([frozenset(i.items()) for i in json_report.get('warnings', dict())])
+ deduped_warnings = [{**dict(i), 'count': counted_warnings[i]} for i in counted_warnings]
+
+ json_report['warnings'] = deduped_warnings
diff --git a/requirements.txt b/requirements.txt
index caeab0aa..c5508119 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
pytest
-hypothesis
-regex
+pytest-json-report
+hypothesis>=6.130.5
+ndindex>=1.8
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 00000000..2549aa6d
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,12 @@
+
+# See the docstring in versioneer.py for instructions. Note that you must
+# re-run 'versioneer.py setup' after changing this section, and commit the
+# resulting files.
+
+[versioneer]
+VCS = git
+style = pep440
+versionfile_source = array_api_tests/_version.py
+versionfile_build = array_api_tests/_version.py
+tag_prefix =
+parentdir_prefix =