-
Notifications
You must be signed in to change notification settings - Fork 3
/
tcsvmLBFGS_demo.m
77 lines (64 loc) · 2 KB
/
tcsvmLBFGS_demo.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
function tcsvmLBFGS_demo
clc
clear all
close all
%% generate data
nsamples = 200;
% training data
[x, y] = tcdataGenerator(nsamples);
% testing data
[xt, yt] = tcdataGenerator(nsamples);
[m n] = size(x);
%% Solver: LBFGS Pegasos SGD
option.C = 0.01;
option.debug = 1;
w = tcsvmSGD(x, y, option)
%% Visualize Results
figure(1)
xmin = min(x(:))-1;
xmax = max(x(:))+1;
data_pos = x(find(y==1),:);
data_neg = x(find(y==-1),:);
subplot(121)
hold on
scatter(data_pos(:, 1), data_pos(:, 2), 'b+', 'SizeData', 200, 'LineWidth', 2);
scatter(data_neg(:, 1), data_neg(:, 2), 'gx', 'SizeData', 200, 'LineWidth', 2);
axis tight
margin = xmin:0.1:xmax;
plot(margin, (-w(3)-margin*w(1))/w(2), 'r', 'LineWidth', 2);
plot(margin, (1-w(3)-margin*w(1))/w(2), 'r:', 'LineWidth', 1.5);
plot(margin, (-1-w(3)-margin*w(1))/w(2), 'r:', 'LineWidth', 1.5);
hold off
%% predict
%training data
X = [x ones(m, 1)];
acc = length(find(y.*(X*w)>0))/length(y);
disp(['training acc: ', num2str(acc)])
title(['C = ', num2str(option.C), ', acc = ', num2str(acc)])
T = [xt ones(m, 1)];
acc = length(find(yt.*(T*w)>0))/length(yt);
disp(['testing acc: ', num2str(acc)])
%% LBFGS Solver
option.C = 0.01;
w = tcsvmLBFGS(x, y, option)
%% Visualize Results
figure(1)
subplot(122)
hold on
title(['C = ', num2str(option.C)])
scatter(data_pos(:, 1), data_pos(:, 2), 'b+', 'SizeData', 200, 'LineWidth', 2);
scatter(data_neg(:, 1), data_neg(:, 2), 'gx', 'SizeData', 200, 'LineWidth', 2);
axis tight
margin = xmin:0.1:xmax;
plot(margin, (-w(3)-margin*w(1))/w(2), 'r', 'LineWidth', 2);
plot(margin, (1-w(3)-margin*w(1))/w(2), 'r:', 'LineWidth', 1.5);
plot(margin, (-1-w(3)-margin*w(1))/w(2), 'r:', 'LineWidth', 1.5);
hold off
%% predict
X = [x ones(m, 1)];
acc = length(find(y.*(X*w)>0))/length(y);
disp(['training acc: ', num2str(acc)]);
title(['C = ', num2str(option.C), ', acc = ', num2str(acc)])
T = [xt ones(m, 1)];
acc = length(find(yt.*(T*w)>0))/length(yt);
disp(['testing acc: ', num2str(acc)])