8000 [Bug]: ValueError when plotting 2D pytorch tensor using matplotlib==3.8.0 · Issue #26806 · matplotlib/matplotlib · GitHub
[go: up one dir, main page]

Skip to content
[Bug]: ValueError when plotting 2D pytorch tensor using matplotlib==3.8.0 #26806
Closed
@CMGeldenhuys

Description

@CMGeldenhuys

Bug summary

Value error occurs when trying to plot a 2D pytorch tensor using matplotlib==3.8.0. The error does not arise in matplotlib==3.7.3.

Code for reproduction

# Using matplotlib==3.8.0
>>> import torch
>>> import matplotlib as mplt
>>> mplt.__version__
'3.8.0'
>>> import matplotlib.pyplot as plt
>>> a = torch.randn(185,5)
>>> plt.plot(a)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../lib/python3.11/site-packages/matplotlib/pyplot.py", line 3578, in plot
    return gca().plot(
           ^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/axes/_axes.py", line 1721, in plot
    lines = [*self._get_lines(self, *args, data=data, **kwargs)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/axes/_base.py", line 303, in __call__
    yield from self._plot_args(
               ^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/axes/_base.py", line 496, in _plot_args
    axes.yaxis.update_units(y)
  File ".../lib/python3.11/site-packages/matplotlib/axis.py", line 1706, in update_units
    converter = munits.registry.get_converter(data)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/units.py", line 183, in get_converter
    first = cbook._safe_first_finite(x)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/cbook.py", line 1730, in _safe_first_finite
    if safe_isfinite(val):
       ^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/matplotlib/cbook.py", line 1699, in safe_isfinite
    return math.isfinite(val)
           ^^^^^^^^^^^^^^^^^^
ValueError: only one element tensors can be converted to Python scalars

>>> b = torch.randn(185)
>>> plt.plot(b)
[<matplotlib.lines.Line2D object at 0x7f27d7ccc190>]

# Using matplotlib==3.7.3
>>> import torch
>>> import matplotlib as mplt
>>> mplt.__version__
'3.7.3'
>>> import matplotlib.pyplot as plt
>>> a = torch.randn(185,5)
>>> plt.plot(a)
[<matplotlib.lines.Line2D object at 0x7f19762d7910>, <matplotlib.lines.Line2D object at 0x7f1976684250>, <matplotlib.lines.Line2D object at 0x7f19764d5150>, <matplotlib.lines.Line2D object at 0x7f197598f9d0>, <matplotlib.lines.Line2D object at 0x7f19762bff50>]

Actual outcome

(included REPL output in above example)

Expected outcome

Expect 5 line series (one for each of the second dimension of the tensor)

Additional information

The bug seems to occur in version 3.8.0. (In both cases I was using torch==2.0.1)

Operating system

Ubuntu

Matplotlib Version

3.8.0 and 3.7.3

Matplotlib Backend

TkAgg

Python version

Python 3.11.5

Jupyter version

N/A

Installation

pip

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0