Skip to content

Instantly share code, notes, and snippets.

@sgl0v
Created June 2, 2021 19:15
Show Gist options
  • Save sgl0v/d287a53d0d7069cabec4e2e76b296bb5 to your computer and use it in GitHub Desktop.
Save sgl0v/d287a53d0d7069cabec4e2e76b296bb5 to your computer and use it in GitHub Desktop.
def train_model(image_name, model):
image = img_to_array(load_img(image_name))
image = np.array(image, dtype=np.float32)
X = rgb2lab(1.0/255*image)[:, :, 0]
Y = rgb2lab(1.0/255*image)[:, :, 1:]
Y /= 128
X = X.reshape(1, 400, 400, 1)
Y = Y.reshape(1, 400, 400, 2)
model.fit(x=X,
y=Y,
batch_size=1,
epochs=1000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment