forked from amueller/introduction_to_ml_with_python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_rbf_svm_parameters.py
30 lines (26 loc) · 1.15 KB
/
plot_rbf_svm_parameters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.pyplot as plt
from sklearn.svm import SVC
from .plot_2d_separator import plot_2d_separator
from .tools import make_handcrafted_dataset
from .plot_helpers import discrete_scatter
def plot_svm(log_C, log_gamma, ax=None):
X, y = make_handcrafted_dataset()
C = 10. ** log_C
gamma = 10. ** log_gamma
svm = SVC(kernel='rbf', C=C, gamma=gamma).fit(X, y)
if ax is None:
ax = plt.gca()
plot_2d_separator(svm, X, ax=ax, eps=.5)
# plot data
discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)
# plot support vectors
sv = svm.support_vectors_
# class labels of support vectors are given by the sign of the dual coefficients
sv_labels = svm.dual_coef_.ravel() > 0
discrete_scatter(sv[:, 0], sv[:, 1], sv_labels, s=15, markeredgewidth=3, ax=ax)
ax.set_title("C = %.4f gamma = %.4f" % (C, gamma))
def plot_svm_interactive():
from IPython.html.widgets import interactive, FloatSlider
C_slider = FloatSlider(min=-3, max=3, step=.1, value=0, readout=False)
gamma_slider = FloatSlider(min=-2, max=2, step=.1, value=0, readout=False)
return interactive(plot_svm, log_C=C_slider, log_gamma=gamma_slider)