Skip to content

Commit 924925d

Browse files
authored
switch SVM example to half-moon dataset (#421)
1 parent 992b665 commit 924925d

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

Diff for: examples/support-vector-machine/script.jl

+28-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# # Support Vector Machine
22
#
3+
# In this notebook we show how you can use KernelFunctions.jl to generate
4+
# kernel matrices for classification with a support vector machine, as
5+
# implemented by LIBSVM.
36

47
using Distributions
58
using KernelFunctions
@@ -8,39 +11,45 @@ using LinearAlgebra
811
using Plots
912
using Random
1013

11-
## Set plotting theme
12-
theme(:wong)
13-
1414
## Set seed
1515
Random.seed!(1234);
1616

17-
# Number of samples:
18-
N = 100;
17+
# ## Generate half-moon dataset
18+
19+
# Number of samples per class:
20+
nin = nout = 50;
1921

20-
# Select randomly between two classes:
21-
y_train = rand([-1, 1], N);
22+
# We generate data based on SciKit-Learn's sklearn.datasets.make_moons function:
2223

23-
# Random attributes for both classes:
24-
X = Matrix{Float64}(undef, 2, N)
25-
rand!(MvNormal(randn(2), I), view(X, :, y_train .== 1))
26-
rand!(MvNormal(randn(2), I), view(X, :, y_train .== -1));
27-
x_train = ColVecs(X);
24+
class1x = cos.(range(0, π; length=nout))
25+
class1y = sin.(range(0, π; length=nout))
26+
class2x = 1 .- cos.(range(0, π; length=nin))
27+
class2y = 1 .- sin.(range(0, π; length=nin)) .- 0.5
28+
X = hcat(vcat(class1x, class2x), vcat(class1y, class2y))
29+
X .+= 0.1randn(size(X))
30+
x_train = RowVecs(X)
31+
y_train = vcat(fill(-1, nout), fill(1, nin));
2832

29-
# Create a 2D grid:
33+
# Create a 100×100 2D grid for evaluation:
3034
test_range = range(floor(Int, minimum(X)), ceil(Int, maximum(X)); length=100)
3135
x_test = ColVecs(mapreduce(collect, hcat, Iterators.product(test_range, test_range)));
3236

37+
# ## SVM model
38+
#
3339
# Create kernel function:
34-
k = SqExponentialKernel() ScaleTransform(2.0)
40+
k = SqExponentialKernel() ScaleTransform(1.5)
3541

3642
# [LIBSVM](https://github.com/JuliaML/LIBSVM.jl) can make use of a pre-computed kernel matrix.
3743
# KernelFunctions.jl can be used to produce that.
38-
# Precomputed matrix for training (corresponds to linear kernel)
44+
#
45+
# Precomputed matrix for training
3946
model = svmtrain(kernelmatrix(k, x_train), y_train; kernel=LIBSVM.Kernel.Precomputed)
4047

4148
# Precomputed matrix for prediction
42-
y_pr, _ = svmpredict(model, kernelmatrix(k, x_train, x_test));
49+
y_pred, _ = svmpredict(model, kernelmatrix(k, x_train, x_test));
4350

44-
# Compute prediction on a grid:
45-
contourf(test_range, test_range, y_pr)
46-
scatter!(X[1, :], X[2, :]; color=y_train, lab="data", widen=false)
51+
# Visualize prediction on a grid:
52+
plot(; lim=extrema(test_range), aspect_ratio=1)
53+
contourf!(test_range, test_range, y_pred; levels=1, color=cgrad(:redsblues), alpha=0.7)
54+
scatter!(X[y_train .== -1, 1], X[y_train .== -1, 2]; color=:red, label="class 1")
55+
scatter!(X[y_train .== +1, 1], X[y_train .== +1, 2]; color=:blue, label="class 2")

0 commit comments

Comments
 (0)