-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexpToyData.py
More file actions
78 lines (65 loc) · 2.66 KB
/
expToyData.py
File metadata and controls
78 lines (65 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Data Generation for Toy Example
from mixture_latent_analysis import *
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy as np
from helper import guassianPlot2D
# z ~ N(0,I), x|z ~ N(Wz+mu, sigma2*I)
def genData4PPCA(num_pts):
z = np.array([np.random.normal(loc=0, scale=1, size=num_pts)]) # 1xnum_pts
W = np.array([[3],[1]]) # 2x1
W = W/np.sqrt(np.sum(W*W))
mu = np.array([[-2],[1]]) # 2x1
line_points = np.matmul(W,[[-2, 2]])+mu
sigma = 0.4
noise = np.array([np.random.normal(loc=0, scale=sigma, size=num_pts) for i in list(range(2))])
out = np.matmul(W,z) + mu + noise
print('gt W={}'.format(W))
print('gt sigma2={}'.format(sigma*sigma))
return out, line_points
if __name__ == '__main__':
num_pts = 400
data, line_points = genData4PPCA(num_pts)
# W, mu, sigma2 = ppca_closed_form(data,1)
# W = W/np.sqrt(np.sum(W*W))
# print('ppca W={}'.format(W))
# print('ppca sigma2={}'.format(sigma2))
# line_points_ppca = np.matmul(W,[[-2, 2]])+mu
# Return:
# mu, (D,1)
# W, (D,K)
# psi, (D,)
mu, W, sigma2 = ppca_em(data,1)
C = np.matmul(W,W.T) + sigma2*np.eye(2)
W = W/np.sqrt(np.sum(W*W))
print('ppca_em W={}'.format(W))
print('ppca_em sigma2={}'.format(sigma2))
line_points_ppca_em = np.matmul(W,[[-2, 2]])+mu
mu_fa, W_fa, psi = fa_em(data,1)
C_fa = np.matmul(W_fa,W_fa.T) + np.diag(psi)
W_fa = W_fa/np.sqrt(np.sum(W_fa*W_fa))
print('fa_em W_fa={}'.format(W_fa))
print('fa_em psi={}'.format(psi))
line_points_fa_em = np.matmul(W_fa,[[-2, 2]])+mu_fa
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(line_points[0,:],line_points[1,:],'k-')
# Plotting ppca latent space and marginal pdf
# plt.plot(line_points_ppca_em[0,:],line_points_ppca_em[1,:],'g-.')
# pts = guassianPlot2D(mu,C)
# plt.plot(pts[0,:],pts[1,:],'g-.')
# Plotting fa latent space and marginal pdf
plt.plot(line_points_fa_em[0,:],line_points_fa_em[1,:],'r--')
pts_fa = guassianPlot2D(mu_fa,C_fa)
plt.plot(pts_fa[0,:],pts_fa[1,:],'r-.')
# Plotting data points scatter
ax.scatter(data[0],data[1], c='b', marker='o')
# Inference part
X = np.array([[-1.],[2.2]])
ax.scatter(X[0],X[1],s=[60], c='r', marker='o')
Z, ZinXSpace = fa_inference(X,mu_fa,W_fa,psi)
ax.scatter(ZinXSpace[0],ZinXSpace[1],s=[60], c='r', marker='^')
# plt.legend(['Data points','True latent space','PPCA_EM latent space','PPCA_EM marginal pdf','FA_EM latent space','FA_EM marginal pdf','Data points','Test Point X','Latent Variable Z for X'])
plt.legend(['True latent space','FA_EM latent space','FA_EM marginal pdf','Data points','Test Point X','Latent Variable Z for X'])
plt.axis('equal')
plt.show()