forked from amueller/introduction_to_ml_with_python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_linear_svc_regularization.py
37 lines (30 loc) · 1.08 KB
/
plot_linear_svc_regularization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import matplotlib.pyplot as plt
import numpy as np
from sklearn.svm import LinearSVC
from sklearn.datasets import make_blobs
from .plot_helpers import discrete_scatter
def plot_linear_svc_regularization():
X, y = make_blobs(centers=2, random_state=4, n_samples=30)
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
# a carefully hand-designed dataset lol
y[7] = 0
y[27] = 0
x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
for ax, C in zip(axes, [1e-2, 10, 1e3]):
discrete_scatter(X[:, 0], X[:, 1], y, ax=ax)
svm = LinearSVC(C=C, tol=0.00001, dual=False).fit(X, y)
w = svm.coef_[0]
a = -w[0] / w[1]
xx = np.linspace(6, 13)
yy = a * xx - (svm.intercept_[0]) / w[1]
ax.plot(xx, yy, c='k')
ax.set_xlim(x_min, x_max)
ax.set_ylim(y_min, y_max)
ax.set_xticks(())
ax.set_yticks(())
ax.set_title("C = %f" % C)
axes[0].legend(loc="best")
if __name__ == "__main__":
plot_linear_svc_regularization()
plt.show()