8000 Add 3D scatter plots, allow more than one 3d plot on the same figure … · chhlab/matplotlib-cpp@9d19657 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d19657

Browse files
Baggins800Benno Evers
authored andcommitted
Add 3D scatter plots, allow more than one 3d plot on the same figure and make rcparams changeable.
1 parent 80bc9cd commit 9d19657

File tree

1 file changed

+167
-18
lines changed

1 file changed

+167
-18
lines changed

matplotlibcpp.h

Lines changed: 167 additions & 18 deletions
< A3E2 tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ struct _interpreter {
9999
PyObject *s_python_function_barh;
100100
PyObject *s_python_function_colorbar;
101101
PyObject *s_python_function_subplots_adjust;
102+
PyObject *s_python_function_rcparams;
102103

103104

104105
/* For now, _interpreter is implemented as a singleton since its currently not possible to have
@@ -189,6 +190,7 @@ struct _interpreter {
189190
}
190191

191192
PyObject* matplotlib = PyImport_Import(matplotlibname);
193+
192194
Py_DECREF(matplotlibname);
193195
if (!matplotlib) {
194196
PyErr_Print();
@@ -201,6 +203,8 @@ struct _interpreter {
201203
PyObject_CallMethod(matplotlib, const_cast<char*>("use"), const_cast<char*>("s"), s_backend.c_str());
202204
}
203205

206+
207+
204208
PyObject* pymod = PyImport_Import(pyplotname);
205209
Py_DECREF(pyplotname);
206210
if (!pymod) { throw std::runtime_error("Error loading module matplotlib.pyplot!"); }
@@ -264,6 +268,7 @@ struct _interpreter {
264268
s_python_function_barh = safe_import(pymod, "barh");
265269
s_python_function_colorbar = PyObject_GetAttrString(pymod, "colorbar");
266270
s_python_function_subplots_adjust = safe_import(pymod,"subplots_adjust");
271+
s_python_function_rcparams = PyObject_GetAttrString(pymod, "rcParams");
267272
#ifndef WITHOUT_NUMPY
268273
s_python_function_imshow = safe_import(pymod, "imshow");
269274
#endif
@@ -464,6 +469,7 @@ template <typename Numeric>
464469
void plot_surface(const std::vector<::std::vector<Numeric>> &x,
465470
const std::vector<::std::vector<Numeric>> &y,
466471
const std::vector<::std::vector<Numeric>> &z,
472+
const long fig_number=0,
467473
const std::map<std::string, std::string> &keywords =
468474
std::map<std::string, std::string>())
469475
{
@@ -516,14 +522,29 @@ void plot_surface(const std::vector<::std::vector<Numeric>> &x,
516522

517523
for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
518524
it != keywords.end(); ++it) {
519-
PyDict_SetItemString(kwargs, it->first.c_str(),
520-
PyString_FromString(it->second.c_str()));
525+
if (it->first == "linewidth" || it->first == "alpha") {
526+
PyDict_SetItemString(kwargs, it->first.c_str(),
527+
PyFloat_FromDouble(std::stod(it->second)));
528+
} else {
529+
PyDict_SetItemString(kwargs, it->first.c_str(),
530+
PyString_FromString(it->second.c_str()));
531+
}
521532
}
522533

523-
524-
PyObject *fig =
525-
PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
526-
detail::_interpreter::get().s_python_empty_tuple);
534+
PyObject *fig_args = PyTuple_New(1);
535+
PyObject* fig = nullptr;
8000 536+
PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
537+
PyObject *fig_exists =
538+
PyObject_CallObject(
539+
detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
540+
if (!PyObject_IsTrue(fig_exists)) {
541+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
542+
detail::_interpreter::get().s_python_empty_tuple);
543+
} else {
544+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
545+
fig_args);
546+
}
547+
Py_DECREF(fig_exists);
527548
if (!fig) throw std::runtime_error("Call to figure() failed.");
528549

529550
PyObject *gca_kwargs = PyDict_New();
@@ -559,6 +580,7 @@ template <typename Numeric>
559580
void plot3(const std::vector<Numeric> &x,
560581
const std::vector<Numeric> &y,
561582
const std::vector<Numeric> &z,
583+
const long fig_number=0,
562584
const std::map<std::string, std::string> &keywords =
563585
std::map<std::string, std::string>())
564586
{
@@ -607,9 +629,18 @@ void plot3(const std::vector<Numeric> &x,
607629
PyString_FromString(it->second.c_str()));
608630
}
609631

610-
PyObject *fig =
611-
PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
612-
detail::_interpreter::get().s_python_empty_tuple);
632+
PyObject *fig_args = PyTuple_New(1);
633+
PyObject* fig = nullptr;
634+
PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
635+
PyObject *fig_exists =
636+
PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
637+
if (!PyObject_IsTrue(fig_exists)) {
638+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
639+
detail::_interpreter::get().s_python_empty_tuple);
640+
} else {
641+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
642+
fig_args);
643+
}
613644
if (!fig) throw std::runtime_error("Call to figure() failed.");
614645

615646
PyObject *gca_kwargs = PyDict_New();
@@ -911,6 +942,103 @@ bool scatter(const std::vector<NumericX>& x,
911942
return res;
912943
}
913944

945+
template<typename NumericX, typename NumericY, typename NumericZ>
946+
bool scatter(const std::vector<NumericX>& x,
947+
const std::vector<NumericY>& y,
948+
const std::vector<NumericZ>& z,
949+
const double s=1.0, // The marker size in points**2
950+
const long fig_number=0,
951+
const std::map<std::string, std::string> & keywords = {}) {
952+
detail::_interpreter::get();
953+
954+
// Same as with plot_surface: We lazily load the modules here the first time
955+
// this function is called because I'm not sure that we can assume "matplotlib
956+
// installed" implies "mpl_toolkits installed" on all platforms, and we don't
957+
// want to require it for people who don't need 3d plots.
958+
static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr;
959+
if (!mpl_toolkitsmod) {
960+
detail::_interpreter::get();
961+
962+
PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits");
963+
PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d");
964+
if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); }
965+
966+
mpl_toolkitsmod = PyImport_Import(mpl_toolkits);
967+
Py_DECREF(mpl_toolkits);
968+
if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); }
969+
970+
axis3dmod = PyImport_Import(axis3d);
971+
Py_DECREF(axis3d);
972+
if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); }
973+
}
974+
975+
assert(x.size() == y.size());
976+
assert(y.size() == z.size());
977+
978+
PyObject *xarray = detail::get_array(x);
979+
PyObject *yarray = detail::get_array(y);
980+
PyObject *zarray = detail::get_array(z);
981+
982+
// construct positional args
983+
PyObject *args = PyTuple_New(3);
984+
PyTuple_SetItem(args, 0, xarray);
985+
PyTuple_SetItem(args, 1, yarray);
986+
PyTuple_SetItem(args, 2, zarray);
987+
988+
// Build up the kw args.
989+
PyObject *kwargs = PyDict_New();
990+
991+
for (std::map<std::string, std::string>::const_iterator it = keywords.begin();
992+
it != keywords.end(); ++it) {
993+
PyDict_SetItemString(kwargs, it->first.c_str(),
994+
PyString_FromString(it->second.c_str()));
995+
}
996+
PyObject *fig_args = PyTuple_New(1);
997+
PyObject* fig = nullptr;
998+
PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number));
999+
PyObject *fig_exists =
1000+
PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args);
1001+
if (!PyObject_IsTrue(fig_exists)) {
1002+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
1003+
detail::_interpreter::get().s_python_empty_tuple);
1004+
} else {
1005+
fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure,
1006+
fig_args);
1007+
}
1008+
Py_DECREF(fig_exists);
1009+
if (!fig) throw std::runtime_error("Call to figure() failed.");
1010+
1011+
PyObject *gca_kwargs = PyDict_New();
1012+
PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d"));
1013+
1014+
PyObject *gca = PyObject_GetAttrString(fig, "gca");
1015+
if (!gca) throw std::runtime_error("No gca");
1016+
Py_INCREF(gca);
1017+
PyObject *axis = PyObject_Call(
1018+
gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs);
1019+
1020+
if (!axis) throw std::runtime_error("No axis");
1021+
Py_INCREF(axis);
1022+
1023+
Py_DECREF(gca);
1024+
Py_DECREF(gca_kwargs);
1025+
1026+
PyObject *plot3 = PyObject_GetAttrString(axis, "scatter");
1027+
if (!plot3) throw std::runtime_error("No 3D line plot");
1028+
Py_INCREF(plot3);
1029+
PyObject *res = PyObject_Call(plot3, args, kwargs);
1030+
if (!res) throw std::runtime_error("Failed 3D line plot");
1031+
Py_DECREF(plot3);
1032+
1033+
Py_DECREF(axis);
1034+
Py_DECREF(args);
1035+
Py_DECREF(kwargs);
1036+
Py_DECREF(fig);
1037+
if (res) Py_DECREF(res);
1038+
return res;
1039+
1040+
}
1041+
9141042
template<typename Numeric>
9151043
bool boxplot(const std::vector<std::vector<Numeric>>& data,
9161044
const std::vector<std::string>& labels = {},
@@ -1139,9 +1267,9 @@ bool contour(const std::vector<NumericX>& x, const std::vector<NumericY>& y,
11391267
const std::map<std::string, std::string>& keywords = {}) {
11401268
assert(x.size() == y.size() && x.size() == z.size());
11411269

1142-
PyObject* xarray = get_array(x);
1143-
PyObject* yarray = get_array(y);
1144-
PyObject* zarray = get_array(z);
1270+
PyObject* xarray = detail::get_array(x);
1271+
PyObject* yarray = detail::get_array(y);
1272+
PyObject* zarray = detail::get_array(z);
11451273

11461274
PyObject* plot_args = PyTuple_New(3);
11471275
PyTuple_SetItem(plot_args, 0, xarray);
@@ -2094,12 +2222,14 @@ inline void axvspan(double xmin, double xmax, double ymin = 0., double ymax = 1.
20942222

20952223
// construct keyword args
20962224
PyObject* kwargs = PyDict_New();
2097-
for(std::map<std::string, std::string>::const_iterator it = keywords.begin(); it != keywords.end(); ++it)
2098-
{
2099-
if (it->first == "linewidth" || it->first == "alpha")
2100-
PyDict_SetItemString(kwargs, it->first.c_str(), PyFloat_FromDouble(std::stod(it->second)));
2101-
else
2102-
PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str()));
2225+
for (auto it = keywords.begin(); it != keywords.end(); ++it) {
2226+
if (it->first == "linewidth" || it->first == "alpha") {
2227+
PyDict_SetItemString(kwargs, it->first.c_str(),
2228+
PyFloat_FromDouble(std::stod(it->second)));
2229+
} else {
2230+
PyDict_SetItemString(kwargs, it->first.c_str(),
2231+
PyString_FromString(it->second.c_str()));
2232+
}
21032233
}
21042234

21052235
PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_axvspan, args, kwargs);
@@ -2319,6 +2449,25 @@ inline void save(const std::string& filename)
23192449
Py_DECREF(res);
23202450
}
23212451

2452+
inline void rcparams(const std::map<std::string, std::string>& keywords = {}) {
2453+
detail::_interpreter::get();
2454+
PyObject* args = PyTuple_New(0);
2455+
PyObject* kwargs = PyDict_New();
2456+
for (auto it = keywords.begin(); it != keywords.end(); ++it) {
2457+
if ("text.usetex" == it->first)
2458+
PyDict_SetItemString(kwargs, it->first.c_str(), PyLong_FromLong(std::stoi(it->second.c_str())));
2459+
else PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str()));
2460+
}
2461+
2462+
PyObject * update = PyObject_GetAttrString(detail::_interpreter::get().s_python_function_rcparams, "update");
2463+
PyObject * res = PyObject_Call(update, args, kwargs);
2 6B7B 464+
if(!res) throw std::runtime_error("Call to rcParams.update() failed.");
2465+
Py_DECREF(args);
2466+
Py_DECREF(kwargs);
2467+
Py_DECREF(update);
2468+
Py_DECREF(res);
2469+
}
2470+
23222471
inline void clf() {
23232472
detail::_interpreter::get();
23242473

0 commit comments

Comments
 (0)
0