Skip to content

Instantly share code, notes, and snippets.

View ajinkyajawale14499's full-sized avatar
🎯
Focusing

Ajinkya ajinkyajawale14499

🎯
Focusing
View GitHub Profile
for image_batch, label_batch in train_generator:
break
image_batch.shape, label_batch.shape
print (train_generator.class_indices)
labels = '\n'.join(sorted(train_generator.class_indices.keys()))
with open('labels.txt', 'w') as f:
f.write(labels)
IMG_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)
# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')
model = tf.keras.Sequential([
base_model,
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(5, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss='categorical_crossentropy',
metrics=['accuracy'])
epochs = 10
history = model.fit(train_generator,
epochs=epochs,
validation_data=val_generator)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))
# Fine tune from this layer onwards
fine_tune_at = 100
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
model.compile(loss='categorical_crossentropy',
optimizer = tf.keras.optimizers.Adam(1e-5),
metrics=['accuracy'])
saved_model_dir = 'save/fine_tuning'
tf.saved_model.save(model, saved_model_dir)
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
# download the model & labels