8000 Add plotting module with heatmap function by amueller · Pull Request #8082 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add plotting module with heatmap function #8082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

amueller
Copy link
Member
@amueller amueller commented Dec 19, 2016

This adds a plotting module and a first function to plot heatmaps with values inside.
This is a slight generalization of the confusion matrix plot.

One of the improvements is that the color of the text takes the colormap into account, which is actually somewhat tricky to do as you can see.

Here's what the grid-search plot looks before
grid_before
and after:
grid_new

I could add another keyword for the direction of the y-axis but I'm not sure if that's overkill. The current direction makes sense for confusion matrices, with the origin in the top right.

Slight caveat: I'm not sure it's easy to add a colorbar to a plot like this.
I think that's less important given how explicit the plot is, though.


print(title)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure if this function is still helpful for copy & paste. It mostly labels the axis and does normalization, which might be helpful... I left it for now.

ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

ax.set_xlim(0, values.shape[1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this is necessary, but otherwise the grid-search plot adds another row / column

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plt.pcolor(np.random.uniform(size=(13, 13)), snap=True) gives a 14x14 plot... huh

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be necessary, or at least on master, it isn't necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not necessary on 2.0.0rc2, so we only have to support it for like 8 1/2 more years or something (requirements are two ubuntu LTS releases ago)

@amueller
Copy link
Member Author

also pcolor and pcolormesh seem to have different opinions on where the center of a pixel / square is, so I'm using the slower pcolor for now.

8000
try:
import matplotlib
except ImportError:
raise SkipTest("Not testing plot_heatmap, matplotlib not installed.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be module-level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would that work with nose?

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial thoughts:

How do we name and organise these things? How specific should they be to ML? to scikit-learn? really this is just a plot for a continuous function of two categorical inputs. Is that essentially identical to a heatmap?

Do we tie these more closely to API and applications, e.g. by taking a cv_results_ as input, together with a pair of parameter names?

How do we scale up this interface to the case where we want to produce a matrix of all bivariate and univariate parameter-score plots for some grid search?

Also, every plot_ function should be used in at least one example.

@amueller
Copy link
Member Author
amueller commented Dec 20, 2016 via email

@amueller
Copy link
Member Author

also, working for general cv_results_ is kinda tricky. I think you remember my attempt a couple of years ago. I think for the book I have a function that handles 1d and 2d and barfs otherwise. We could do that, too. That would already be useful.

@amueller
Copy link
Member Author

Any further thoughts? I think having more "high-level" functions is probably good. There was just a submission to scikit-learn-contrib that did something similar.
@GaelVaroquaux any thoughts?

@jnothman
Copy link
Member
jnothman commented Jan 10, 2017 via email

@amueller
Copy link
Member Author

I think actually I want more integrated functions, so I'd add a function plotting 2d grid search and one for confusion matrix.

@amueller
Copy link
Member Author

note to self: the alignment is messed up, it should use pcolormesh, I need to double check the flipping of the axes and vmin and vmax are not passed along

@thismlguy
Copy link
Contributor

Continued here - #9173. Please close this PR.

@NelleV
Copy link
Member
NelleV commented Aug 11, 2017

Closing this pull request.

@NelleV NelleV closed this Aug 11, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0