"""
.. module:: lm_lstm_crf
:synopsis: lm_lstm_crf
.. moduleauthor:: Liyuan Liu
"""
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import numpy as np
import model.crf as crf
import model.utils as utils
import model.highway as highway
[docs]class LM_LSTM_CRF(nn.Module):
"""LM_LSTM_CRF model
args:
tagset_size: size of label set
char_size: size of char dictionary
char_dim: size of char embedding
char_hidden_dim: size of char-level lstm hidden dim
char_rnn_layers: number of char-level lstm layers
embedding_dim: size of word embedding
word_hidden_dim: size of word-level blstm hidden dim
word_rnn_layers: number of word-level lstm layers
vocab_size: size of word dictionary
dropout_ratio: dropout ratio
large_CRF: use CRF_L or not, refer model.crf.CRF_L and model.crf.CRF_S for more details
if_highway: use highway layers or not
in_doc_words: number of words that occurred in the corpus (used for language model prediction)
highway_layers: number of highway layers
"""
def __init__(self, tagset_size, char_size, char_dim, char_hidden_dim, char_rnn_layers, embedding_dim, word_hidden_dim, word_rnn_layers, vocab_size, dropout_ratio, large_CRF=True, if_highway = False, in_doc_words = 2, highway_layers = 1):
super(LM_LSTM_CRF, self).__init__()
self.char_dim = char_dim
self.char_hidden_dim = char_hidden_dim
self.char_size = char_size
self.word_dim = embedding_dim
self.word_hidden_dim = word_hidden_dim
self.word_size = vocab_size
self.if_highway = if_highway
self.char_embeds = nn.Embedding(char_size, char_dim)
self.forw_char_lstm = nn.LSTM(char_dim, char_hidden_dim, num_layers=char_rnn_layers, bidirectional=False, dropout=dropout_ratio)
self.back_char_lstm = nn.LSTM(char_dim, char_hidden_dim, num_layers=char_rnn_layers, bidirectional=False, dropout=dropout_ratio)
self.char_rnn_layers = char_rnn_layers
self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
self.word_lstm = nn.LSTM(embedding_dim + char_hidden_dim * 2, word_hidden_dim // 2, num_layers=word_rnn_layers, bidirectional=True, dropout=dropout_ratio)
self.word_rnn_layers = word_rnn_layers
self.dropout = nn.Dropout(p=dropout_ratio)
self.tagset_size = tagset_size
if large_CRF:
self.crf = crf.CRF_L(word_hidden_dim, tagset_size)
else:
self.crf = crf.CRF_S(word_hidden_dim, tagset_size)
if if_highway:
self.forw2char = highway.hw(char_hidden_dim, num_layers=highway_layers, dropout_ratio=dropout_ratio)
self.back2char = highway.hw(char_hidden_dim, num_layers=highway_layers, dropout_ratio=dropout_ratio)
self.forw2word = highway.hw(char_hidden_dim, num_layers=highway_layers, dropout_ratio=dropout_ratio)
self.back2word = highway.hw(char_hidden_dim, num_layers=highway_layers, dropout_ratio=dropout_ratio)
self.fb2char = highway.hw(2 * char_hidden_dim, num_layers=highway_layers, dropout_ratio=dropout_ratio)
self.char_pre_train_out = nn.Linear(char_hidden_dim, char_size)
self.word_pre_train_out = nn.Linear(char_hidden_dim, in_doc_words)
self.batch_size = 1
self.word_seq_length = 1
[docs] def set_batch_size(self, bsize):
"""
set batch size
"""
self.batch_size = bsize
[docs] def set_batch_seq_size(self, sentence):
"""
set batch size and sequence length
"""
tmp = sentence.size()
self.word_seq_length = tmp[0]
self.batch_size = tmp[1]
[docs] def rand_init_embedding(self):
"""
random initialize char-level embedding
"""
utils.init_embedding(self.char_embeds.weight)
[docs] def load_pretrained_word_embedding(self, pre_word_embeddings):
"""
load pre-trained word embedding
args:
pre_word_embeddings (self.word_size, self.word_dim) : pre-trained embedding
"""
assert (pre_word_embeddings.size()[1] == self.word_dim)
self.word_embeds.weight = nn.Parameter(pre_word_embeddings)
[docs] def rand_init(self, init_char_embedding=True, init_word_embedding=False):
"""
random initialization
args:
init_char_embedding: random initialize char embedding or not
init_word_embedding: random initialize word embedding or not
"""
if init_char_embedding:
utils.init_embedding(self.char_embeds.weight)
if init_word_embedding:
utils.init_embedding(self.word_embeds.weight)
if self.if_highway:
self.forw2char.rand_init()
self.back2char.rand_init()
self.forw2word.rand_init()
self.back2word.rand_init()
self.fb2char.rand_init()
utils.init_lstm(self.forw_char_lstm)
utils.init_lstm(self.back_char_lstm)
utils.init_lstm(self.word_lstm)
utils.init_linear(self.char_pre_train_out)
utils.init_linear(self.word_pre_train_out)
self.crf.rand_init()
[docs] def word_pre_train_forward(self, sentence, position, hidden=None):
"""
output of forward language model
args:
sentence (char_seq_len, batch_size): char-level representation of sentence
position (word_seq_len, batch_size): position of blank space in char-level representation of sentence
hidden: initial hidden state
return:
language model output (word_seq_len, in_doc_word), hidden
"""
embeds = self.char_embeds(sentence)
d_embeds = self.dropout(embeds)
lstm_out, hidden = self.forw_char_lstm(d_embeds)
tmpsize = position.size()
position = position.unsqueeze(2).expand(tmpsize[0], tmpsize[1], self.char_hidden_dim)
select_lstm_out = torch.gather(lstm_out, 0, position)
d_lstm_out = self.dropout(select_lstm_out).view(-1, self.char_hidden_dim)
if self.if_highway:
char_out = self.forw2word(d_lstm_out)
d_char_out = self.dropout(char_out)
else:
d_char_out = d_lstm_out
pre_score = self.word_pre_train_out(d_char_out)
return pre_score, hidden
[docs] def word_pre_train_backward(self, sentence, position, hidden=None):
"""
output of backward language model
args:
sentence (char_seq_len, batch_size): char-level representation of sentence (inverse order)
position (word_seq_len, batch_size): position of blank space in inversed char-level representation of sentence
hidden: initial hidden state
return:
language model output (word_seq_len, in_doc_word), hidden
"""
embeds = self.char_embeds(sentence)
d_embeds = self.dropout(embeds)
lstm_out, hidden = self.back_char_lstm(d_embeds)
tmpsize = position.size()
position = position.unsqueeze(2).expand(tmpsize[0], tmpsize[1], self.char_hidden_dim)
select_lstm_out = torch.gather(lstm_out, 0, position)
d_lstm_out = self.dropout(select_lstm_out).view(-1, self.char_hidden_dim)
if self.if_highway:
char_out = self.back2word(d_lstm_out)
d_char_out = self.dropout(char_out)
else:
d_char_out = d_lstm_out
pre_score = self.word_pre_train_out(d_char_out)
return pre_score, hidden
[docs] def forward(self, forw_sentence, forw_position, back_sentence, back_position, word_seq, hidden=None):
'''
args:
forw_sentence (char_seq_len, batch_size) : char-level representation of sentence
forw_position (word_seq_len, batch_size) : position of blank space in char-level representation of sentence
back_sentence (char_seq_len, batch_size) : char-level representation of sentence (inverse order)
back_position (word_seq_len, batch_size) : position of blank space in inversed char-level representation of sentence
word_seq (word_seq_len, batch_size) : word-level representation of sentence
hidden: initial hidden state
return:
crf output (word_seq_len, batch_size, tag_size, tag_size), hidden
'''
self.set_batch_seq_size(forw_position)
#embedding layer
forw_emb = self.char_embeds(forw_sentence)
back_emb = self.char_embeds(back_sentence)
#dropout
d_f_emb = self.dropout(forw_emb)
d_b_emb = self.dropout(back_emb)
#forward the whole sequence
forw_lstm_out, _ = self.forw_char_lstm(d_f_emb)#seq_len_char * batch * char_hidden_dim
back_lstm_out, _ = self.back_char_lstm(d_b_emb)#seq_len_char * batch * char_hidden_dim
#select predict point
forw_position = forw_position.unsqueeze(2).expand(self.word_seq_length, self.batch_size, self.char_hidden_dim)
select_forw_lstm_out = torch.gather(forw_lstm_out, 0, forw_position)
back_position = back_position.unsqueeze(2).expand(self.word_seq_length, self.batch_size, self.char_hidden_dim)
select_back_lstm_out = torch.gather(back_lstm_out, 0, back_position)
fb_lstm_out = self.dropout(torch.cat((select_forw_lstm_out, select_back_lstm_out), dim=2))
if self.if_highway:
char_out = self.fb2char(fb_lstm_out)
d_char_out = self.dropout(char_out)
else:
d_char_out = fb_lstm_out
#word
word_emb = self.word_embeds(word_seq)
d_word_emb = self.dropout(word_emb)
#combine
word_input = torch.cat((d_word_emb, d_char_out), dim = 2)
#word level lstm
lstm_out, _ = self.word_lstm(word_input)
d_lstm_out = self.dropout(lstm_out)
#convert to crf
crf_out = self.crf(d_lstm_out)
crf_out = crf_out.view(self.word_seq_length, self.batch_size, self.tagset_size, self.tagset_size)
return crf_out