-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
Copy pathplot_helpers.py
108 lines (79 loc) · 3.05 KB
/
plot_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, colorConverter, LinearSegmentedColormap
cm_cycle = ListedColormap(['#0000aa', '#ff5050', '#50ff50', '#9040a0', '#fff000'])
cm3 = ListedColormap(['#0000aa', '#ff2020', '#50ff50'])
cm2 = ListedColormap(['#0000aa', '#ff2020'])
# create a smooth transition from the first to to the second color of cm3
# similar to RdBu but with our red and blue, also not going through white,
# which is really bad for greyscale
cdict = {'red': [(0.0, 0.0, cm2(0)[0]),
(1.0, cm2(1)[0], 1.0)],
'green': [(0.0, 0.0, cm2(0)[1]),
(1.0, cm2(1)[1], 1.0)],
'blue': [(0.0, 0.0, cm2(0)[2]),
(1.0, cm2(1)[2], 1.0)]}
ReBl = LinearSegmentedColormap("ReBl", cdict)
def discrete_scatter(x1, x2, y=None, markers=None, s=10, ax=None,
labels=None, padding=.2, alpha=1, c=None, markeredgewidth=None):
"""Adaption of matplotlib.pyplot.scatter to plot classes or clusters.
Parameters
----------
x1 : nd-array
input data, first axis
x2 : nd-array
input data, second axis
y : nd-array
input data, discrete labels
cmap : colormap
Colormap to use.
markers : list of string
List of markers to use, or None (which defaults to 'o').
s : int or float
Size of the marker
padding : float
Fraction of the dataset range to use for padding the axes.
alpha : float
Alpha value for all points.
"""
if ax is None:
ax = plt.gca()
if y is None:
y = np.zeros(len(x1))
unique_y = np.unique(y)
if markers is None:
markers = ['o', '^', 'v', 'D', 's', '*', 'p', 'h', 'H', '8', '<', '>'] * 10
if len(markers) == 1:
markers = markers * len(unique_y)
if labels is None:
labels = unique_y
# lines in the matplotlib sense, not actual lines
lines = []
current_cycler = mpl.rcParams['axes.prop_cycle']
for i, (yy, cycle) in enumerate(zip(unique_y, current_cycler())):
mask = y == yy
# if c is none, use color cycle
if c is None:
color = cycle['color']
elif len(c) > 1:
color = c[i]
else:
color = c
# use light edge for dark markers
if np.mean(colorConverter.to_rgb(color)) < .4:
markeredgecolor = "grey"
else:
markeredgecolor = "black"
lines.append(ax.plot(x1[mask], x2[mask], markers[i], markersize=s,
label=labels[i], alpha=alpha, c=color,
markeredgewidth=markeredgewidth,
markeredgecolor=markeredgecolor)[0])
if padding != 0:
pad1 = x1.std() * padding
pad2 = x2.std() * padding
xlim = ax.get_xlim()
ylim = ax.get_ylim()
ax.set_xlim(min(x1.min() - pad1, xlim[0]), max(x1.max() + pad1, xlim[1]))
ax.set_ylim(min(x2.min() - pad2, ylim[0]), max(x2.max() + pad2, ylim[1]))
return lines