Last active
October 7, 2017 11:43
-
-
Save halegreen/e6dc31c66d237b2b9884981b674b87bb to your computer and use it in GitHub Desktop.
CBOW(Continuous Bag-of-Words) model implementation with tensorflow ,just need to change the generate_batch() function.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_batch(data, batch_size, skip_window): | |
""" | |
Generates a mini-batch of training data for the training CBOW | |
embedding model. | |
:param data (numpy.ndarray(dtype=int, shape=(corpus_size,)): holds the | |
training corpus, with words encoded as an integer | |
:param batch_size (int): size of the batch to generate | |
:param skip_window (int): number of words to both left and right that form | |
the context window for the target word. | |
Batch is a vector of shape (batch_size, 2*skip_window), with each entry for the batch containing all the context words, with the corresponding label being the word in the middle of the context | |
""" | |
global data_index | |
span = 2 * skip_window + 1 # [ skip_window target skip_window ] | |
batch = np.ndarray(shape=(batch_size, span - 1), dtype=np.int32) | |
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) | |
buffer = collections.deque(maxlen=span) | |
if data_index + span > len(data): | |
data_index = 0 | |
for _ in range(span): | |
buffer.append(data[data_index]) | |
data_index = (data_index + 1) % len(data) | |
for i in range(batch_size): | |
target = skip_window | |
target_to_avoid = [skip_window] | |
col_idx = 0 | |
for j in range(span): | |
if j == span // 2: ##skip the middel word | |
continue | |
batch[i, col_idx] = buffer[j] | |
col_idx += 1 | |
labels[i, 0] = buffer[target] | |
buffer.append(data[data_index]) | |
data_index = (data_index + 1) % len(data) | |
assert batch_size[0] == batch_size | |
assert batch_size[1] == span - 1 | |
return batch, labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment