8000 Add typing to test_plot.py by Illviljan · Pull Request #8889 · pydata/xarray · GitHub
[go: up one dir, main page]

Skip to content

Add typing to test_plot.py #8889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ module = [
"xarray.tests.test_merge",
"xarray.tests.test_missing",
"xarray.tests.test_parallelcompat",
"xarray.tests.test_plot",
"xarray.tests.test_sparse",
"xarray.tests.test_ufuncs",
"xarray.tests.test_units",
Expand Down
72 changes: 40 additions & 32 deletions xarray/tests/test_plot.py
A3E2
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import contextlib
import inspect
import math
from collections.abc import Hashable
from collections.abc import Generator, Hashable
from copy import copy
from datetime import date, datetime, timedelta
from typing import Any, Callable, Literal
Expand Down Expand Up @@ -85,52 +85,54 @@ def test_all_figures_closed():

@pytest.mark.flaky
@pytest.mark.skip(reason="maybe flaky")
def text_in_fig():
def text_in_fig() -> set[str]:
"""
Return the set of all text in the figure
"""
return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)}
return {t.get_text() for t in plt.gcf().findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?


def find_possible_colorbars():
def find_possible_colorbars() -> list[mpl.collections.QuadMesh]:
# nb. this function also matches meshes from pcolormesh
return plt.gcf().findobj(mpl.collections.QuadMesh)
return plt.gcf().findobj(mpl.collections.QuadMesh) # type: ignore[return-value] # mpl error?
Copy link
Contributor Author
@Illviljan Illviljan Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the typing for findobj is wrong. Looks to me it mostly returns the same type as the input.
https://github.com/matplotlib/matplotlib/blob/78b6bdc04f89dafe4d157855a9023826cab8a0fd/lib/matplotlib/artist.pyi#L131-L135
Probably should be something like TypeVar("T_Artist", bound=Artist).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ksunden if you're still interested.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thank you @ksunden !



def substring_in_axes(substring, ax):
def substring_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
"""
Return True if a substring is found anywhere in an axes
"""
alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)}
alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
for txt in alltxt:
if substring in txt:
return True
return False


def substring_not_in_axes(substring, ax):
def substring_not_in_axes(substring: str, ax: mpl.axes.Axes) -> bool:
"""
Return True if a substring is not found anywhere in an axes
"""
alltxt = {t.get_text() for t in ax.findobj(mpl.text.Text)}
alltxt: set[str] = {t.get_text() for t in ax.findobj(mpl.text.Text)} # type: ignore[attr-defined] # mpl error?
check = [(substring not in txt) for txt in alltxt]
return all(check)


def property_in_axes_text(property, property_str, target_txt, ax):
def property_in_axes_text(
property, property_str, target_txt, ax: mpl.axes.Axes
) -> bool:
"""
Return True if the specified text in an axes
has the property assigned to property_str
"""
alltxt = ax.findobj(mpl.text.Text)
alltxt: list[mpl.text.Text] = ax.findobj(mpl.text.Text) # type: ignore[assignment]
check = []
for t in alltxt:
if t.get_text() == target_txt:
check.append(plt.getp(t, property) == property_str)
return all(check)


