-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassify_knn.py
More file actions
74 lines (57 loc) · 2.18 KB
/
classify_knn.py
File metadata and controls
74 lines (57 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import os
import sys
import tensorflow.keras as keras
import numpy as np
from utils import prepareData
class NOT_SKLEARN_KNN(object):
def __init__(self, n_neightbors=5):
self.k_neigh = n_neightbors
self.data = []
self.labels = []
self.cats = []
def fit(self, data, labels):
self.data = data
self.labels = labels
self.cats = [x for x in range(4)]
def predict(self, new_data):
if np.any(np.array([len(x) for x in [self.labels, self.cats, self.data]]) < 1):
sys.exit("Error: KNN model not fitted.")
lbls = []
for samp in new_data:
nn_idxs = np.power(samp - self.data, 2).sum(axis=1).argsort()[:self.k_neigh]
lbls_lst = self.labels[nn_idxs]
lbls_count = np.array([sum(lbls_lst == x) for x in self.cats])
best_lbl = self.cats[lbls_count.argmax()]
lbls.append(best_lbl)
return lbls
def getKNN(nn_model, images: np.ndarray, labels: np.ndarray):
imgs_vecs = nn_model.predict(images)
knn = NOT_SKLEARN_KNN(n_neightbors=5)
knn.fit(imgs_vecs, labels)
return knn
def main(model_path: str, img_fld: str):
# Training the KNN
model = keras.models.load_model(model_path)
img_h = img_w = model.inputs[0].shape[1]
train_x, test_x, train_y, test_y = prepareData(
img_fld,
img_h,
sample_size=-30,
normalize=True)
print("Building KNN model..")
knn = getKNN(model, train_x, train_y)
print("Predicting the Test dataset..")
test_vecs = model.predict(test_x)
test_pred = knn.predict(test_vecs)
accuracy = np.asarray(test_pred == test_y).sum() / len(test_y)
print("Accuracy: %f" % accuracy)
if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
parser = argparse.ArgumentParser(description='Train KNN')
parser.add_argument('--model', dest="model", type=str, required=True,
help='The trained model to load')
parser.add_argument('--images', dest="img_folder", type=str, required=True,
help='Location of the images')
args = parser.parse_args()
main(args.model, args.img_folder)