-
Notifications
You must be signed in to change notification settings - Fork 3
/
tcsvmSMO_MNIST.m
70 lines (61 loc) · 1.5 KB
/
tcsvmSMO_MNIST.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
function tcsvmSMO_MNIST
clear all
close all
clc
%images = loadMNISTImages('train-images.idx3-ubyte');size(images)
%labels = loadMNISTLabels('train-labels.idx1-ubyte');size(labels)
%% load data
[I,labels,I_test,labels_test] = readMNIST(1000);
%% train
nclass = 10;
y_train = double(labels) + 1.0;
x_train = [];
for i = 1:length(I)
x_train = [x_train; I{i}(:)'];
end
x_train = im2double(x_train);
%clear I
%clear lables
[m n] = size(x_train);
model = {};
option.C = 1;
disp('training...');
for c = 1:nclass
disp([num2str(c), '-th loop:']);
idc = find(y_train==c);
yc_train = -ones(size(y_train));
yc_train(idc) = 1;
% tcsvmSMO
[alphay, b, sv, w] = tcsvmSMO(x_train, yc_train, option);
mc.alphay = alphay;
mc.b = b;
mc.sv = sv;
model{c} = mc;
end
clear x_train
clear y_train
%% test
y_test = double(labels_test) + 1.0;
x_test = [];
for i = 1:length(I)
x_test = [x_test; I_test{i}(:)'];
end
x_test = im2double(x_test);
clear I_test
clear lables_test
accuracy = [];
disp('testing...');
for c = 1:nclass
disp([num2str(c), '-th loop:']);
idc = find(y_test==c);
yc_test = -ones(size(y_test));
yc_test(idc) = 1;
wc = model{c};
% predict
f = x_test*(repmat(wc.alphay, 1, n).*wc.sv)';
f = sum(f, 2) + wc.b;
acc = length(find(yc_test.*f>0)) / length(yc_test);
accuracy = [accuracy acc];
disp(['accuracy: ', num2str(acc)])
end
disp(['avg-accuracy: ', num2str(mean(accuracy))])