def easy_array(shape, start=0, stop=1):
def easy_array(shape: tuple[int, ...], start: float = 0, stop: float = 1) -> np.ndarray:
"""
Make an array with desired shape using np.linspace

Expand All @@ -140,7 +142,7 @@ def easy_array(shape, start=0, stop=1):
return a.reshape(shape)


def get_colorbar_label(colorbar):
def get_colorbar_label(colorbar) -> str:
if colorbar.orientation == "vertical":
return colorbar.ax.get_ylabel()
else:
Expand All @@ -150,27 +152,27 @@ def get_colorbar_label(colorbar):
@requires_matplotlib
class PlotTestCase:
@pytest.fixture(autouse=True)
def setup(self):
def setup(self) -> Generator:
yield
# Remove all matplotlib figures
plt.close("all")

def pass_in_axis(self, plotmethod, subplot_kw=None):
def pass_in_axis(self, plotmethod, subplot_kw=None) -> None:
fig, axs = plt.subplots(ncols=2, subplot_kw=subplot_kw)
plotmethod(ax=axs[0])
assert axs[0].has_data()

@pytest.mark.slow
def imshow_called(self, plotmethod):
def imshow_called(self, plotmethod) -> bool:
plotmethod()
images = plt.gca().findobj(mpl.image.AxesImage)
return len(images) > 0

def contourf_called(self, plotmethod):
def contourf_called(self, plotmethod) -> bool:
plotmethod()

# Compatible with mpl before (PathCollection) and after (QuadContourSet) 3.8
def matchfunc(x):
def matchfunc(x) -> bool:
return isinstance(
x, (mpl.collections.PathCollection, mpl.contour.QuadContourSet)
)
Expand Down Expand Up @@ -1248,14 +1250,16 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self) -> None:
def test_discrete_colormap_provided_boundary_norm(self) -> None:
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
primitive = self.darray.plot.contourf(norm=norm)
np.testing.assert_allclose(primitive.levels, norm.boundaries)
np.testing.assert_allclose(list(primitive.levels), norm.boundaries)

def test_discrete_colormap_provided_boundary_norm_matching_cmap_levels(
self,
) -> None:
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
primitive = self.darray.plot.contourf(norm=norm)
assert primitive.colorbar.norm.Ncmap == primitive.colorbar.norm.N
cbar = primitive.colorbar
assert cbar is not None
assert cbar.norm.Ncmap == cbar.norm.N # type: ignore[attr-defined] # Exists, debatable if public though.


class Common2dMixin:
Expand Down Expand Up @@ -2532,7 +2536,7 @@ def test_default_labels(self) -> None:

# Leftmost column should have array name
for ax in g.axs[:, 0]:
assert substring_in_axes(self.darray.name, ax)
assert substring_in_axes(str(self.darray.name), ax)

def test_test_empty_cell(self) -> None:
g = (
Expand Down Expand Up @@ -2635,7 +2639,7 @@ def test_facetgrid(self) -> None:
(True, "continuous", False, True),
],
)
def test_add_guide(self, add_guide, hue_style, legend, colorbar):
def test_add_guide(self, add_guide, hue_style, legend, colorbar) -> None:
meta_data = _infer_meta_data(
self.ds,
x="x",
Expand Down Expand Up @@ -2811,7 +2815,7 @@ def test_bad_args(
add_legend: bool | None,
add_colorbar: bool | None,
error_type: type[Exception],
):
) -> None:
with pytest.raises(error_type):
self.ds.plot.scatter(
x=x, y=y, hue=hue, add_legend=add_legend, add_colorbar=add_colorbar
Expand Down Expand Up @@ -3011,20 +3015,22 @@ def test_ncaxis_notinstalled_line_plot(self) -> None:
@requires_matplotlib
class TestAxesKwargs:
@pytest.fixture(params=[1, 2, 3])
def data_array(self, request):
def data_array(self, request) -> DataArray:
"""
Return a simple DataArray
"""
dims = request.param
if dims == 1:
return DataArray(easy_array((10,)))
if dims == 2:
elif dims == 2:
return DataArray(easy_array((10, 3)))
if dims == 3:
elif dims == 3:
return DataArray(easy_array((10, 3, 2)))
else:
raise ValueError(f"No DataArray implemented for {dims=}.")

@pytest.fixture(params=[1, 2])
def data_array_logspaced(self, request):
def data_array_logspaced(self, request) -> DataArray:
"""
Return a simple DataArray with logspaced coordinates
"""
Expand All @@ -3033,12 +3039,14 @@ def data_array_logspaced(self, request):
return DataArray(
np.arange(7), dims=("x",), coords={"x": np.logspace(-3, 3, 7)}
)
if dims == 2:
elif dims == 2:
return DataArray(
np.arange(16).reshape(4, 4),
dims=("y", "x"),
coords={"x": np.logspace(-1, 2, 4), "y": np.logspace(-5, -1, 4)},
)
else:
raise ValueError(f"No DataArray implemented for {dims=}.")

@pytest.mark.parametrize("xincrease", [True, False])
def test_xincrease_kwarg(self, data_array, xincrease) -> None:
Expand Down Expand Up @@ -3146,16 +3154,16 @@ def test_facetgrid_single_contour() -> None:


@requires_matplotlib
def test_get_axis_raises():
def test_get_axis_raises() -> None:
# test get_axis raises an error if trying to do invalid things

# cannot provide both ax and figsize
with pytest.raises(ValueError, match="both `figsize` and `ax`"):
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something")
get_axis(figsize=[4, 4], size=None, aspect=None, ax="something") # type: ignore[arg-type]

# cannot provide both ax and size
with pytest.raises(ValueError, match="both `size` and `ax`"):
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something")
get_axis(figsize=None, size=200, aspect=4 / 3, ax="something") # type: ignore[arg-type]

# cannot provide both size and figsize
with pytest.raises(ValueError, match="both `figsize` and `size`"):
Expand All @@ -3167,7 +3175,7 @@ def test_get_axis_raises():

# cannot provide axis and subplot_kws
with pytest.raises(ValueError, match="cannot use subplot_kws with existing ax"):
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5)
get_axis(figsize=None, size=None, aspect=None, ax=1, something_else=5) # type: ignore[arg-type]


@requires_matplotlib
Expand Down
0