forked from amadeuzou/SvmSolvers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tcsvmQP_demo.m
41 lines (34 loc) · 1.04 KB
/
tcsvmQP_demo.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
function tcsvmQP_demo
clc
clear all
close all
%% generate data
nsamples = 200;
% training data
[x, y] = tcdataGenerator(nsamples);
% testing data
[xt, yt] = tcdataGenerator(nsamples);
[m n] = size(x);
%% Quadratic Programming Solver
w = tcsvmQP(x, y);
%% Visualize Results
xmin = min(x(:))-1;
xmax = max(x(:))+1;
data_pos = x(find(y==1),:);
data_neg = x(find(y==-1),:);
scatter(data_pos(:, 1), data_pos(:, 2), 'b+', 'SizeData', 200, 'LineWidth', 2);
hold on
scatter(data_neg(:, 1), data_neg(:, 2), 'gx', 'SizeData', 200, 'LineWidth', 2);
axis tight
margin = xmin:0.1:xmax;
plot(margin, (-w(3)-margin*w(1))/w(2), 'r', 'LineWidth', 2);
plot(margin, (1-w(3)-margin*w(1))/w(2), 'r:', 'LineWidth', 1.5);
plot(margin, (-1-w(3)-margin*w(1))/w(2), 'r:', 'LineWidth', 1.5);
%% predict
X = [x ones(m, 1)];
acc = length(find(y.*(X*w)>0))/length(y);
disp(['training acc: ', num2str(acc)]);
title([' acc = ', num2str(acc)])
T = [xt ones(m, 1)];
acc = length(find(yt.*(T*w)>0))/length(yt);
disp(['testing acc: ', num2str(acc)])