-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
Copy pathtools.py
94 lines (82 loc) · 3.43 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.tree import export_graphviz
import matplotlib.pyplot as plt
from .plot_2d_separator import plot_2d_separator, plot_2d_classification, plot_2d_scores
from .plot_helpers import cm2 as cm, discrete_scatter
def visualize_coefficients(coefficients, feature_names, n_top_features=25):
# get coefficients with large absolute values
coef = coefficients.ravel()
positive_coefficients = np.argsort(coef)[-n_top_features:]
negative_coefficients = np.argsort(coef)[:n_top_features]
interesting_coefficients = np.hstack([negative_coefficients, positive_coefficients])
# plot them
plt.figure(figsize=(15, 5))
colors = [cm(1) if c < 0 else cm(0) for c in coef[interesting_coefficients]]
plt.bar(np.arange(2 * n_top_features), coef[interesting_coefficients], color=colors)
feature_names = np.array(feature_names)
plt.subplots_adjust(bottom=0.3)
plt.xticks(np.arange(1, 1 + 2 * n_top_features),
feature_names[interesting_coefficients], rotation=60, ha="right")
plt.ylabel("Coefficient magnitude")
plt.xlabel("Feature")
def heatmap(values, xlabel, ylabel, xticklabels, yticklabels, cmap=None,
vmin=None, vmax=None, ax=None, fmt="%0.2f"):
if ax is None:
ax = plt.gca()
# plot the mean cross-validation scores
img = ax.pcolor(values, cmap=cmap, vmin=None, vmax=None)
img.update_scalarmappable()
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xticks(np.arange(len(xticklabels)) + .5)
ax.set_yticks(np.arange(len(yticklabels)) + .5)
ax.set_xticklabels(xticklabels)
ax.set_yticklabels(yticklabels)
ax.set_aspect(1)
for p, color, value in zip(img.get_paths(), img.get_facecolors(), img.get_array()):
x, y = p.vertices[:-2, :].mean(0)
if np.mean(color[:3]) > 0.5:
c = 'k'
else:
c = 'w'
ax.text(x, y, fmt % value, color=c, ha="center", va="center")
return img
def make_handcrafted_dataset():
# a carefully hand-designed dataset lol
X, y = make_blobs(centers=2, random_state=4, n_samples=30)
y[np.array([7, 27])] = 0
mask = np.ones(len(X), dtype=np.bool)
mask[np.array([0, 1, 5, 26])] = 0
X, y = X[mask], y[mask]
return X, y
def print_topics(topics, feature_names, sorting, topics_per_chunk=6, n_words=20):
for i in range(0, len(topics), topics_per_chunk):
# for each chunk:
these_topics = topics[i: i + topics_per_chunk]
# maybe we have less than topics_per_chunk left
len_this_chunk = len(these_topics)
# print topic headers
print(("topic {:<8}" * len_this_chunk).format(*these_topics))
print(("-------- {0:<5}" * len_this_chunk).format(""))
# print top n_words frequent words
for i in range(n_words):
try:
print(("{:<14}" * len_this_chunk).format(*feature_names[sorting[these_topics, i]]))
except:
pass
print("\n")
def get_tree(tree, **kwargs):
try:
# python3
from io import StringIO
except ImportError:
# python2
from StringIO import StringIO
f = StringIO()
export_graphviz(tree, f, **kwargs)
import graphviz
return graphviz.Source(f.getvalue())
__all__ = ['plot_2d_separator', 'plot_2d_classification', 'plot_2d_scores',
'cm', 'visualize_coefficients', 'print_topics', 'heatmap',
'discrete_scatter']