Last active
March 29, 2019 00:39
-
-
Save mtreviso/5fcff3c6cd20324a536b5a052f779657 to your computer and use it in GitHub Desktop.
File displayed on medium.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class CRF(nn.Module): | |
""" | |
Linear-chain Conditional Random Field (CRF). | |
Args: | |
nb_labels (int): number of labels in your tagset, including special symbols. | |
bos_tag_id (int): integer representing the beginning of sentence symbol in | |
your tagset. | |
eos_tag_id (int): integer representing the end of sentence symbol in your tagset. | |
batch_first (bool): Whether the first dimension represents the batch dimension. | |
""" | |
def __init__( | |
self, nb_labels, bos_tag_id, eos_tag_id, batch_first=True | |
): | |
super().__init__() | |
self.nb_labels = nb_labels | |
self.BOS_TAG_ID = bos_tag_id | |
self.EOS_TAG_ID = eos_tag_id | |
self.batch_first = batch_first | |
self.transitions = nn.Parameter(torch.empty(self.nb_labels, self.nb_labels)) | |
self.init_weights() | |
def init_weights(self): | |
# initialize transitions from a random uniform distribution between -0.1 and 0.1 | |
nn.init.uniform_(self.transitions, -0.1, 0.1) | |
# enforce contraints (rows=from, columns=to) with a big negative number | |
# so exp(-10000) will tend to zero | |
# no transitions allowed to the beginning of sentence | |
self.transitions.data[:, self.BOS_TAG_ID] = -10000.0 | |
# no transition alloed from the end of sentence | |
self.transitions.data[self.EOS_TAG_ID, :] = -10000.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment