-
Notifications
You must be signed in to change notification settings - Fork 4.6k
/
Copy pathtools.py
128 lines (112 loc) · 4.69 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.tree import export_graphviz
import matplotlib.pyplot as plt
from .plot_2d_separator import (plot_2d_separator, plot_2d_classification,
plot_2d_scores)
from .plot_helpers import cm2 as cm, discrete_scatter
def visualize_coefficients(coefficients, feature_names, n_top_features=25):
"""Visualize coefficients of a linear model.
Parameters
----------
coefficients : nd-array, shape (n_features,)
Model coefficients.
feature_names : list or nd-array of strings, shape (n_features,)
Feature names for labeling the coefficients.
n_top_features : int, default=25
How many features to show. The function will show the largest (most
positive) and smallest (most negative) n_top_features coefficients,
for a total of 2 * n_top_features coefficients.
"""
coefficients = coefficients.squeeze()
if coefficients.ndim > 1:
# this is not a row or column vector
raise ValueError("coeffients must be 1d array or column vector, got"
" shape {}".format(coefficients.shape))
coefficients = coefficients.ravel()
if len(coefficients) != len(feature_names):
raise ValueError("Number of coefficients {} doesn't match number of"
"feature names {}.".format(len(coefficients),
len(feature_names)))
# get coefficients with large absolute values
coef = coefficients.ravel()
positive_coefficients = np.argsort(coef)[-n_top_features:]
negative_coefficients = np.argsort(coef)[:n_top_features]
interesting_coefficients = np.hstack([negative_coefficients,
positive_coefficients])
# plot them
plt.figure(figsize=(15, 5))
colors = [cm(1) if c < 0 else cm(0)
for c in coef[interesting_coefficients]]
plt.bar(np.arange(2 * n_top_features), coef[interesting_coefficients],
color=colors)
feature_names = np.array(feature_names)
plt.subplots_adjust(bottom=0.3)
plt.xticks(np.arange(1, 1 + 2 * n_top_features),
feature_names[interesting_coefficients], rotation=60,
ha="right")
plt.ylabel("Coefficient magnitude")
plt.xlabel("Feature")
def heatmap(values, xlabel, ylabel, xticklabels, yticklabels, cmap=None,
vmin=None, vmax=None, ax=None, fmt="%0.2f"):
if ax is None:
ax = plt.gca()
# plot the mean cross-validation scores
img = ax.pcolor(values, cmap=cmap, vmin=vmin, vmax=vmax)
img.update_scalarmappable()
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_xticks(np.arange(len(xticklabels)) + .5)
ax.set_yticks(np.arange(len(yticklabels)) + .5)
ax.set_xticklabels(xticklabels)
ax.set_yticklabels(yticklabels)
ax.set_aspect(1)
for p, color, value in zip(img.get_paths(), img.get_facecolors(),
img.get_array()):
x, y = p.vertices[:-2, :].mean(0)
if np.mean(color[:3]) > 0.5:
c = 'k'
else:
c = 'w'
ax.text(x, y, fmt % value, color=c, ha="center", va="center")
return img
def make_handcrafted_dataset():
# a carefully hand-designed dataset lol
X, y = make_blobs(centers=2, random_state=4, n_samples=30)
y[np.array([7, 27])] = 0
mask = np.ones(len(X), dtype=np.bool)
mask[np.array([0, 1, 5, 26])] = 0
X, y = X[mask], y[mask]
return X, y
def print_topics(topics, feature_names, sorting, topics_per_chunk=6,
n_words=20):
for i in range(0, len(topics), topics_per_chunk):
# for each chunk:
these_topics = topics[i: i + topics_per_chunk]
# maybe we have less than topics_per_chunk left
len_this_chunk = len(these_topics)
# print topic headers
print(("topic {:<8}" * len_this_chunk).format(*these_topics))
print(("-------- {0:<5}" * len_this_chunk).format(""))
# print top n_words frequent words
for i in range(n_words):
try:
print(("{:<14}" * len_this_chunk).format(
*feature_names[sorting[these_topics, i]]))
except:
pass
print("\n")
def get_tree(tree, **kwargs):
try:
# python3
from io import StringIO
except ImportError:
# python2
from StringIO import StringIO
f = StringIO()
export_graphviz(tree, f, **kwargs)
import graphviz
return graphviz.Source(f.getvalue())
__all__ = ['plot_2d_separator', 'plot_2d_classification', 'plot_2d_scores',
'cm', 'visualize_coefficients', 'print_topics', 'heatmap',
'discrete_scatter']