Skip to content

Instantly share code, notes, and snippets.

@ameasure
Last active July 30, 2021 08:48
Show Gist options
  • Save ameasure/985c87bb8b34ac30269f to your computer and use it in GitHub Desktop.
Save ameasure/985c87bb8b34ac30269f to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
Example of text classification using a Convolution1D network with one hot
representation. Adapted from the imdb_cnn.py example.
Gets to 0.8292 test accuracy after 2 epochs. 153s/epoch on GTX660 GPU.
"""
from __future__ import print_function
import numpy as np
np.random.seed(1337) # for reproducibility
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution1D, MaxPooling1D
from keras.datasets import imdb
# set parameters:
max_features = 5000
maxlen = 100
batch_size = 32
nb_filter = 250
filter_length = 3
hidden_dims = 250
nb_epoch = 2
# define the generator that will create one hot outputs on the fly
def generate_one_hot(X, Y, vocab_size, batch_size):
"""
Inputs:
X: [n_samples, timesteps] each value is the index of a token
Y: [n_samples, n_categories]
Returns: training tuple of x_batch [batch_size, n_timesteps, vocab_size] and y_batch [batch_size, n_categories]
"""
if not hasattr(Y, 'shape'):
Y = np.asarray(Y)
n_samples = len(X)
seq_len = len(X[0])
start = 0
while 1:
stop = start + batch_size
X_subset = X[start: stop]
X_out = np.zeros([batch_size, seq_len, vocab_size])
index_1 = np.repeat(np.arange(batch_size), seq_len).reshape(batch_size, seq_len)
index_2 = np.arange(seq_len)
X_out[index_1, index_2, X_subset] = 1
Y_out = Y[start: stop]
start += batch_size
if (start + batch_size) > n_samples:
print('reshuffling, %s + %s > %s' % (start, batch_size, n_samples))
remaining_X = X[start: start + batch_size]
remaining_Y = Y[start: start + batch_size]
random_index = np.random.permutation(n_samples)
X = np.concatenate((remaining_X, X[random_index]), axis=0)
Y = np.concatenate((remaining_Y, Y[random_index]), axis=0)
start = 0
n_samples = len(X)
yield (X_out, Y_out)
print('Loading data...')
(X_train, y_train), (X_test, y_test) = imdb.load_data(nb_words=max_features,
test_split=0.2)
print(len(X_train), 'train sequences')
print(len(X_test), 'test sequences')
print('Pad sequences (samples x time)')
X_train = sequence.pad_sequences(X_train, maxlen=maxlen)
X_test = sequence.pad_sequences(X_test, maxlen=maxlen)
print('X_train shape:', X_train.shape)
print('X_test shape:', X_test.shape)
print('Build model...')
model = Sequential()
# we add a Convolution1D, which will learn nb_filter
# word group filters of size filter_length:
model.add(Convolution1D(nb_filter=nb_filter,
filter_length=filter_length,
border_mode='valid',
activation='relu',
subsample_length=1, input_shape=(maxlen, max_features)))
# we use standard max pooling (halving the output of the previous layer):
model.add(MaxPooling1D(pool_length=2))
# We flatten the output of the conv layer,
# so that we can add a vanilla dense layer:
model.add(Flatten())
# We add a vanilla hidden layer:
model.add(Dense(hidden_dims))
model.add(Dropout(0.25))
model.add(Activation('relu'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy',
optimizer='rmsprop')
train_generator = generate_one_hot(X_train, y_train, vocab_size=max_features, batch_size=batch_size)
valid_generator = generate_one_hot(X_test, y_test, vocab_size=max_features, batch_size=batch_size)
model.fit_generator(generator=train_generator, samples_per_epoch=len(X_train),
nb_epoch=nb_epoch, show_accuracy=True,
validation_data=valid_generator, nb_val_samples=len(X_test))
@sivaratna
Copy link

Hi,

Thank you very much for the open source example to use 1d CNN with keras. This is my first example to work with CNN and keras. I hope I will be able to use 1d CNN with keras for my problem after playing around with your example. Just wondering is "imdb" like a database of some datasets which you used in this example?

Thank you for your help,
Siva

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment