Skip to content

Instantly share code, notes, and snippets.

@halegreen
Last active October 7, 2017 11:43
Show Gist options
  • Save halegreen/e6dc31c66d237b2b9884981b674b87bb to your computer and use it in GitHub Desktop.
Save halegreen/e6dc31c66d237b2b9884981b674b87bb to your computer and use it in GitHub Desktop.
CBOW(Continuous Bag-of-Words) model implementation with tensorflow ,just need to change the generate_batch() function.
def generate_batch(data, batch_size, skip_window):
"""
Generates a mini-batch of training data for the training CBOW
embedding model.
:param data (numpy.ndarray(dtype=int, shape=(corpus_size,)): holds the
training corpus, with words encoded as an integer
:param batch_size (int): size of the batch to generate
:param skip_window (int): number of words to both left and right that form
the context window for the target word.
Batch is a vector of shape (batch_size, 2*skip_window), with each entry for the batch containing all the context words, with the corresponding label being the word in the middle of the context
"""
global data_index
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
batch = np.ndarray(shape=(batch_size, span - 1), dtype=np.int32)
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
buffer = collections.deque(maxlen=span)
if data_index + span > len(data):
data_index = 0
for _ in range(span):
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
for i in range(batch_size):
target = skip_window
target_to_avoid = [skip_window]
col_idx = 0
for j in range(span):
if j == span // 2: ##skip the middel word
continue
batch[i, col_idx] = buffer[j]
col_idx += 1
labels[i, 0] = buffer[target]
buffer.append(data[data_index])
data_index = (data_index + 1) % len(data)
assert batch_size[0] == batch_size
assert batch_size[1] == span - 1
return batch, labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment