Skip to content

Instantly share code, notes, and snippets.

@mtreviso
Last active March 29, 2019 00:39
Show Gist options
  • Save mtreviso/5fcff3c6cd20324a536b5a052f779657 to your computer and use it in GitHub Desktop.
Save mtreviso/5fcff3c6cd20324a536b5a052f779657 to your computer and use it in GitHub Desktop.
File displayed on medium.
import torch
from torch import nn
class CRF(nn.Module):
"""
Linear-chain Conditional Random Field (CRF).
Args:
nb_labels (int): number of labels in your tagset, including special symbols.
bos_tag_id (int): integer representing the beginning of sentence symbol in
your tagset.
eos_tag_id (int): integer representing the end of sentence symbol in your tagset.
batch_first (bool): Whether the first dimension represents the batch dimension.
"""
def __init__(
self, nb_labels, bos_tag_id, eos_tag_id, batch_first=True
):
super().__init__()
self.nb_labels = nb_labels
self.BOS_TAG_ID = bos_tag_id
self.EOS_TAG_ID = eos_tag_id
self.batch_first = batch_first
self.transitions = nn.Parameter(torch.empty(self.nb_labels, self.nb_labels))
self.init_weights()
def init_weights(self):
# initialize transitions from a random uniform distribution between -0.1 and 0.1
nn.init.uniform_(self.transitions, -0.1, 0.1)
# enforce contraints (rows=from, columns=to) with a big negative number
# so exp(-10000) will tend to zero
# no transitions allowed to the beginning of sentence
self.transitions.data[:, self.BOS_TAG_ID] = -10000.0
# no transition alloed from the end of sentence
self.transitions.data[self.EOS_TAG_ID, :] = -10000.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment