1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
   | clear all; clc; data = load('Iris-train.txt'); testD = load('Iris-test.txt') x2 = testD(:,1:end-1); y2 = testD(:,end); x = data(:,1:end-1); temp = data(:,end); y = zeros(3,75); for i = 1:75     y(temp(i)+1,i) =1; end
  f = @(x) 1/(1+exp(-x))
  delta = 0.1; V= randn(4,10); W = randn(10,3); gamma = randn(10,1); theta = randn(3,1);
 
  hTrain = [];  hTest = []; flag = 1; while 1     accTrain = 0;     for k = 1:75           b = arrayfun(f,V'*x(k,:)'-gamma);         y_bar = arrayfun(f,W'*b-theta);                           g = y_bar.*(1-y_bar).*(y(:,k)-y_bar);                            e = b.*(1-b).*(W*g);
                   W = W + delta*b*g';          theta = theta - delta*g;         V = V + delta*x(k,:)'*e';         gamma = gamma -  delta*e;         [maxL,label] = max(y_bar);         accTrain = accTrain + (label==(temp(k)+1));     end          b = arrayfun(f,V'*x2'-repmat(gamma,1,75));     y_bar = arrayfun(f,W'*b-repmat(theta,1,75));     [maxL,label] = max(y_bar);     hTrain = [hTrain accTrain/75];     hTest = [hTest sum(label == (y2+1)')/75];     plot(hTrain,'blue')     hold on     plot(hTest,'red')     if flag == 1         legend('训练集acc','测试集acc');         flag = 0;     end     if sum(label == (y2+1)')/75>0.98         break     end     pause(0.000001); end
   |