-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_dataAugmentation.py
More file actions
116 lines (102 loc) · 4.01 KB
/
train_dataAugmentation.py
File metadata and controls
116 lines (102 loc) · 4.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import random
from PIL import Image
import pandas as pd
from tensorflow.keras import layers
from tensorflow.keras.applications import ResNet50V2
def load_data(directory, df):
images = []
labels = []
for subdir, _, files in os.walk(directory):
for file in files:
path = os.path.join(subdir, file)
image = Image.open(path)
image = image.resize((224, 224))
image = np.array(image).astype('float64')
if len(image.shape) == 2:
# handle grayscale images
image = np.stack((image,)*3, axis=-1).astype('float64')
images.append(image)
img_id = int(file.split('.')[0])
label = df[df['id']==img_id]['label'].item() -1
labels.append(label)
return np.array(images), np.array(labels)
def load_saved_data(train_im="train_images.npy", train_l="train_labels.npy", test_im="test_images.npy", test_l="test_labels.npy"):
print("Init Loading data")
X_train = np.load(train_im)
y_train = np.load(train_l)
X_test = np.load(test_im)
y_test = np.load(test_l)
print("Finished Loading data")
return X_train, y_train, X_test, y_test
def get_mobile_netv2():
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
include_top=False,
weights='imagenet')
for layer in base_model.layers:
layer.trainable = False
head = layers.Conv2D(4, kernel_size=3, padding='same')(base_model.output)
head = layers.Flatten()(head)
model = tf.keras.Model(inputs=base_model.input, outputs=head)
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(256, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(200, activation='softmax')
])
print("returning mobilenet model")
return model
def get_resnet_model():
resnet_model = tf.keras.applications.ResNet50V2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
for layer in resnet_model.layers[:-10]:
layer.trainable = False
# Add new fully connected layers
x = tf.keras.layers.Flatten()(resnet_model.output)
x = tf.keras.layers.Dense(256, activation='relu')(x)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.Dense(200, activation='softmax')(x)
model = tf.keras.models.Model(resnet_model.input, x)
print("returning resnet50 model")
return model
def train_model():
X_train, y_train, X_test, y_test = load_saved_data()
# Define data augmentation transformations
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.1,
horizontal_flip=True,
vertical_flip=False,
rescale=1./255
)
# Fit the data augmentation on the training data
train_datagen.fit(X_train)
# Define the model
model = get_resnet_model()
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])
model.summary()
checkpoint_path = "resnet_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Define the callbacks, including the checkpoint callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with data augmentation
history = model.fit(train_datagen.flow(X_train, y_train, batch_size=batch_size),
steps_per_epoch=len(X_train) // batch_size,
epochs=200,
validation_data=(X_test, y_test),
callbacks=[cp_callback], validation_freq=5)
model.save('resnet50_model.h5')
if __name__ == '__main__':
train_model()