-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathplot_boundary.m
More file actions
86 lines (61 loc) · 1.82 KB
/
plot_boundary.m
File metadata and controls
86 lines (61 loc) · 1.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
function [] = plot_boundary( x1, x2, Wb, act_type )
%{
FUNCTION DESCRIPTION:
---------------------------------------------------------------------------
Plot the decision boundary predicted by the trained neural net.
---------------------------------------------------------------------------
INPUTS:
---------------------------------------------------------------------------
x1,x2 -- co-ordinates of training data {vectors},
Wb -- trained weights and bias vectors {cell array},
act_type -- type of activation function used {string}.
---------------------------------------------------------------------------
Written by: C F Higham and D J Higham, August 2017
Available at: https://www.maths.ed.ac.uk/~dhigham/algfiles.html
Adapted by: James Rynn
Last edited: 20/03/2020
%}
%% CHOOSE EVALUATION POINTS:
% Number of layers.
L = size(Wb,1);
% Number of evaluation points.
NN = 500;
% Spacing between points.
Dx = 1/NN;
Dy = 1/NN;
% Evaluation points.
xvals = 0:Dx:1;
yvals = 0:Dy:1;
% Allocate storage for predictions.
Aval = zeros(NN+1);
Bval = zeros(NN+1);
%% EVALUATE PREDICTIONS:
% Loop through evaluation points.
for k1 = 1:NN+1
xk = xvals(k1);
for k2 = 1:NN+1
yk = yvals(k2);
a = [xk; yk];
for l = 2:L
a = my_activate(a, Wb{l,1}, Wb{l,2}, act_type);
end
Aval(k2,k1) = a(1);
Bval(k2,k1) = a(2);
end
end
[X,Y] = meshgrid(xvals,yvals);
%% PLOT PREDICTIONS:
% Note, assumes first 5 entries of x1,x2 are type A and the rest are type B.
figure
a2 = subplot(1,1,1);
Mval = Aval>Bval;
contourf(X,Y,Mval,[0.5 0.5])
hold on
colormap([1 1 1; 0.8 0.8 0.8])
plot(x1(1:5),x2(1:5),'ro','MarkerSize',10,'LineWidth',1.5)
plot(x1(6:end),x2(6:end),'bx','MarkerSize',10,'LineWidth',1.5)
a2.XTick = [0 1];
a2.YTick = [0 1];
a2.FontSize = 16;
xlim([0,1])
ylim([0,1])