Skip to content

Instantly share code, notes, and snippets.

@Allesanddro
Created March 23, 2020 14:41
Show Gist options
  • Save Allesanddro/67903e3133aa0455306faa3ff3428061 to your computer and use it in GitHub Desktop.
Save Allesanddro/67903e3133aa0455306faa3ff3428061 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import, division,print_function,unicode_literals
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def plotImages(images_arr):
fig, axes = plt.subplots(1, 5, figsize=(20,20))
axes = axes.flatten()
for img, ax in zip( images_arr, axes):
ax.imshow(img)
ax.axis('off')
plt.tight_layout()
plt.show()
PATH = "/root/work/train_nsfw"
batch_size = 128
epochs = 65
IMG_HEIGHT = 200
IMG_WIDTH = 200
checkpoint_path = os.path.join(PATH, 'checkpoint', 'cp.ckpt')
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
train_dir = os.path.join(PATH, 'data', 'train')
validation_dir = os.path.join(PATH, 'data', 'test');
train_sfw_dir = os.path.join(train_dir, 'sfw')
train_nsfw_dir = os.path.join(train_dir, 'nsfw')
validation_sfw_dir = os.path.join(validation_dir, 'sfw')
validation_nsfw_dir = os.path.join(validation_dir, 'nsfw')
num_sfw_train = len(os.listdir(train_sfw_dir))
num_nsfw_train = len(os.listdir(train_nsfw_dir))
num_sfw_test = len(os.listdir(validation_sfw_dir))
num_nsfw_test = len(os.listdir(validation_nsfw_dir))
total_train = num_sfw_train + num_nsfw_train
total_val = num_sfw_test + num_nsfw_test
print("Total training images:", total_train)
print("Total validation images:", total_val)
train_image_generator = ImageDataGenerator(rescale=1./255, rotation_range=45, horizontal_flip=True, zoom_range=0.5) # Generator for our training data
validation_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our validation data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
directory=validation_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
sample_training_images, _ = next(train_data_gen)
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu',
input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Dropout(0.1),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Dropout(0.1),
Flatten(),
Dense(512, activation='relu'),
Dense(1)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
model.load_weights(checkpoint_path)
model.summary()
history = model.fit_generator(
train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=[cp_callback]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment