1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
| clear all; clc; data = load('Iris-train.txt'); testD = load('Iris-test.txt') x2 = testD(:,1:end-1); y2 = testD(:,end); x = data(:,1:end-1); temp = data(:,end); y = zeros(3,75); for i = 1:75 y(temp(i)+1,i) =1; end
f = @(x) 1/(1+exp(-x))
delta = 0.1; V= randn(4,10); W = randn(10,3); gamma = randn(10,1); theta = randn(3,1);
hTrain = []; hTest = []; flag = 1; while 1 accTrain = 0; for k = 1:75 b = arrayfun(f,V'*x(k,:)'-gamma); y_bar = arrayfun(f,W'*b-theta); g = y_bar.*(1-y_bar).*(y(:,k)-y_bar); e = b.*(1-b).*(W*g);
W = W + delta*b*g'; theta = theta - delta*g; V = V + delta*x(k,:)'*e'; gamma = gamma - delta*e; [maxL,label] = max(y_bar); accTrain = accTrain + (label==(temp(k)+1)); end b = arrayfun(f,V'*x2'-repmat(gamma,1,75)); y_bar = arrayfun(f,W'*b-repmat(theta,1,75)); [maxL,label] = max(y_bar); hTrain = [hTrain accTrain/75]; hTest = [hTest sum(label == (y2+1)')/75]; plot(hTrain,'blue') hold on plot(hTest,'red') if flag == 1 legend('训练集acc','测试集acc'); flag = 0; end if sum(label == (y2+1)')/75>0.98 break end pause(0.000001); end
|