8000 Refactoring: Removing axis parameter from scales by vagnermcj · Pull Request #29988 · matplotlib/matplotlib · GitHub
[go: up one dir, main page]

Skip to content

Refactoring: Removing axis parameter from scales #29988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions lib/matplotlib/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import inspect
import textwrap
from functools import wraps

import numpy as np

Expand Down Expand Up @@ -103,14 +104,58 @@
return vmin, vmax


def handle_axis_parameter(init_func):
"""
Decorator to handle the optional *axis* parameter in scale constructors.

This decorator ensures backward compatibility for scale classes that
previously required an *axis* parameter. It allows constructors to work
seamlessly with or without the *axis* parameter.

Parameters
----------
init_func : callable
The original __init__ method of a scale class.

Returns
-------
callable
A wrapped version of *init_func* that handles the optional *axis*.

Notes
-----
If the wrapped constructor defines *axis* as its first argument, the
parameter is preserved when present. Otherwise, the value `None` is injected
as the first argument.

Examples
--------
>>> from matplotlib.scale import ScaleBase
>>> class CustomScale(ScaleBase):
... @handle_axis_parameter
... def __init__(self, axis, custom_param=1):
... self.custom_param = custom_param
"""
@wraps(init_func)
def wrapper(self, *args, **kwargs):
if args and isinstance(args[0], mpl.axis.Axis):
return init_func(self, *args, **kwargs)
else:
# Remove 'axis' from kwargs to avoid double assignment
kwargs.pop('axis', None)
return init_func(self, None, *args, **kwargs)
return wrapper


class LinearScale(ScaleBase):
"""
The default linear scale.
"""

name = 'linear'

def __init__(self, axis):
@handle_axis_parameter
def __init__(self, axis=None):
# This method is present only to prevent inheritance of the base class'
# constructor docstring, which would otherwise end up interpolated into
# the docstring of Axis.set_scale.
Expand Down Expand Up @@ -180,6 +225,7 @@

name = 'function'

@handle_axis_parameter
def __init__(self, axis, functions):
"""
Parameters
Expand Down Expand Up @@ -279,7 +325,8 @@
"""
name = 'log'

def __init__(self, axis, *, base=10, subs=None, nonpositive="clip"):
@handle_axis_parameter
def __init__(self, axis=None, *, base=10, subs=None, nonpositive="clip"):
"""
Parameters
----------
Expand Down Expand Up @@ -330,6 +377,7 @@

name = 'functionlog'

@handle_axis_parameter
def __init__(self, axis, functions, base=10):
"""
Parameters
Expand Down Expand Up @@ -455,7 +503,8 @@
"""
name = 'symlog'

def __init__(self, axis, *, base=10, linthresh=2, subs=None, linscale=1):
@handle_axis_parameter
def __init__(self, axis=None, *, base=10, linthresh=2, subs=None, linscale=1):
self._transform = SymmetricalLogTransform(base, linthresh, linscale)
self.subs = subs

Expand Down Expand Up @@ -547,7 +596,8 @@
1024: (256, 512)
}

def __init__(self, axis, *, linear_width=1.0,
@handle_axis_parameter
def __init__(self, axis=None, *, linear_width=1.0,
base=10, subs='auto', **kwargs):
"""
Parameters
Expand Down Expand Up @@ -645,7 +695,8 @@
"""
name = 'logit'

def __init__(self, axis, nonpositive='mask', *,
@handle_axis_parameter
def __init__(self, axis=None, nonpositive='mask', *,
one_half=r"\frac{1}{2}", use_overline=False):
r"""
Parameters
Expand Down Expand Up @@ -725,7 +776,12 @@
axis : `~matplotlib.axis.Axis`
"""
scale_cls = _api.check_getitem(_scale_mapping, scale=scale)
return scale_cls(axis, **kwargs)
try:
return scale_cls(axis, **kwargs)
except TypeError as e:
if 'unexpected keyword argument' in str(e) or 'positional argument' in str(e):
return scale_cls(**kwargs)
raise

Check warning on line 784 in lib/matplotlib/scale.py

View check run for this annotation

Codecov / codecov/patch

lib/matplotlib/scale.py#L784

Added line #L784 was not covered by tests


if scale_factory.__doc__:
Expand Down
38 changes: 20 additions & 18 deletions lib/matplotlib/scale.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from matplotlib.axis import Axis
from matplotlib.transforms import Transform

from collections.abc import Callable, Iterable
from typing import Literal
from typing import Literal, Union
from numpy.typing import ArrayLike

class ScaleBase:
Expand All @@ -15,6 +15,7 @@ class ScaleBase:

class LinearScale(ScaleBase):
name: str
def __init__(self: ScaleBase, axis: Union[Axis, None] = None) -> None: ...

class FuncTransform(Transform):
input_dims: int
Expand Down Expand Up @@ -56,12 +57,12 @@ class LogScale(ScaleBase):
name: str
subs: Iterable[int] | None
def __init__(
self,
axis: Axis | None,
self: LogScale,
axis: Union[Axis, None] = None,
*,
base: float = ...,
subs: Iterable[int] | None = ...,
nonpositive: Literal["clip", "mask"] = ...
base: float = 10,
subs: Union[Iterable[int], None] = None,
nonpositive: Union[Literal['clip'], Literal['mask']] = 'clip'
) -> None: ...
@property
def base(self) -> float: ...
Expand Down Expand Up @@ -103,13 +104,13 @@ class SymmetricalLogScale(ScaleBase):
name: str
subs: Iterable[int] | None
def __init__(
self,
axis: Axis | None,
self: SymmetricalLogScale,
axis: Union[Axis, None] = None,
*,
base: float = ...,
linthresh: float = ...,
subs: Iterable[int] | None = ...,
linscale: float = ...
base: float = 10,
linthresh: float = 2,
subs: Union[Iterable[int], None] = None,
linscale: float = 1
) -> None: ...
@property
def base(self) -> float: ...
Expand Down Expand Up @@ -138,7 +139,7 @@ class AsinhScale(ScaleBase):
auto_tick_multipliers: dict[int, tuple[int, ...]]
def __init__(
self,
axis: Axis | None,
axis: Union[Axis, None] = None,
*,
linear_width: float = ...,
base: float = ...,
Expand All @@ -164,15 +165,16 @@ class LogisticTransform(Transform):
class LogitScale(ScaleBase):
name: str
def __init__(
self,
axis: Axis | None,
nonpositive: Literal["mask", "clip"] = ...,
self: LogitScale,
axis: Union[Axis, None] = None,
nonpositive: Union[Literal['mask'], Literal['clip']] = 'mask',
*,
one_half: str = ...,
use_overline: bool = ...
one_half: str = '\\frac{1}{2}',
use_overline: bool = False
) -> None: ...
def get_transform(self) -> LogitTransform: ...

def get_scale_names() -> list[str]: ...
def scale_factory(scale: str, axis: Axis, **kwargs) -> ScaleBase: ...
def register_scale(scale_class: type[ScaleBase]) -> None: ...
def handle_axis_parameter(init_func: Callable[..., None]) -> Callable[..., None]: ...
Loading
0