Skip to content

Instantly share code, notes, and snippets.

@RyanKor
Created September 24, 2021 10:57
Show Gist options
  • Select an option

  • Save RyanKor/aa23b5fd7cbd231c44aebcc62ec5aea0 to your computer and use it in GitHub Desktop.

Select an option

Save RyanKor/aa23b5fd7cbd231c44aebcc62ec5aea0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
# In[1]:
import pandas as pd
import numpy as np
import glob
import os
import zipfile
import tensorflow as tf
from keras_preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model, Sequential, load_model
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import load_model
# In[2]:
# 다운로드 받은 이미지 압축 파일 해제
# print(os.listdir("./drive/MyDrive"))
local_zip = './train.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./res/train')
zip_ref.close()
local_zip = './test.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('./res/test')
zip_ref.close()
# In[3]:
base_dir = "./res/"
train_dir = "train/train/"
train_class = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
test_dir = "test/test/0"
# In[4]:
train_one_hot = []
for path in range(len(train_class)):
class2label = [''] + [0] * len(train_class)
for img in os.listdir(base_dir + train_dir + train_class[path]):
class2label[0] = (train_dir + train_class[path] + "/" + img)[6:]
class2label[path + 1] = 1
train_one_hot.append(class2label)
train_answer = pd.DataFrame(train_one_hot, columns = ["path", "dogs", 'elephant', 'giraffe', 'guitar', 'horse' , 'house', 'person'])
train_answer.to_csv('./train_answer.csv', index=False)
data = pd.read_csv("./train_answer.csv")
columns = data.columns
datagen=ImageDataGenerator(
rescale = 1/255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
train_generator=datagen.flow_from_dataframe(
dataframe=data,
directory='./res/train/',
x_col="path",
y_col=columns[1:],
batch_size=48,
shuffle=False,
class_mode="raw",
target_size=(224,224))
# In[5]:
res_net = tf.keras.applications.ResNet50V2(
include_top=True, weights='imagenet',input_shape=(224,224,3), pooling="max")
res_net.trainable = True
pretrained_data = res_net.predict(train_generator, verbose=1)
x_train, x_valid, y_train, y_valid = train_test_split(pretrained_data, data.iloc[:,1:], test_size=0.2, random_state=42)
# In[ ]:
x_train, x_valid, y_train, y_valid = train_test_split(pretrained_data, data.iloc[:,1:], test_size=0.2, random_state=42)
x_train
# In[6]:
model = Sequential(
[Input(1000,),
BatchNormalization(),
Dense(512, activation='relu'),
# BatchNormalization(),
Dense(256, activation='relu'),
# BatchNormalization(),
Dense(128, activation='relu'),
# BatchNormalization(),
Dense(7, activation='softmax')
])
model.summary()
mc = tf.keras.callbacks.ModelCheckpoint('res_model.h5', monitor='val_loss', mode='min', verbose=1, save_best_only=True)
# In[7]:
model.compile(loss='categorical_crossentropy', optimizer=Adam(0.001), metrics=['accuracy']) # 최적화 함수 학습률 1e-4에서 0.001로 변경
model.fit(x_train, y_train, epochs=30, batch_size=48, validation_data=(x_valid, y_valid),callbacks=[mc])
# In[8]:
loaded_model = load_model('res_model.h5')
final_model = Sequential([res_net, model])
final_model.summary()
# In[9]:
test_df = pd.DataFrame()
test = glob.glob("./res/test/test/0/*.jpg")
test_dir = []
for img in test:
test_dir.append(img[11:])
test_df["path"] = test_dir
test_datagen=ImageDataGenerator(rescale = 1/255)
test_generator = test_datagen.flow_from_dataframe( dataframe=test_df[:],
directory='./res/test/',
x_col="path",
y_col=columns[0],
batch_size=48,
shuffle=False,
class_mode="raw",
target_size=(224,224))
# In[10]:
pred = final_model.predict(test_generator, verbose=1)
# In[11]:
answer = np.array([y.argmax() for y in pred])
test_df = pd.read_csv("./test_answer_sample_.csv")
test_df.iloc[:,1] = answer
test_df.to_csv('res_net_50v2.csv', index=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment