Implement to_numpy
method to speed up matplotlib with PyTorch arrays
#101795
Labels
enhancement
Not as big of a feature, but technically not a bug. Should be easy to fix
module: numpy
Related to numpy support, and also numpy compatibility of our operators
needs research
We need to decide whether or not this merits inclusion, based on research world
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
🚀 The feature, motivation and pitch
Hi,
As discussed in this issue and corresponding PR on
matplotlib
, PyTorch arrays can be significantly slow when used directly withmatplotlib
. This is becausematplotlib
has no easy way to convert PyTorch arrays to NumPy arrays before plotting and thus it expects other libraries to haveto_numpy()
method. I thinkto_numpy()
implementation in PyTorch would be useful for the PyTorch users who might be using PyTorch arrays directly withmatplotlib
without knowing that it can be too slow.Alternatives
matplotlib
PR, we considered adding a specific check for inputs of typetorch.Tensor
and then convert it to numpy using.numpy()
method but adding a string based check does not seem a good idea.__array__
method to convert both JAX and PyTorch arrays to NumPy but it does not work well with some other objects having__array__
method.Additional context
Here is the code to reproduce the plotting delay issue:
I am open to a diverse set of suggestions to fix this issue.
cc @mruberry @rgommers
The text was updated successfully, but these errors were encountered: