Source code for model.ner_dataset

"""
.. module:: datasets
    :synopsis: datasets
 
.. moduleauthor:: Liyuan Liu
"""

from torch.utils.data import Dataset


[docs]class CRFDataset(Dataset): """Dataset Class for word-level model args: data_tensor (ins_num, seq_length): words label_tensor (ins_num, seq_length): labels mask_tensor (ins_num, seq_length): padding masks """ def __init__(self, data_tensor, label_tensor, mask_tensor): assert data_tensor.size(0) == label_tensor.size(0) assert data_tensor.size(0) == mask_tensor.size(0) self.data_tensor = data_tensor self.label_tensor = label_tensor self.mask_tensor = mask_tensor def __getitem__(self, index): return self.data_tensor[index], self.label_tensor[index], self.mask_tensor[index] def __len__(self): return self.data_tensor.size(0)
[docs]class CRFDataset_WC(Dataset): """Dataset Class for char-aware model args: forw_tensor (ins_num, seq_length): forward chars forw_index (ins_num, seq_length): index of forward chars back_tensor (ins_num, seq_length): backward chars back_index (ins_num, seq_length): index of backward chars word_tensor (ins_num, seq_length): words label_tensor (ins_num, seq_length): labels: mask_tensor (ins_num, seq_length): padding masks len_tensor (ins_num, 2): length of chars (dim0) and words (dim1) """ def __init__(self, forw_tensor, forw_index, back_tensor, back_index, word_tensor, label_tensor, mask_tensor, len_tensor): assert forw_tensor.size(0) == label_tensor.size(0) assert forw_tensor.size(0) == mask_tensor.size(0) assert forw_tensor.size(0) == forw_index.size(0) assert forw_tensor.size(0) == back_tensor.size(0) assert forw_tensor.size(0) == back_index.size(0) assert forw_tensor.size(0) == word_tensor.size(0) assert forw_tensor.size(0) == len_tensor.size(0) self.forw_tensor = forw_tensor self.forw_index = forw_index self.back_tensor = back_tensor self.back_index = back_index self.word_tensor = word_tensor self.label_tensor = label_tensor self.mask_tensor = mask_tensor self.len_tensor = len_tensor def __getitem__(self, index): return self.forw_tensor[index], self.forw_index[index], self.back_tensor[index], self.back_index[index], self.word_tensor[index], self.label_tensor[index], self.mask_tensor[index], self.len_tensor[index] def __len__(self): return self.forw_tensor.size(0)