-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path01_image_classification.py
More file actions
119 lines (92 loc) · 3.66 KB
/
01_image_classification.py
File metadata and controls
119 lines (92 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
# Load the CIFAR-10 dataset
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0
# Split the training data into training and validation sets
from sklearn.model_selection import train_test_split
train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.2, random_state=42)
# Build the CNN model
model = models.Sequential()
# First convolutional layer
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
# Second convolutional layer
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# Third convolutional layer
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
# Flatten the output and add dense layers
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax')) # 10 classes for CIFAR-10
# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Show the model architecture
model.summary()
# Data augmentation
datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
# Fit the model
history = model.fit(
datagen.flow(train_images, train_labels, batch_size=32),
steps_per_epoch=len(train_images) // 32,
epochs=30,
validation_data=(val_images, val_labels)
)
# Evaluate the model on the test data
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")
# Plot training and validation accuracy/loss
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))
plt.plot(epochs, acc, 'bo', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.plot(epochs, loss, 'bo', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
# Save the trained model
model.save('cnn_image_classification_model.h5')
# Load the saved model
model = tf.keras.models.load_model('cnn_image_classification_model.h5')
# Function to preprocess and predict new images
def preprocess_image(image_path):
img = tf.keras.utils.load_img(image_path, target_size=(32, 32))
img_array = tf.keras.utils.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) / 255.0
return img_array
def predict_image(image_path):
img_array = preprocess_image(image_path)
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions, axis=1)
return predicted_class[0]
# Example prediction on a new image (uncomment the lines below to test)
# image_path = 'path_to_new_image.jpg'
# predicted_class = predict_image(image_path)
# print(f'Predicted class: {predicted_class}')
image_path = '/content/Vidwud_faceswap.jpg'
predicted_class = predict_image(image_path)
print(f'Predicted class: {predicted_class}')
image_path = '/content/gosau-3724039.jpg'
predicted_class = predict_image(image_path)
print(f'Predicted class: {predicted_class}')