diff --git a/.travis.yml b/.travis.yml index f1411e0ef6..036bab819e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,6 +14,7 @@ addons: before_install: # Install CPU version of PyTorch. - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install torch==1.2.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html; fi + - pip install --upgrade setuptools - pip install -r requirements.opt.txt - python setup.py install env: diff --git a/bert_ckp_convert.py b/bert_ckp_convert.py new file mode 100755 index 0000000000..c8d1814b49 --- /dev/null +++ b/bert_ckp_convert.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +""" Convert weights of huggingface Bert to onmt Bert""" +from argparse import ArgumentParser +import torch +from onmt.encoders.bert import BertEncoder +from onmt.models.bert_generators import BertPreTrainingHeads +from onmt.modules.bert_embeddings import BertEmbeddings +from collections import OrderedDict +import re + + +def decrement(matched): + value = int(matched.group(1)) + if value < 1: + raise ValueError('Value Error when converting string') + string = "bert.encoder.layer.{}.output.LayerNorm".format(value-1) + return string + + +def mapping_key(key, max_layers): + if 'bert.embeddings' in key: + key = key + + elif 'bert.encoder' in key: + # convert layer_norm weights + key = re.sub(r'bert.encoder.0.layer_norm\.(.*)', + r'bert.embeddings.LayerNorm.\1', key) + key = re.sub(r'bert.encoder\.(\d+)\.layer_norm', + decrement, key) + # convert attention weights + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_keys\.(.*)', + r'bert.encoder.layer.\1.attention.self.key.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_values\.(.*)', + r'bert.encoder.layer.\1.attention.self.value.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.linear_query\.(.*)', + r'bert.encoder.layer.\1.attention.self.query.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.self_attn.final_linear\.(.*)', + r'bert.encoder.layer.\1.attention.output.dense.\2', key) + # convert feed forward weights + key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.layer_norm\.(.*)', + r'bert.encoder.layer.\1.attention.output.LayerNorm.\2', + key) + key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.w_1\.(.*)', + r'bert.encoder.layer.\1.intermediate.dense.\2', key) + key = re.sub(r'bert.encoder\.(\d+)\.feed_forward.w_2\.(.*)', + r'bert.encoder.layer.\1.output.dense.\2', key) + + elif 'bert.layer_norm' in key: + key = re.sub(r'bert.layer_norm', + r'bert.encoder.layer.' + str(max_layers - 1) + + '.output.LayerNorm', key) + elif 'bert.pooler' in key: + key = key + elif 'generator.next_sentence' in key: + key = re.sub(r'generator.next_sentence.linear\.(.*)', + r'cls.seq_relationship.\1', key) + elif 'generator.mask_lm' in key: + key = re.sub(r'generator.mask_lm.bias', + r'cls.predictions.bias', key) + key = re.sub(r'generator.mask_lm.decode.weight', + r'cls.predictions.decoder.weight', key) + key = re.sub(r'generator.mask_lm.transform.dense\.(.*)', + r'cls.predictions.transform.dense.\1', key) + key = re.sub(r'generator.mask_lm.transform.layer_norm\.(.*)', + r'cls.predictions.transform.LayerNorm.\1', key) + else: + raise KeyError("Unexpected keys! Please provide HuggingFace weights") + return key + + +def convert_bert_weights(bert_model, weights, n_layers=12): + bert_model_keys = bert_model.state_dict().keys() + bert_weights = OrderedDict() + generator_weights = OrderedDict() + model_weights = {"bert": bert_weights, + "generator": generator_weights} + hugface_keys = weights.keys() + try: + for key in bert_model_keys: + hugface_key = mapping_key(key, n_layers) + if hugface_key not in hugface_keys: + if 'LayerNorm' in hugface_key: + # Fix LayerNorm of old huggingface ckp + hugface_key = re.sub(r'LayerNorm.weight', + r'LayerNorm.gamma', hugface_key) + hugface_key = re.sub(r'LayerNorm.bias', + r'LayerNorm.beta', hugface_key) + if hugface_key in hugface_keys: + print("[OLD Weights file]gamma/beta is used in " + + "naming BertLayerNorm. Mapping succeed.") + else: + raise KeyError("Failed fix LayerNorm %s, check file" + % hugface_key) + else: + raise KeyError("Mapped key %s not in weight file" + % hugface_key) + if 'generator' not in key: + onmt_key = re.sub(r'bert\.(.*)', r'\1', key) + model_weights['bert'][onmt_key] = weights[hugface_key] + else: + onmt_key = re.sub(r'generator\.(.*)', r'\1', key) + model_weights['generator'][onmt_key] = weights[hugface_key] + except KeyError: + print("Unsuccessful convert.") + raise + return model_weights + + +def main(): + parser = ArgumentParser() + parser.add_argument("--layers", type=int, default=None, required=True) + + parser.add_argument("--bert_model_weights_file", "-i", type=str, + default=None, required=True, help="Path to the " + "huggingface Bert weights file download from " + "https://github.com/huggingface/pytorch-transformers") + + parser.add_argument("--output_name", "-o", type=str, + default=None, required=True, + help="output onmt version Bert weight file Path") + args = parser.parse_args() + + print("Model contain {} layers.".format(args.layers)) + + print("Load weights from {}.".format(args.bert_model_weights_file)) + + bert_weights = torch.load(args.bert_model_weights_file) + embeddings = BertEmbeddings(28996) # vocab don't bother the conversion + bert_encoder = BertEncoder(embeddings) + generator = BertPreTrainingHeads(bert_encoder.d_model, + embeddings.vocab_size) + bertlm = torch.nn.Sequential(OrderedDict([ + ('bert', bert_encoder), + ('generator', generator)])) + model_weights = convert_bert_weights(bertlm, bert_weights, args.layers) + + ckp = {'model': model_weights['bert'], + 'generator': model_weights['generator']} + + outfile = args.output_name + print("Converted weights file in {}".format(outfile)) + torch.save(ckp, outfile) + + +if __name__ == '__main__': + main() diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 4cd19c986e..aa22fb67f3 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -150,6 +150,150 @@ will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, **Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing. +## How do I use BERT? +BERT is a general-purpose "language understanding" model introduced by Google, it can be used for various downstream NLP tasks and easily adapted into a new task using transfer learning. Using BERT has two stages: Pre-training and fine-tuning. But as the Pre-training is super expensive, we do not recommand you to pre-train a BERT from scratch. Instead loading weights from a existing pretrained model and fine-tuning is suggested. Currently we support sentence(-pair) classification and token tagging downstream task. + +### Use pretrained BERT weights +To use weights from a existing huggingface's pretrained model, we provide you a script to convert huggingface's BERT model weights into ours. + +Usage: +```bash +python bert_ckp_convert.py --layers NUMBER_LAYER + --bert_model_weights_file HUGGINGFACE_BERT_WEIGHTS + --output_name OUTPUT_FILE +``` +* Go to [modeling_bert.py](https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py) to check all available pretrained model. + +### Preprocess train/dev dataset +To generate train/dev data for BERT, you can use preprocess_bert.py by providing raw data in certain format and choose a BERT Tokenizer model `-vm` coherent with pretrained model. +#### Classification +For classification dataset, we support input file in csv or plain text file format. + +* For csv file, each line should contain a instance with one or two sentence column and one column for label as in GLUE dataset, other csv format dataset should be compatible. A typical csv file should be like: + + | ID | SENTENCE_A | SENTENCE_B(Optional) | LABEL | + | -- | ------------------------ | ------------------------ | ------- | + | 0 | sentence a of instance 0 | sentence b of instance 0 | class 2 | + | 1 | sentence a of instance 1 | sentence b of instance 1 | class 1 | + | ...| ... | ... | ... | + + Then calling `preprocess_bert.py` and providing input sentence columns and label column: + ```bash + python preprocess_bert.py --task classification --corpus_type {train, valid} + --file_type csv [--delimiter '\t'] [--skip_head] + --input_columns 1 2 --label_column 3 + --data DATA_DIR/FILENAME.tsv + --save_data dataset + -vm bert-base-cased --max_seq_len 256 [--do_lower_case] + [--sort_label_vocab] [--do_shuffle] + ``` +* For plain text format, we accept multiply files as input, each file contains instances for one specific class. Each line of the file contains one instance which could be composed by one sentence or two separated by ` ||| `. All input file should be arranged in following way: + ``` + . + ├── LABEL_A + │   └── FILE_WITH_INSTANCE_A + └── LABEL_B + └── FILE_WITH_INSTANCE_B + ``` + Then call `preprocess_bert.py` as following to generate training data: + ```bash + python preprocess_bert.py --task classification --corpus_type {'train', 'valid'} + --file_type txt [--delimiter ' ||| '] + --data DIR_BASE/LABEL_1/FILENAME1 ... DIR_BASE/LABEL_N/FILENAME2 + --save_data dataset + --vocab_model {bert-base-uncased,...} + --max_seq_len 256 [--do_lower_case] + [--sort_label_vocab] [--do_shuffle] + ``` +#### Tagging +For tagging dataset, we support input file in plain text file format. + +Each line of the input file should contain one token and its tagging, different fields should be separated by a delimiter(default space) while sentences are separated by a blank line. + +A example of input file is given below (`Token X X Label`): + ``` + -DOCSTART- -X- O O + + CRICKET NNP I-NP O + - : O O + LEICESTERSHIRE NNP I-NP I-ORG + TAKE NNP I-NP O + OVER IN I-PP O + AT NNP I-NP O + TOP NNP I-NP O + AFTER NNP I-NP O + INNINGS NNP I-NP O + VICTORY NN I-NP O + . . O O + + LONDON NNP I-NP I-LOC + 1996-08-30 CD I-NP O + + ``` +Then call preprocess_bert.py providing token column and label column as following to generate training data for token tagging task: + ```bash + python preprocess_bert.py --task tagging --corpus_type {'train', 'valid'} + --file_type txt [--delimiter ' '] + --input_columns 1 --label_column 3 + --data DATA_DIR/FILENAME + --save_data dataset + --vocab_model {bert-base-uncased,...} + --max_seq_len 256 [--do_lower_case] + [--sort_label_vocab] [--do_shuffle] + ``` +#### Pretraining objective +Even if it's not recommended, we also provide you a script to generate pretraining dataset as you may want to finetuning a existing pretrained model on masked language modeling and next sentence prediction. + +The script expects a single file as input, consisting of untokenized text, with one sentence per line, and one blank line between documents. +A usage example is given below: +```bash +python pregenerate_bert_training_data.py --input_file INPUT_FILE + --output_dir OUTPUT_DIR + --output_name OUTPUT_FILE_PREFIX + --corpus_type {'train', 'valid'} + --vocab_model {bert-base-uncased,...} + [--do_lower_case] [--do_whole_word_mask] [--reduce_memory] + --epochs_to_generate 2 + --max_seq_len 128 + --short_seq_prob 0.1 --masked_lm_prob 0.15 + --max_predictions_per_seq 20 + [--save_json] +``` + +### Training +After preprocessed data have been generated, you can load weights from a pretrained BERT and transfer it to downstream task with a task specific output head. This task specific head will be initialized by a method you choose if there is no such architecture in weights file specified by `--train_from`. Among all available optimizers, you are suggested to use `--optim bertadam` as it is the method used to train BERT. `warmup_steps` could be set as 1% of `train_steps` as in original paper if use linear decay method. + +A usage example is given below: +```bash +python train.py --is_bert --task_type {pretraining, classification, tagging} + --data PREPROCESSED_DATAIFILE + --train_from CONVERTED_CHECKPOINT.pt [--param_init 0.1] + --save_model MODEL_PREFIX --save_checkpoint_steps 1000 + [--world_size 2] [--gpu_ranks 0 1] + --word_vec_size 768 --rnn_size 768 + --layers 12 --heads 8 --transformer_ff 3072 + --activation gelu --dropout 0.1 --average_decay 0.0001 + --batch_size 8 [--accum_count 4] --optim bertadam [--max_grad_norm 0] + --learning_rate 2e-5 --learning_rate_decay 0.99 --decay_method linear + --train_steps 4000 --valid_steps 200 --warmup_steps 40 + [--report_every 10] [--seed 3435] + [--tensorboard] [--tensorboard_log_dir LOGDIR] +``` + +### Predicting +After training, you can use `predict.py` to generate predicting for raw file. Make sure to use the same BERT Tokenizer model `--vocab_model` as in training data. + +For classification task, file to be predicted should be one sentence(-pair) a line with ` ||| ` separating sentence. +For tagging task, each line should be a tokenized sentence with tokens separated by space. + +Usage: +```bash +python predict.py --task {classification, tagging} + --model ONMT_BERT_CHECKPOINT.pt + --vocab_model bert-base-uncased [--do_lower_case] + --data DATA_2_PREDICT [--delimiter {' ||| ', ' '}] --max_seq_len 256 + --output PREDICT.txt [--batch_size 8] [--gpu 1] [--seed 3435] +``` ## Can I get word alignment while translating? ### Raw alignments from averaging Transformer attention heads diff --git a/docs/source/refs.bib b/docs/source/refs.bib index bee6f86ec8..b25ab7879a 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -436,6 +436,55 @@ @article{DBLP:journals/corr/MartinsA16 bibsource = {dblp computer science bibliography, https://dblp.org} } +@article{DBLP:journals/corr/abs-1711-05101, + author = {Ilya Loshchilov and + Frank Hutter}, + title = {Fixing Weight Decay Regularization in Adam}, + journal = {CoRR}, + volume = {abs/1711.05101}, + year = {2017}, + url = {http://arxiv.org/abs/1711.05101}, + archivePrefix = {arXiv}, + eprint = {1711.05101}, + timestamp = {Mon, 13 Aug 2018 16:48:18 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1711-05101}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@article{DBLP:journals/corr/abs-1810-04805, + author = {Jacob Devlin and + Ming{-}Wei Chang and + Kenton Lee and + Kristina Toutanova}, + title = {{BERT:} Pre-training of Deep Bidirectional Transformers for Language + Understanding}, + journal = {CoRR}, + volume = {abs/1810.04805}, + year = {2018}, + url = {http://arxiv.org/abs/1810.04805}, + archivePrefix = {arXiv}, + eprint = {1810.04805}, + timestamp = {Tue, 30 Oct 2018 20:39:56 +0100}, + biburl = {https://dblp.org/rec/bib/journals/corr/abs-1810-04805}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@article{DBLP:journals/corr/HendrycksG16, + author = {Dan Hendrycks and + Kevin Gimpel}, + title = {Bridging Nonlinearities and Stochastic Regularizers with Gaussian + Error Linear Units}, + journal = {CoRR}, + volume = {abs/1606.08415}, + year = {2016}, + url = {http://arxiv.org/abs/1606.08415}, + archivePrefix = {arXiv}, + eprint = {1606.08415}, + timestamp = {Mon, 13 Aug 2018 16:46:20 +0200}, + biburl = {https://dblp.org/rec/bib/journals/corr/HendrycksG16}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + @inproceedings{garg2019jointly, title = {Jointly Learning to Align and Translate with Transformer Models}, author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias}, diff --git a/onmt/__init__.py b/onmt/__init__.py index 3d2a1650b3..9a6b429267 100644 --- a/onmt/__init__.py +++ b/onmt/__init__.py @@ -2,9 +2,9 @@ from __future__ import division, print_function import onmt.inputters +import onmt.models import onmt.encoders import onmt.decoders -import onmt.models import onmt.utils import onmt.modules from onmt.trainer import Trainer diff --git a/onmt/bin/train.py b/onmt/bin/train.py index b3115acd40..86068044ec 100755 --- a/onmt/bin/train.py +++ b/onmt/bin/train.py @@ -27,8 +27,12 @@ def train(opt): logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) - logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) - vocab = checkpoint['vocab'] + if 'vocab' in checkpoint: + logger.info('Loading vocab from checkpoint at %s.' + % opt.train_from) + vocab = checkpoint['vocab'] + else: + vocab = torch.load(opt.data + '.vocab.pt') else: vocab = torch.load(opt.data + '.vocab.pt') @@ -114,19 +118,38 @@ def next_batch(device_id): for device_id, q in cycle(enumerate(queues)): b.dataset = None - if isinstance(b.src, tuple): - b.src = tuple([_.to(torch.device(device_id)) - for _ in b.src]) + if opt.is_bert: + if isinstance(b.tokens, tuple): + b.tokens = tuple([_.to(torch.device(device_id)) + for _ in b.tokens]) + else: + b.tokens = b.tokens.to(torch.device(device_id)) + b.segment_ids = b.segment_ids.to(torch.device(device_id)) + if opt.task_type == 'pretraining': + b.is_next = b.is_next.to(torch.device(device_id)) + b.lm_labels_ids = b.lm_labels_ids.to(torch.device(device_id)) + elif opt.task_type == 'classification': + b.category = b.category.to(torch.device(device_id)) + elif opt.task_type == 'prediction' or opt.task_type == 'tagging': + b.token_labels = b.token_labels.to( + torch.device(device_id)) + else: + raise ValueError("task type Error") + else: - b.src = b.src.to(torch.device(device_id)) - b.tgt = b.tgt.to(torch.device(device_id)) - b.indices = b.indices.to(torch.device(device_id)) - b.alignment = b.alignment.to(torch.device(device_id)) \ - if hasattr(b, 'alignment') else None - b.src_map = b.src_map.to(torch.device(device_id)) \ - if hasattr(b, 'src_map') else None - b.align = b.align.to(torch.device(device_id)) \ - if hasattr(b, 'align') else None + if isinstance(b.src, tuple): + b.src = tuple([_.to(torch.device(device_id)) + for _ in b.src]) + else: + b.src = b.src.to(torch.device(device_id)) + b.tgt = b.tgt.to(torch.device(device_id)) + b.indices = b.indices.to(torch.device(device_id)) + b.alignment = b.alignment.to(torch.device(device_id)) \ + if hasattr(b, 'alignment') else None + b.src_map = b.src_map.to(torch.device(device_id)) \ + if hasattr(b, 'src_map') else None + b.align = b.align.to(torch.device(device_id)) \ + if hasattr(b, 'align') else None # hack to dodge unpicklable `dict_keys` b.fields = list(b.fields) diff --git a/onmt/encoders/__init__.py b/onmt/encoders/__init__.py index 53daac6d82..d59b097fea 100644 --- a/onmt/encoders/__init__.py +++ b/onmt/encoders/__init__.py @@ -6,11 +6,12 @@ from onmt.encoders.mean_encoder import MeanEncoder from onmt.encoders.audio_encoder import AudioEncoder from onmt.encoders.image_encoder import ImageEncoder +from onmt.encoders.bert import BertEncoder str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, "transformer": TransformerEncoder, "img": ImageEncoder, - "audio": AudioEncoder, "mean": MeanEncoder} + "audio": AudioEncoder, "mean": MeanEncoder, "bert": BertEncoder} __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", - "MeanEncoder", "str2enc"] + "MeanEncoder", "str2enc", "BertEncoder"] diff --git a/onmt/encoders/bert.py b/onmt/encoders/bert.py new file mode 100644 index 0000000000..03f714d00e --- /dev/null +++ b/onmt/encoders/bert.py @@ -0,0 +1,120 @@ +import torch.nn as nn +from onmt.encoders.transformer import TransformerEncoderLayer + + +class BertEncoder(nn.Module): + """BERT Encoder: A Transformer Encoder with LayerNorm and BertPooler. + :cite:`DBLP:journals/corr/abs-1810-04805` + + Args: + embeddings (onmt.modules.BertEmbeddings): embeddings to use + num_layers (int): number of encoder layers. + d_model (int): size of the model + heads (int): number of heads + d_ff (int): size of the inner FF layer + dropout (float): dropout parameters + """ + + def __init__(self, embeddings, num_layers=12, d_model=768, heads=12, + d_ff=3072, dropout=0.1, attention_dropout=0.1, + max_relative_positions=0): + super(BertEncoder, self).__init__() + self.num_layers = num_layers + self.d_model = d_model + self.heads = heads + self.dropout = dropout + # Feed-Forward size should be 4*d_model as in paper + self.d_ff = d_ff + + self.embeddings = embeddings + # Transformer Encoder Block + self.encoder = nn.ModuleList( + [TransformerEncoderLayer(d_model, heads, d_ff, + dropout, attention_dropout, + max_relative_positions=max_relative_positions, + activation='gelu', is_bert=True) for _ in range(num_layers)]) + + self.layer_norm = nn.LayerNorm(d_model, eps=1e-12) + self.pooler = BertPooler(d_model) + + @classmethod + def from_opt(cls, opt, embeddings): + """Alternate constructor.""" + return cls( + embeddings, + opt.layers, + opt.word_vec_size, + opt.heads, + opt.transformer_ff, + opt.dropout[0] if type(opt.dropout) is list else opt.dropout, + opt.attention_dropout[0] if type(opt.attention_dropout) + is list else opt.attention_dropout, + opt.max_relative_positions) + + def forward(self, input_ids, token_type_ids=None, input_mask=None, + output_all_encoded_layers=False): + """ + Args: + input_ids (Tensor): ``(B, S)``, padding ids=0 + token_type_ids (Tensor): ``(B, S)``, A(0), B(1), pad(0) + input_mask (Tensor): ``(B, S)``, 1 for masked (padding) + output_all_encoded_layers (bool): if out contain all hidden layer + Returns: + all_encoder_layers (list of Tensor): ``(B, S, H)``, token level + pooled_output (Tensor): ``(B, H)``, sequence level + """ + + # OpenNMT waiting for mask of size [B, 1, T], + # while in MultiHeadAttention part2 -> [B, 1, 1, T] + if input_mask is None: + # shape: 2D tensor [batch, seq] + padding_idx = self.embeddings.word_padding_idx + # shape: 2D tensor [batch, seq]: 1 for tokens, 0 for paddings + input_mask = input_ids.data.eq(padding_idx) + # [batch, seq] -> [batch, 1, seq] + attention_mask = input_mask.unsqueeze(1) + + # embedding vectors: [batch, seq, hidden_size] + out = self.embeddings(input_ids, token_type_ids) + all_encoder_layers = [] + for layer in self.encoder: + out = layer(out, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(self.layer_norm(out)) + out = self.layer_norm(out) + if not output_all_encoded_layers: + all_encoder_layers.append(out) + pooled_out = self.pooler(out) + return all_encoder_layers, pooled_out + + def update_dropout(self, dropout): + self.dropout = dropout + self.embeddings.update_dropout(dropout) + for layer in self.encoder: + layer.update_dropout(dropout) + + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + """A pooling block (Linear layer followed by Tanh activation). + + Args: + hidden_size (int): size of hidden layer. + """ + + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation_fn = nn.Tanh() + + def forward(self, hidden_states): + """hidden_states[:, 0, :] --> {Linear, Tanh} --> Returns. + + Args: + hidden_states (Tensor): last layer's hidden_states, ``(B, S, H)`` + Returns: + pooled_output (Tensor): transformed output of last layer's hidden + """ + + first_token_tensor = hidden_states[:, 0, :] # [batch, d_model] + pooled_output = self.activation_fn(self.dense(first_token_tensor)) + return pooled_output diff --git a/onmt/encoders/transformer.py b/onmt/encoders/transformer.py index 5c29d8ded1..fb9da4eb88 100644 --- a/onmt/encoders/transformer.py +++ b/onmt/encoders/transformer.py @@ -21,18 +21,37 @@ class TransformerEncoderLayer(nn.Module): heads (int): the number of head for MultiHeadedAttention. d_ff (int): the second-layer of the PositionwiseFeedForward. dropout (float): dropout probability(0-1.0). + activation (str): activation function to chose from + ['relu', 'gelu'] + is_bert (bool): default False. When set True, + layer_norm will be performed on the + direct connection of residual block. """ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, - max_relative_positions=0): + max_relative_positions=0, activation='relu', is_bert=False): super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiHeadedAttention( heads, d_model, dropout=attention_dropout, max_relative_positions=max_relative_positions) - self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout, + activation, is_bert) + self.layer_norm = nn.LayerNorm( + d_model, eps=1e-12 if is_bert else 1e-6) self.dropout = nn.Dropout(dropout) + self.is_bert = is_bert + + def residual(self, output, x): + """A Residual connection. + + Official BERT perform residual connection on layer normed input. + BERT's layer_norm is done before pass into next block while onmt's + layer_norm is performed at the begining. + """ + + maybe_norm = self.layer_norm(x) if self.is_bert else x + return output + maybe_norm def forward(self, inputs, mask): """ @@ -45,10 +64,11 @@ def forward(self, inputs, mask): * outputs ``(batch_size, src_len, model_dim)`` """ + input_norm = self.layer_norm(inputs) context, _ = self.self_attn(input_norm, input_norm, input_norm, mask=mask, attn_type="self") - out = self.dropout(context) + inputs + out = self.residual(self.dropout(context), inputs) return self.feed_forward(out) def update_dropout(self, dropout, attention_dropout): diff --git a/onmt/inputters/__init__.py b/onmt/inputters/__init__.py index 8476af1b70..458bec1f84 100644 --- a/onmt/inputters/__init__.py +++ b/onmt/inputters/__init__.py @@ -4,7 +4,7 @@ e.g., from a line of text to a sequence of embeddings. """ from onmt.inputters.inputter import \ - load_old_vocab, get_fields, OrderedIterator, \ + load_old_vocab, get_fields, get_bert_fields, OrderedIterator, \ build_vocab, old_style_vocab, filter_example from onmt.inputters.dataset_base import Dataset from onmt.inputters.text_dataset import text_sort_key, TextDataReader @@ -12,7 +12,8 @@ from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader from onmt.inputters.vec_dataset import vec_sort_key, VecDataReader from onmt.inputters.datareader_base import DataReaderBase - +from onmt.inputters.dataset_bert import BertDataset, bert_text_sort_key,\ + ClassifierDataset, TaggerDataset str2reader = { "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader, @@ -27,4 +28,5 @@ 'build_vocab', 'OrderedIterator', 'text_sort_key', 'img_sort_key', 'audio_sort_key', 'vec_sort_key', 'TextDataReader', 'ImageDataReader', 'AudioDataReader', - 'VecDataReader'] + 'VecDataReader', 'get_bert_fields', 'bert_text_sort_key', + 'BertDataset', 'ClassifierDataset', 'TaggerDataset'] diff --git a/onmt/inputters/dataset_bert.py b/onmt/inputters/dataset_bert.py new file mode 100644 index 0000000000..69d6306cb8 --- /dev/null +++ b/onmt/inputters/dataset_bert.py @@ -0,0 +1,263 @@ +import torch +from torchtext.data import Dataset as TorchtextDataset +from torchtext.data import Example +from random import random + + +def bert_text_sort_key(ex): + """Sort using the number of tokens in the sequence.""" + return len(ex.tokens) + + +def truncate_seq(tokens, max_num_tokens): + """Truncates a sequences randomly from front or back + to a maximum sequence length.""" + while True: + total_length = len(tokens) + if total_length <= max_num_tokens: + break + assert len(tokens) >= 1 + # We want to sometimes truncate from the front and sometimes + # from the back to add more randomness and avoid biases. + if random() < 0.5: + del tokens[0] + else: + tokens.pop() + + +def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens): + """Truncates a pair of sequences to a maximum sequence length. + Lifted from Google's BERT repo: create_pretraining_data.py in + https://github.com/google-research/bert/""" + + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_num_tokens: + break + trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b + assert len(trunc_tokens) >= 1 + + # We want to sometimes truncate from the front and sometimes from the + # back to add more randomness and avoid biases. + if random() < 0.5: + del trunc_tokens[0] + else: + trunc_tokens.pop() + + +def create_sentence_instance(sentence, tokenizer, + max_seq_length, random_trunc=False): + """Create single processed instance in BERT format. + + Args: + sentence (str): a raw single sentence. + tokenizer (onmt.utils.BertTokenizer): tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + random_trunc (bool): if false, trunc tail. + + Returns: + (list, list): + + * tokens_processed: ["[CLS]", sent_a, "[SEP]"] + * segment_ids: [0, ..., 0] + """ + tokens = tokenizer.tokenize(sentence) + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 2 + if len(tokens) > max_num_tokens: + if random_trunc is True: + truncate_seq(tokens, max_num_tokens) + else: + tokens = tokens[:max_num_tokens] + tokens_processed = ["[CLS]"] + tokens + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens) + 2)] + return tokens_processed, segment_ids + + +def create_sentence_pair_instance(sent_a, sent_b, tokenizer, max_seq_length): + """Create single processed instance in BERT format. + + Args: + sent_a (str): a raw single sentence. + sent_b (str): another raw single sentence. + tokenizer (onmt.utils.BertTokenizer): tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + + Returns: + (list, list): + + * tokens_processed: ["[CLS]", sent_a, "[SEP]", sent_b, "[SEP]"] + * segment_ids: [0, ..., 0, 1, ..., 1] + """ + tokens_a = tokenizer.tokenize(sent_a) + tokens_b = tokenizer.tokenize(sent_b) + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) + tokens_processed = ["[CLS]"] + tokens_a + ["[SEP]"] + tokens_b + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens_a) + 2)] + \ + [1 for _ in range(len(tokens_b) + 1)] + return tokens_processed, segment_ids + + +class BertDataset(TorchtextDataset): + """Defines a BERT dataset composed of Examples along with its Fields. + + Args: + fields_dict (dict[str, Field]): a dict containing all Field with + its name. + instances (Iterable[dict[]]): a list of dictionary, each dict + represent one Example with its field specified by fields_dict. + """ + + def __init__(self, fields_dict, instances, + sort_key=bert_text_sort_key, filter_pred=None): + self.sort_key = sort_key + examples = [] + ex_fields = {k: [(k, v)] for k, v in fields_dict.items()} + for instance in instances: + ex = Example.fromdict(instance, ex_fields) + examples.append(ex) + fields_list = list(fields_dict.items()) + + super(BertDataset, self).__init__(examples, fields_list, filter_pred) + + def __getattr__(self, attr): + # avoid infinite recursion when fields isn't defined + if 'fields' not in vars(self): + raise AttributeError + if attr in self.fields: + return (getattr(x, attr) for x in self.examples) + else: + raise AttributeError + + def save(self, path, remove_fields=True): + if remove_fields: + self.fields = [] + torch.save(self, path) + + +class ClassifierDataset(BertDataset): + """Defines a BERT dataset composed of Examples along with its Fields. + Fields include "tokens", "segment_ids", "category". + + Args: + fields_dict (dict[str, Field]): a dict containing all Field with + its name. + data (list[]): a list of sequence (sentence or sentence pair), + possible with its label becoming tuple(list[]). + tokenizer (onmt.utils.BertTokenizer): a tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in input sequence. + """ + + def __init__(self, fields_dict, data, tokenizer, + max_seq_len=256, delimiter=' ||| '): + if not isinstance(data, tuple): + data = data, [None for _ in range(len(data))] + instances = self.create_instances( + data, tokenizer, delimiter, max_seq_len) + super(ClassifierDataset, self).__init__(fields_dict, instances) + + def create_instances(self, data, tokenizer, + delimiter, max_seq_len): + """Return data instances in the form of list of dict. + + Args: + data (list[]): a list of sequence (sentence or sentence pair), + possible with its label becoming tuple(list[]). + tokenizer (onmt.utils.BertTokenizer): tokenizer to use on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in sequence. + + Returns: + instances (list of dict): list of sequence classification instance. + """ + + instances = [] + for sentence, label in zip(*data): + sentences = sentence.strip().split(delimiter, 1) + if len(sentences) == 2: + sent_a, sent_b = sentences + tokens, segment_ids = create_sentence_pair_instance( + sent_a, sent_b, tokenizer, max_seq_len) + else: + sentence = sentences[0] + tokens, segment_ids = create_sentence_instance( + sentence, tokenizer, max_seq_len, random_trunc=False) + instance = { + "tokens": tokens, + "segment_ids": segment_ids, + "category": label} + instances.append(instance) + return instances + + +class TaggerDataset(BertDataset): + """Defines a BERT dataset composed of Examples along with its Fields. + + Args: + fields_dict (dict[str, Field]): a dict containing all Field with + its name. + data (list of str|tuple of list): a list of sequence, each sequence is + composed with tokens that to be tagging. Can also combined with + its tags as tuple([tokens], [tags]). + tokenizer (onmt.utils.BertTokenizer): a tokenizer to be used on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in input sequence. + """ + + def __init__(self, fields_dict, data, tokenizer, + max_seq_len=256, delimiter=' '): + targer_field = fields_dict["token_labels"] + self.pad_tok = targer_field.pad_token + if hasattr(targer_field, 'vocab'): # when predicting + self.predict_tok = targer_field.vocab.itos[-1] + if not isinstance(data, tuple): + data = (data, [None for _ in range(len(data))]) + instances = self.create_instances( + data, tokenizer, delimiter, max_seq_len) + super(TaggerDataset, self).__init__(fields_dict, instances) + + def create_instances(self, datas, tokenizer, delimiter, max_seq_len): + """Return data instances in the form of list of dict. + + Args: + data (list[]): a list of sequence (sentence or sentence pair), + possible with its label becoming tuple(list[]). + tokenizer (onmt.utils.BertTokenizer): tokenizer to use on data. + max_seq_len (int): maximum length of sequence. + delimiter (str): delimiter used to separate tokens in sequence. + + Returns: + instances (list of dict): list of tokens tagging instance. + """ + + instances = [] + for words, taggings in zip(*datas): + if isinstance(words, str): # build from raw sentence + words = words.strip().split(delimiter) + if taggings is None: # when predicting + assert hasattr(self, 'predict_tok') + taggings = [self.predict_tok for _ in range(len(words))] + sentence = [] + tags = [] + max_num_tokens = max_seq_len - 2 + for word, tag in zip(words, taggings): + tokens = tokenizer.tokenize(word) + n_pad = len(tokens) - 1 + paded_tag = [tag] + [self.pad_tok] * n_pad + if len(sentence) + len(tokens) > max_num_tokens: + break + else: + sentence.extend(tokens) + tags.extend(paded_tag) + sentence = ["[CLS]"] + sentence + ["[SEP]"] + tags = [self.pad_tok] + tags + [self.pad_tok] + segment_ids = [0 for _ in range(len(sentence))] + instance = { + "tokens": sentence, + "segment_ids": segment_ids, + "token_labels": tags} + instances.append(instance) + return instances diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index d4c43342c0..aaaec8727d 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -185,6 +185,41 @@ def get_fields( return fields +def get_bert_fields(task='pretraining', pad='[PAD]', bos='[CLS]', + eos='[SEP]', unk='[UNK]'): + fields = {} + tokens = Field(sequential=True, use_vocab=True, pad_token=pad, + unk_token=unk, include_lengths=True, batch_first=True) + fields["tokens"] = tokens + + segment_ids = Field(use_vocab=False, dtype=torch.long, unk_token=None, + sequential=True, pad_token=0, batch_first=True) + fields["segment_ids"] = segment_ids + if task == 'pretraining': + is_next = Field(use_vocab=False, dtype=torch.long, + sequential=False, batch_first=True) # 0/1 + fields["is_next"] = is_next + + lm_labels_ids = Field(sequential=True, use_vocab=False, + pad_token=-1, batch_first=True) + fields["lm_labels_ids"] = lm_labels_ids + + elif task == 'classification': + category = LabelField(sequential=False, use_vocab=True, + pad_token=None, batch_first=True) + fields["category"] = category + + elif task == 'generation' or task == 'tagging': + token_labels = Field(sequential=True, use_vocab=True, unk_token=None, + pad_token=pad, batch_first=True) + fields["token_labels"] = token_labels + + else: + raise ValueError("task %s has not been implemented yet!" % task) + + return fields + + def load_old_vocab(vocab, data_type="text", dynamic_dict=False): """Update a legacy vocab/field format. @@ -396,6 +431,33 @@ def _build_fields_vocab(fields, counters, data_type, share_vocab, return fields +def _build_bert_fields_vocab(fields, counters, vocab_size, label_name=None, + tokens_min_frequency=1, vocab_size_multiple=1): + tokens_field = fields["tokens"] + tokens_counter = counters["tokens"] + # NOTE: Do not use _build_field_vocab + # as the special tokens is fixed in origin bert vocab file + # _build_field_vocab(tokens_field, tokens_counter, + # size_multiple=vocab_size_multiple, + # max_size=vocab_size, min_freq=tokens_min_frequency) + tokens_field.vocab = tokens_field.vocab_cls( + tokens_counter, specials=[], max_size=vocab_size, + min_freq=tokens_min_frequency) + if vocab_size_multiple > 1: + _pad_vocab_to_multiple(tokens_field.vocab, vocab_size_multiple) + + if label_name is not None: + label_field = fields[label_name] + label_counter = counters[label_name] + all_specials = [label_field.unk_token, label_field.pad_token, + label_field.init_token, label_field.eos_token] + specials = [tok for tok in all_specials if tok is not None] + + label_field.vocab = label_field.vocab_cls( + label_counter, specials=specials) + return fields + + def build_vocab(train_dataset_files, fields, data_type, share_vocab, src_vocab_path, src_vocab_size, src_words_min_frequency, tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency, @@ -630,7 +692,7 @@ def __iter__(self): instead of a torchtext.data.Batch object. """ while True: - self.init_epoch() + self.init_epoch() # Inside, create_batches() will be called for idx, minibatch in enumerate(self.batches): # fast-forward if loaded from state if self._iterations_this_epoch > idx: @@ -801,19 +863,32 @@ def max_tok_len(new, count, sofar): such that the total number of src/tgt tokens (including padding) in a batch <= batch_size """ - # Maintains the longest src and tgt length in the current batch - global max_src_in_batch, max_tgt_in_batch # this is a hack - # Reset current longest length at a new batch (count=1) - if count == 1: - max_src_in_batch = 0 - max_tgt_in_batch = 0 - # Src: [ w1 ... wN ] - max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2) - # Tgt: [w1 ... wM ] - max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt[0]) + 1) - src_elements = count * max_src_in_batch - tgt_elements = count * max_tgt_in_batch - return max(src_elements, tgt_elements) + if hasattr(new, 'tokens'): + # when a example has the attr 'tokens', + # this means we are loading Bert Data + # Maintains the longest token length in the current batch + global max_tokens_in_batch + # Reset current longest length at a new batch (count=1) + if count == 1: + max_tokens_in_batch = 0 + # tokens: ['[CLS]', '[MASK]', ..., '[SEP]','This',...,'B','[SEP]'] + max_tokens_in_batch = max(max_tokens_in_batch, len(new.tokens)) + tokens_nelem = count * max_tokens_in_batch + return tokens_nelem + else: + # Maintains the longest src and tgt length in the current batch + global max_src_in_batch, max_tgt_in_batch # this is a hack + # Reset current longest length at a new batch (count=1) + if count == 1: + max_src_in_batch = 0 + max_tgt_in_batch = 0 + # Src: [ w1 ... wN ] + max_src_in_batch = max(max_src_in_batch, len(new.src[0]) + 2) + # Tgt: [w1 ... wM ] + max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt[0]) + 1) + src_elements = count * max_src_in_batch + tgt_elements = count * max_tgt_in_batch + return max(src_elements, tgt_elements) def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): diff --git a/onmt/model_builder.py b/onmt/model_builder.py index 9031dd40ac..5df3725201 100644 --- a/onmt/model_builder.py +++ b/onmt/model_builder.py @@ -13,20 +13,34 @@ from onmt.decoders import str2dec -from onmt.modules import Embeddings, VecEmbedding, CopyGenerator +from onmt.modules import Embeddings, VecEmbedding, CopyGenerator, \ + BertEmbeddings from onmt.modules.util_class import Cast from onmt.utils.misc import use_gpu from onmt.utils.logging import logger from onmt.utils.parse import ArgumentParser +from onmt.models import BertPreTrainingHeads, ClassificationHead, \ + TokenGenerationHead, TokenTaggingHead + def build_embeddings(opt, text_field, for_encoder=True): """ Args: opt: the option in current environment. - text_field(TextMultiField): word and feats field. + text_field(TextMultiField | Field): word and feats field. for_encoder(bool): build Embeddings for encoder or decoder? """ + if opt.is_bert: + token_fields_vocab = text_field.vocab + vocab_size = len(token_fields_vocab) + emb_dim = opt.word_vec_size + return BertEmbeddings( + vocab_size, emb_dim, + dropout=(opt.dropout[0] if type(opt.dropout) is list + else opt.dropout) + ) + emb_dim = opt.src_word_vec_size if for_encoder else opt.tgt_word_vec_size if opt.model_type == "vec" and for_encoder: @@ -71,8 +85,11 @@ def build_encoder(opt, embeddings): opt: the option in current environment. embeddings (Embeddings): vocab embeddings for this encoder. """ - enc_type = opt.encoder_type if opt.model_type == "text" \ - or opt.model_type == "vec" else opt.model_type + if opt.is_bert: + enc_type = 'bert' + else: + enc_type = opt.encoder_type if opt.model_type == "text" \ + or opt.model_type == "vec" else opt.model_type return str2enc[enc_type].from_opt(opt, embeddings) @@ -90,6 +107,7 @@ def build_decoder(opt, embeddings): def load_test_model(opt, model_path=None): if model_path is None: + assert hasattr(opt, 'models') model_path = opt.models[0] checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) @@ -129,7 +147,7 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): gpu_id (int or NoneType): Which GPU to use. Returns: - the NMTModel. + the NMTModel or BertEncoder(with generator). """ # for back compat when attention_dropout was not defined @@ -139,7 +157,10 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): model_opt.attention_dropout = model_opt.dropout # Build embeddings. - if model_opt.model_type == "text" or model_opt.model_type == "vec": + if model_opt.is_bert: + src_field = fields["tokens"] + src_emb = build_embeddings(model_opt, src_field) + elif model_opt.model_type == "text" or model_opt.model_type == "vec": src_field = fields["src"] src_emb = build_embeddings(model_opt, src_field) else: @@ -148,31 +169,32 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): # Build encoder. encoder = build_encoder(model_opt, src_emb) - # Build decoder. - tgt_field = fields["tgt"] - tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False) + if not model_opt.is_bert: + # Build decoder. + tgt_field = fields["tgt"] + tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False) - # Share the embedding matrix - preprocess with share_vocab required. - if model_opt.share_embeddings: - # src/tgt vocab should be the same if `-share_vocab` is specified. - assert src_field.base_field.vocab == tgt_field.base_field.vocab, \ - "preprocess with -share_vocab if you use share_embeddings" + # Share the embedding matrix - preprocess with share_vocab required. + if model_opt.share_embeddings: + # src/tgt vocab should be the same if `-share_vocab` is specified. + assert src_field.base_field.vocab == tgt_field.base_field.vocab, \ + "preprocess with -share_vocab if you use share_embeddings" - tgt_emb.word_lut.weight = src_emb.word_lut.weight + tgt_emb.word_lut.weight = src_emb.word_lut.weight - decoder = build_decoder(model_opt, tgt_emb) + decoder = build_decoder(model_opt, tgt_emb) - # Build NMTModel(= encoder + decoder). if gpu and gpu_id is not None: device = torch.device("cuda", gpu_id) elif gpu and not gpu_id: device = torch.device("cuda") elif not gpu: device = torch.device("cpu") - model = onmt.models.NMTModel(encoder, decoder) # Build Generator. - if not model_opt.copy_attn: + if model_opt.is_bert: + generator = build_bert_generator(model_opt, fields, encoder) + elif not model_opt.copy_attn: if model_opt.generator_function == "sparsemax": gen_func = onmt.modules.sparse_activations.LogSparsemax(dim=-1) else: @@ -191,40 +213,62 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None): pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token] generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx) + if model_opt.is_bert: + model = encoder + else: + # Build NMTModel(= encoder + decoder). + model = onmt.models.NMTModel(encoder, decoder) # Load the model states from checkpoint or initialize them. + model_init = {'model': False, 'generator': False} if checkpoint is not None: - # This preserves backward-compat for models using customed layernorm - def fix_key(s): - s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2', - r'\1.layer_norm\2.bias', s) - s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2', - r'\1.layer_norm\2.weight', s) - return s - - checkpoint['model'] = {fix_key(k): v - for k, v in checkpoint['model'].items()} - # end of patch for backward compatibility - - model.load_state_dict(checkpoint['model'], strict=False) - generator.load_state_dict(checkpoint['generator'], strict=False) - else: - if model_opt.param_init != 0.0: - for p in model.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - for p in generator.parameters(): - p.data.uniform_(-model_opt.param_init, model_opt.param_init) - if model_opt.param_init_glorot: - for p in model.parameters(): - if p.dim() > 1: - xavier_uniform_(p) - for p in generator.parameters(): - if p.dim() > 1: - xavier_uniform_(p) - - if hasattr(model.encoder, 'embeddings'): + assert 'model' in checkpoint + if not model_opt.is_bert: + # This preserves back-compat for models using customed layernorm + def fix_key(s): + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.b_2', + r'\1.layer_norm\2.bias', s) + s = re.sub(r'(.*)\.layer_norm((_\d+)?)\.a_2', + r'\1.layer_norm\2.weight', s) + return s + + checkpoint['model'] = {fix_key(k): v + for k, v in checkpoint['model'].items()} + # end of patch for backward compatibility + # if model.state_dict().keys() != checkpoint['model'].keys(): + # raise ValueError("Checkpoint don't match actual model!") + logger.info("Load Model Parameters...") + model.load_state_dict(checkpoint['model'], strict=True) + model_init['model'] = True + if generator.state_dict().keys() == checkpoint['generator'].keys(): + logger.info("Load generator Parameters...") + generator.load_state_dict(checkpoint['generator'], strict=True) + model_init['generator'] = True + + for module_name, is_init in model_init.items(): + if not is_init: + logger.info("Initialize {} Parameters...".format(module_name)) + sub_module = model if module_name == 'model' else generator + if model_opt.param_init != 0.0: + logger.info('Initialize weights using a uniform distribution') + for p in sub_module.parameters(): + p.data.uniform_(-model_opt.param_init, + model_opt.param_init) + if model_opt.param_init_normal != 0.0: + logger.info('Initialize weights using a normal distribution') + normal_std = model_opt.param_init_normal + for p in sub_module.parameters(): + p.data.normal_(mean=0, std=normal_std) + if model_opt.param_init_glorot: + logger.info('Glorot initialization') + for p in sub_module.parameters(): + if p.dim() > 1: + xavier_uniform_(p) + + if checkpoint is None: + if hasattr(model, 'encoder') and hasattr(model.encoder, 'embeddings'): model.encoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_enc) - if hasattr(model.decoder, 'embeddings'): + if hasattr(model, 'decoder') and hasattr(model.decoder, 'embeddings'): model.decoder.embeddings.load_pretrained_vectors( model_opt.pre_word_vecs_dec) @@ -240,3 +284,39 @@ def build_model(model_opt, opt, fields, checkpoint): model = build_base_model(model_opt, fields, use_gpu(opt), checkpoint) logger.info(model) return model + + +def build_bert_generator(model_opt, fields, bert_encoder): + """Main part for transfer learning: + set opt.task_type to `pretraining` if want finetuning Bert; + set opt.task_type to `classification` if want sentence level task; + set opt.task_type to `generation` if want token level task. + Both all_encoder_layers and pooled_output will be feed to generator, + pretraining task will use the two, + while only pooled_output will be used for classification generator; + only all_encoder_layers will be used for generation generator + """ + task = model_opt.task_type + dropout = model_opt.dropout[0] if type(model_opt.dropout) is list \ + else model_opt.dropout + if task == 'pretraining': + generator = BertPreTrainingHeads( + bert_encoder.d_model, bert_encoder.embeddings.vocab_size) + if model_opt.reuse_embeddings: + generator.mask_lm.decode.weight = \ + bert_encoder.embeddings.word_embeddings.weight + elif task == 'generation': + generator = TokenGenerationHead( + bert_encoder.d_model, bert_encoder.vocab_size) + if model_opt.reuse_embeddings: + generator.decode.weight = \ + bert_encoder.embeddings.word_embeddings.weight + elif task == 'classification': + n_class = len(fields["category"].vocab.stoi) + logger.info('Generator of classification with %s class.' % n_class) + generator = ClassificationHead(bert_encoder.d_model, n_class, dropout) + elif task == 'tagging': + n_class = len(fields["token_labels"].vocab.stoi) + logger.info('Generator of tagging with %s tag.' % n_class) + generator = TokenTaggingHead(bert_encoder.d_model, n_class, dropout) + return generator diff --git a/onmt/models/__init__.py b/onmt/models/__init__.py index c0e48bf506..bae6e3a06a 100644 --- a/onmt/models/__init__.py +++ b/onmt/models/__init__.py @@ -1,5 +1,9 @@ """Module defining models.""" from onmt.models.model_saver import build_model_saver, ModelSaver from onmt.models.model import NMTModel +from onmt.models.bert_generators import BertPreTrainingHeads,\ + ClassificationHead, TokenGenerationHead, TokenTaggingHead -__all__ = ["build_model_saver", "ModelSaver", "NMTModel"] +__all__ = ["build_model_saver", "ModelSaver", "NMTModel", + "BertPreTrainingHeads", "ClassificationHead", + "TokenGenerationHead", "TokenTaggingHead"] diff --git a/onmt/models/bert_generators.py b/onmt/models/bert_generators.py new file mode 100644 index 0000000000..34f734df76 --- /dev/null +++ b/onmt/models/bert_generators.py @@ -0,0 +1,208 @@ +import torch + +import torch.nn as nn +from onmt.utils import get_activation_fn + + +class BertPreTrainingHeads(nn.Module): + """ + Bert Pretraining Heads: Masked Language Models, Next Sentence Prediction + + Args: + hidden_size (int): output size of BERT model + vocab_size (int): total vocab size + """ + def __init__(self, hidden_size, vocab_size): + super(BertPreTrainingHeads, self).__init__() + self.next_sentence = NextSentencePrediction(hidden_size) + self.mask_lm = MaskedLanguageModel(hidden_size, vocab_size) + + def forward(self, x, pooled_out): + """ + Args: + x (list of Tensor): all_encoder_layers, shape ``(B, S, H)`` + pooled_output (Tensor): second output of bert encoder, ``(B, H)`` + Returns: + seq_class_log_prob (Tensor): next sentence prediction, ``(B, 2)`` + prediction_log_prob (Tensor): mlm prediction, ``(B, S, vocab)`` + """ + seq_class_log_prob = self.next_sentence(pooled_out) + prediction_log_prob = self.mask_lm(x[-1]) + return seq_class_log_prob, prediction_log_prob + + +class MaskedLanguageModel(nn.Module): + """predicting origin token from masked input sequence + n-class classification problem, n-class = vocab_size + + Args: + hidden_size (int): output size of BERT model + vocab_size (int): total vocab size + """ + + def __init__(self, hidden_size, vocab_size): + super(MaskedLanguageModel, self).__init__() + self.transform = BertPredictionTransform(hidden_size) + + self.decode = nn.Linear(hidden_size, vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + self.log_softmax = nn.LogSoftmax(dim=-1) + + def forward(self, x): + """ + Args: + x (Tensor): first output of bert encoder, ``(B, S, H)`` + Returns: + prediction_log_prob (Tensor): shape ``(B, S, vocab)`` + """ + x = self.transform(x) # (batch, seq, d_model) + prediction_scores = self.decode(x) + self.bias # (batch, seq, vocab) + prediction_log_prob = self.log_softmax(prediction_scores) + return prediction_log_prob + + +class NextSentencePrediction(nn.Module): + """ + 2-class classification model : is_next, is_random_next + + Args: + hidden_size (int): BERT model output size + """ + + def __init__(self, hidden_size): + super(NextSentencePrediction, self).__init__() + self.linear = nn.Linear(hidden_size, 2) + self.log_softmax = nn.LogSoftmax(dim=-1) + + def forward(self, x): + """ + Args: + x (Tensor): second output of bert encoder, ``(B, H)`` + Returns: + seq_class_prob (Tensor): ``(B, 2)`` + """ + seq_relationship_score = self.linear(x) # (batch, 2) + seq_class_log_prob = self.log_softmax(seq_relationship_score) + return seq_class_log_prob + + +class BertPredictionTransform(nn.Module): + """{Linear(h,h), Activation, LN} block.""" + + def __init__(self, hidden_size): + """ + Args: + hidden_size (int): BERT model hidden layer size. + """ + + super(BertPredictionTransform, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = get_activation_fn('gelu') + self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-12) + + def forward(self, hidden_states): + """ + Args: + hidden_states (Tensor): BERT encoder output ``(B, S, H)`` + """ + + hidden_states = self.layer_norm(self.activation( + self.dense(hidden_states))) + return hidden_states + + +class ClassificationHead(nn.Module): + """n-class Sentence classification head + + Args: + hidden_size (int): BERT model output size + n_class (int): number of classification label + """ + + def __init__(self, hidden_size, n_class, dropout=0.1): + """ + """ + super(ClassificationHead, self).__init__() + self.dropout = nn.Dropout(dropout) + self.linear = nn.Linear(hidden_size, n_class) + self.log_softmax = nn.LogSoftmax(dim=-1) + + def forward(self, all_hidden, pooled): + """ + Args: + all_hidden (list of Tensor): layers output, list [``(B, S, H)``] + pooled (Tensor): last layer hidden [CLS], ``(B, H)`` + Returns: + class_log_prob (Tensor): shape ``(B, 2)`` + None: this is a placeholder for token level prediction task + """ + + pooled = self.dropout(pooled) + score = self.linear(pooled) # (batch, n_class) + class_log_prob = self.log_softmax(score) # (batch, n_class) + return class_log_prob, None + + +class TokenTaggingHead(nn.Module): + """n-class Token Tagging head + + Args: + hidden_size (int): BERT model output size + n_class (int): number of tagging label + """ + + def __init__(self, hidden_size, n_class, dropout=0.1): + super(TokenTaggingHead, self).__init__() + self.dropout = nn.Dropout(dropout) + self.linear = nn.Linear(hidden_size, n_class) + self.log_softmax = nn.LogSoftmax(dim=-1) + + def forward(self, all_hidden, pooled): + """ + Args: + all_hidden (list of Tensor): layers output, list [``(B, S, H)``] + pooled (Tensor): last layer hidden [CLS], ``(B, H)`` + Returns: + None: this is a placeholder for sentence level task + tok_class_log_prob (Tensor): shape ``(B, S, n_class)`` + """ + last_hidden = all_hidden[-1] + last_hidden = self.dropout(last_hidden) # (batch, seq, d_model) + score = self.linear(last_hidden) # (batch, seq, n_class) + tok_class_log_prob = self.log_softmax(score) # (batch, seq, n_class) + return None, tok_class_log_prob + + +class TokenGenerationHead(nn.Module): + """ + Token generation head: generation token from input sequence + + Args: + hidden_size (int): output size of BERT model + vocab_size (int): total vocab size + """ + + def __init__(self, hidden_size, vocab_size): + super(TokenGenerationHead, self).__init__() + self.transform = BertPredictionTransform(hidden_size) + + self.decode = nn.Linear(hidden_size, vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + self.log_softmax = nn.LogSoftmax(dim=-1) + + def forward(self, all_hidden, pooled): + """ + Args: + all_hidden (list of Tensor): layers output, list [``(B, S, H)``] + pooled (Tensor): last layer hidden [CLS], ``(B, H)`` + Returns: + None: this is a placeholder for sentence level task + prediction_log_prob (Tensor): shape ``(B, S, vocab)`` + """ + last_hidden = all_hidden[-1] + y = self.transform(last_hidden) # (batch, seq, d_model) + prediction_scores = self.decode(y) + self.bias # (batch, seq, vocab) + prediction_log_prob = self.log_softmax(prediction_scores) + return None, prediction_log_prob diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index 4f2534b82e..9c847eaebd 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -1,6 +1,6 @@ import os import torch - +from torchtext.data import Field from collections import deque from onmt.utils.logging import logger @@ -107,15 +107,29 @@ def _save(self, step, model): # were not originally here. vocab = deepcopy(self.fields) - for side in ["src", "tgt"]: - keys_to_pop = [] - if hasattr(vocab[side], "fields"): - unk_token = vocab[side].fields[0][1].vocab.itos[0] - for key, value in vocab[side].fields[0][1].vocab.stoi.items(): - if value == 0 and key != unk_token: - keys_to_pop.append(key) - for key in keys_to_pop: - vocab[side].fields[0][1].vocab.stoi.pop(key, None) + for name, field in vocab.items(): + if isinstance(field, Field): + if hasattr(field, "vocab") and \ + (field.unk_token is not None): + assert name == 'tokens' + keys_to_pop = [] + unk_token = field.unk_token + unk_id = field.vocab.stoi[unk_token] + for key, value in field.vocab.stoi.items(): + if value == unk_id and key != unk_token: + keys_to_pop.append(key) + for key in keys_to_pop: + field.vocab.stoi.pop(key, None) + else: + if hasattr(field, "fields"): + assert name in ["src", "tgt"] + keys_to_pop = [] + unk_token = field.fields[0][1].vocab.itos[0] + for key, value in field.fields[0][1].vocab.stoi.items(): + if value == 0 and key != unk_token: + keys_to_pop.append(key) + for key in keys_to_pop: + field.fields[0][1].vocab.stoi.pop(key, None) checkpoint = { 'model': model_state_dict, diff --git a/onmt/modules/__init__.py b/onmt/modules/__init__.py index 763ac8448a..6c3bd9664a 100644 --- a/onmt/modules/__init__.py +++ b/onmt/modules/__init__.py @@ -3,16 +3,18 @@ from onmt.modules.gate import context_gate_factory, ContextGate from onmt.modules.global_attention import GlobalAttention from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention -from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ - CopyGeneratorLossCompute from onmt.modules.multi_headed_attn import MultiHeadedAttention from onmt.modules.embeddings import Embeddings, PositionalEncoding, \ VecEmbedding +from onmt.modules.bert_embeddings import BertEmbeddings from onmt.modules.weight_norm import WeightNormConv2d from onmt.modules.average_attn import AverageAttention +from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ + CopyGeneratorLossCompute __all__ = ["Elementwise", "context_gate_factory", "ContextGate", "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", "CopyGeneratorLoss", "CopyGeneratorLossCompute", "MultiHeadedAttention", "Embeddings", "PositionalEncoding", - "WeightNormConv2d", "AverageAttention", "VecEmbedding"] + "WeightNormConv2d", "AverageAttention", "VecEmbedding", + "BertEmbeddings"] diff --git a/onmt/modules/bert_embeddings.py b/onmt/modules/bert_embeddings.py new file mode 100644 index 0000000000..4b8f647b0c --- /dev/null +++ b/onmt/modules/bert_embeddings.py @@ -0,0 +1,66 @@ +import torch +import torch.nn as nn + + +class BertEmbeddings(nn.Module): + """ BERT input embeddings is sum of: + 1. Token embeddings: called word_embeddings + 2. Segmentation embeddings: called token_type_embeddings + 3. Position embeddings: called position_embeddings + :cite:`DBLP:journals/corr/abs-1810-04805` section 3.2 + + Args: + vocab_size (int): Size of the embedding vocabulary. + embed_size (int): Width of the word embeddings. + pad_idx (int): padding index + dropout (float): dropout rate + max_position (int): max sentence length in input + num_sentence (int): number of segment + """ + def __init__(self, vocab_size, embed_size=768, pad_idx=0, + dropout=0.1, max_position=512, num_sentence=2): + super(BertEmbeddings, self).__init__() + self.vocab_size = vocab_size + self.embed_size = embed_size + self.word_padding_idx = pad_idx + # Token embeddings: for input tokens + self.word_embeddings = nn.Embedding( + vocab_size, embed_size, padding_idx=pad_idx) + # Position embeddings: for Position Encoding + self.position_embeddings = nn.Embedding(max_position, embed_size) + # Segmentation embeddings: for distinguish sentences A/B + self.token_type_embeddings = nn.Embedding( + num_sentence, embed_size, padding_idx=pad_idx) + + self.dropout = nn.Dropout(dropout) + + def forward(self, input_ids, token_type_ids=None): + """ + Args: + input_ids (Tensor): ``(B, S)``. + token_type_ids (Tensor): segment id ``(B, S)``. + + Returns: + embeddings (Tensor): final embeddings, ``(B, S, H)``. + """ + seq_length = input_ids.size(1) + position_ids = torch.arange( + seq_length, dtype=torch.long, device=input_ids.device) + # [[0,1,...,seq_length-1]] -> [[0,1,...,seq_length-1] *batch_size] + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + word_embeds = self.word_embeddings(input_ids) + position_embeds = self.position_embeddings(position_ids) + token_type_embeds = self.token_type_embeddings(token_type_ids) + embeddings = word_embeds + position_embeds + token_type_embeds + # NOTE: in our version, LayerNorm is done in EncoderLayer + # before fed into Attention comparing to original one + # embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def update_dropout(self, dropout): + self.dropout.p = dropout diff --git a/onmt/modules/position_ffn.py b/onmt/modules/position_ffn.py index fb8df80aa7..a35f6a1168 100644 --- a/onmt/modules/position_ffn.py +++ b/onmt/modules/position_ffn.py @@ -2,6 +2,8 @@ import torch.nn as nn +from onmt.utils import get_activation_fn + class PositionwiseFeedForward(nn.Module): """ A two-layer Feed-Forward-Network with residual layer norm. @@ -11,16 +13,34 @@ class PositionwiseFeedForward(nn.Module): d_ff (int): the hidden layer size of the second-layer of the FNN. dropout (float): dropout probability in :math:`[0, 1)`. + activation (str): activation function to use. ['relu', 'gelu'] + is_bert (bool): default False. When set True, + layer_norm will be performed on the + direct connection of residual block. """ - def __init__(self, d_model, d_ff, dropout=0.1): + def __init__(self, d_model, d_ff, dropout=0.1, + activation='relu', is_bert=False): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(d_model, d_ff) self.w_2 = nn.Linear(d_ff, d_model) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + self.layer_norm = nn.LayerNorm( + d_model, eps=1e-12 if is_bert else 1e-6) self.dropout_1 = nn.Dropout(dropout) - self.relu = nn.ReLU() + self.activation = get_activation_fn(activation) self.dropout_2 = nn.Dropout(dropout) + self.is_bert = is_bert + + def residual(self, output, x): + """A Residual connection. + + Official BERT perform residual connection on layer normed input. + BERT's layer_norm is done before pass into next block while onmt's + layer_norm is performed at the begining. + """ + + maybe_norm = self.layer_norm(x) if self.is_bert else x + return output + maybe_norm def forward(self, x): """Layer definition. @@ -32,9 +52,9 @@ def forward(self, x): (FloatTensor): Output ``(batch_size, input_len, model_dim)``. """ - inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) + inter = self.dropout_1(self.activation(self.w_1(self.layer_norm(x)))) output = self.dropout_2(self.w_2(inter)) - return output + x + return self.residual(output, x) def update_dropout(self, dropout): self.dropout_1.p = dropout diff --git a/onmt/opts.py b/onmt/opts.py index 1c56ed23b7..ec4df7d2b5 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -59,6 +59,7 @@ def model_opts(parser): # Encoder-Decoder Options group = parser.add_argument_group('Model- Encoder-Decoder') + group.add('--is_bert', '-is_bert', action='store_true') group.add('--model_type', '-model_type', default='text', choices=['text', 'img', 'audio', 'vec'], help="Type of source model to use. Allows " @@ -151,6 +152,9 @@ def model_opts(parser): help='Number of heads for transformer self-attention') group.add('--transformer_ff', '-transformer_ff', type=int, default=2048, help='Size of hidden transformer feed-forward') + group.add('--activation', '-activation', default='relu', + choices=['relu', 'gelu'], + help='type of activation function used in Bert encoder.') group.add('--aan_useffn', '-aan_useffn', action="store_true", help='Turn on the FFN layer in the AAN decoder') @@ -334,9 +338,87 @@ def preprocess_opts(parser): "model faster and smaller") +def preprocess_bert_opts(parser): + """ Pre-procesing options for pretrained model """ + # Data options + group = parser.add_argument_group('Common') + group.add('--task', '-task', type=str, required=True, + choices=["classification", "tagging"], + help="Target task to perform") + group.add('--corpus_type', '-corpus_type', type=str, default="train", + choices=['train', 'valid'], + help="corpus type choose from ['train', 'valid'], " + + "Vocab file will be generate if `train`") + + group = parser.add_argument_group('Data') + group.add('--file_type', type=str, default="txt", choices=["csv", "txt"], + help="input file type. Choose [txt|csv]") + group.add('--data', '-data', type=str, nargs='+', default=[], + required=True, + help="input datas to prepare: [CLS]" + + "Single file for csv with column indicate label," + + "One file for each class as path/label/file; [TAG]" + + "Single file contain (tok, tag) in each line," + + "Sentence separated by blank line.") + group.add('--skip_head', '-skip_head', action="store_true", + help="CSV: If csv file contain head line.") + group.add('--do_lower_case', '-lower', action='store_true', + help='lowercase data') + group.add("--max_seq_len", type=int, default=256, + help="Maximum sequence length to keep.") + group.add('--save_data', '-save_data', type=str, required=True, + help="Output file Prefix for the prepared data") + + group = parser.add_argument_group('Columns') + # options for column-like input file with fields seperate by -delimiter + group.add('--delimiter', '-delimiter', type=str, default=' ', + help="delimiter used in input file for seperate fields.") + group.add('--input_columns', type=int, nargs='+', default=[], + help="Column where contain sentence A(,B)") + group.add('--label_column', type=int, default=None, + help="Column where contain label") + + group = parser.add_argument_group('Vocab') + group.add('--labels', '-labels', type=str, nargs='+', default=[], + help="Candidate labels, will be used to build label vocab. " + + "If not given, this will be built from input file.") + group.add('--sort_label_vocab', '-sort_label', type=bool, default=True, + help="sort label vocab in alphabetic order.") + group.add("--vocab_model", "-vm", type=str, default="bert-base-uncased", + choices=["bert-base-uncased", "bert-large-uncased", + "bert-base-cased", "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-base-cased-finetuned-mrpc"], + help="Pretrained BertTokenizer model use to tokenizer text.") + + # Data processing options + group = parser.add_argument_group('Random') + group.add('--do_shuffle', '-shuffle', action="store_true", + help="Shuffle data") + + group = parser.add_argument_group('Logging') + group.add('--log_file', '-log_file', type=str, default="", + help="Output logs to a file under this path.") + + def train_opts(parser): """ Training and saving options """ + group = parser.add_argument_group('Pretrain-finetuning') + group.add('--task_type', '-task_type', type=str, default="none", + choices=["none", "pretraining", "classification", "tagging"], + help="Downstream task for Bert if is_bert set True" + "Choose from pretraining Bert," + "use pretrained Bert for classification," + "use pretrained Bert for token generation.") + group.add('--reuse_embeddings', '-reuse_embeddings', type=bool, + default=False, help="if reuse embeddings for generator " + + "only for generation or pretraining task") + group = parser.add_argument_group('General') group.add('--data', '-data', required=True, help='Path prefix to the ".train.pt" and ' @@ -391,6 +473,10 @@ def train_opts(parser): group.add('--param_init_glorot', '-param_init_glorot', action='store_true', help="Init parameters with xavier_uniform. " "Required for transformer.") + group.add('--param_init_normal', '-param_normal', type=float, default=0.0, + help="Parameters are initialized over normal distribution " + "with (mean=0, std=param_init_normal). Used in BERT with 0.02." + "Set value > 0 and param_init 0.0 to activate.") group.add('--train_from', '-train_from', default='', type=str, help="If training from a checkpoint then this is the " @@ -463,7 +549,7 @@ def train_opts(parser): nargs="*", default=None, help='Criteria to use for early stopping.') group.add('--optim', '-optim', default='sgd', - choices=['sgd', 'adagrad', 'adadelta', 'adam', + choices=['sgd', 'adagrad', 'adadelta', 'adam', 'bertadam', 'sparseadam', 'adafactor', 'fusedadam'], help="Optimization method.") group.add('--adagrad_accumulator_init', '-adagrad_accumulator_init', @@ -540,10 +626,14 @@ def train_opts(parser): help="Decay every decay_steps") group.add('--decay_method', '-decay_method', type=str, default="none", - choices=['noam', 'noamwd', 'rsqrt', 'none'], + choices=['none', 'noam', 'noamwd', 'rsqrt', 'linear', + 'linearconst', 'cosine', 'cosine_hard_restart', + 'cosine_warmup_restart'], help="Use a custom decay rate.") group.add('--warmup_steps', '-warmup_steps', type=int, default=4000, help="Number of warmup steps for custom decay.") + group.add('--cycles', '-cycles', type=int, default=None, + help="required for cosine related decay.") group = parser.add_argument_group('Logging') group.add('--report_every', '-report_every', type=int, default=50, @@ -750,6 +840,66 @@ def translate_opts(parser): "model faster and smaller") +def predict_opts(parser): + """ Prediction [Using Pretrained model] options """ + group = parser.add_argument_group('Model') + group.add("--vocab_model", type=str, + default="bert-base-uncased", + choices=["bert-base-uncased", "bert-large-uncased", + "bert-base-cased", "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-base-cased-finetuned-mrpc"], + help="Bert pretrained tokenizer model to use.") + group.add("--model", type=str, default=None, required=True, + help="Path to Bert model that for predicting.") + group.add('--task', type=str, default=None, required=True, + choices=["classification", "tagging"], + help="Target task to perform") + + group = parser.add_argument_group('Data') + group.add('--data', '-i', type=str, default=None, required=True, + help="predicting data for classification / tagging" + + "Classification: Sentence1 ||| Sentence2, " + + "Tagging: one tokenized sentence a line") + group.add("--do_lower_case", action="store_true", help='lowercase data') + group.add('--delimiter', '-d', type=str, default=None, + help="Delimiter used for seperate sentence/word. " + + "Default: ' ||| ' for sentence used in [CLS], " + + " ' ' for word used in [TAG].") + group.add("--max_seq_len", type=int, default=256, + help="max sequence length for prepared data," + "set the limite of position encoding") + group.add('--output', '-output', default=None, required=True, + help="Path to output the predictions") + group.add('--shard_size', '-shard_size', type=int, default=10000, + help="Divide data into smaller multiple data files, " + "then build shards, each shard will have " + "opt.shard_size samples except last shard. " + "shard_size=0 means no segmentation " + "shard_size>0 segment data into multiple shards, " + "each shard has shard_size samples") + + group = parser.add_argument_group('Efficiency') + group.add('--batch_size', '-batch_size', type=int, default=8, + help='Batch size') + group.add('--batch_type', '-batch_type', default='sents', + choices=["sents", "tokens"], + help="Batch grouping for batch_size. Standard " + "is sents. Tokens will do dynamic batching") + group.add('--gpu', '-gpu', type=int, default=-1, help="Device to run on") + group.add('--seed', '-seed', type=int, default=829, help="Random seed") + group.add('--log_file', '-log_file', type=str, default="", + help="Output logs to a file under this path.") + group.add('--fp32', '-fp32', action='store_true', + help="Force the model to be in FP32 " + "because FP16 is very slow on GTX1080(ti).") + group.add('--verbose', '-verbose', action="store_true", + help='Print scores and predictions for each sentence') + # Copyright 2016 The Chromium Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. diff --git a/onmt/train_single.py b/onmt/train_single.py index f2487e6b0a..99ee8c6f94 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -51,11 +51,19 @@ def main(opt, device_id, batch_queue=None, semaphore=None): logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) - model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) - ArgumentParser.update_model_opts(model_opt) - ArgumentParser.validate_model_opts(model_opt) - logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) - vocab = checkpoint['vocab'] + if 'opt' in checkpoint: + model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) + ArgumentParser.update_model_opts(model_opt) + ArgumentParser.validate_model_opts(model_opt) + else: + model_opt = opt + + if 'vocab' in checkpoint: + logger.info('Loading vocab from checkpoint at %s.', + opt.train_from) + vocab = checkpoint['vocab'] + else: + vocab = torch.load(opt.data + '.vocab.pt') else: checkpoint = None model_opt = opt @@ -69,23 +77,31 @@ def main(opt, device_id, batch_queue=None, semaphore=None): else: fields = vocab - # Report src and tgt vocab sizes, including for features - for side in ['src', 'tgt']: - f = fields[side] - try: - f_iter = iter(f) - except TypeError: - f_iter = [(side, f)] - for sn, sf in f_iter: - if sf.use_vocab: - logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) - - # Build model. + if opt.is_bert: + # Report bert tokens vocab sizes, including for features + f = fields['tokens'] + logger.info(' * %s vocab size = %d' % ("BERT", len(f.vocab))) + else: + # Report src and tgt vocab sizes, including for features + for side in ['src', 'tgt']: + f = fields[side] + try: + f_iter = iter(f) + except TypeError: + f_iter = [(side, f)] + for sn, sf in f_iter: + if sf.use_vocab: + logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) + model = build_model(model_opt, opt, fields, checkpoint) n_params, enc, dec = _tally_parameters(model) logger.info('encoder: %d' % enc) - logger.info('decoder: %d' % dec) + if opt.is_bert: + logger.info('generator: %d' % dec) + else: + logger.info('decoder: %d' % dec) logger.info('* number of parameters: %d' % n_params) + _check_save_model_path(opt) # Build optimizer. diff --git a/onmt/trainer.py b/onmt/trainer.py index 4328ca52ea..42b57c6bf5 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -30,8 +30,14 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object used to save the model """ - - tgt_field = dict(fields)["tgt"].base_field + if not opt.is_bert: + tgt_field = dict(fields)["tgt"].base_field + elif opt.task_type == 'tagging' or opt.task_type == 'generation': + tgt_field = fields["token_labels"] + elif opt.task_type == 'classification': + tgt_field = fields["category"] + else: # pretraining task + tgt_field = fields["lm_labels_ids"] train_loss = onmt.utils.loss.build_loss_compute(model, tgt_field, opt) valid_loss = onmt.utils.loss.build_loss_compute( model, tgt_field, opt, train=False) @@ -70,7 +76,8 @@ def build_trainer(opt, device_id, model, fields, optim, model_saver=None): model_dtype=opt.model_dtype, earlystopper=earlystopper, dropout=dropout, - dropout_steps=dropout_steps) + dropout_steps=dropout_steps, + is_bert=opt.is_bert) return trainer @@ -101,13 +108,13 @@ class Trainer(object): """ def __init__(self, model, train_loss, valid_loss, optim, - trunc_size=0, shard_size=32, - norm_method="sents", accum_count=[1], - accum_steps=[0], + trunc_size=0, shard_size=32, norm_method="sents", + accum_count=[1], accum_steps=[0], n_gpu=1, gpu_rank=1, gpu_verbose_level=0, report_manager=None, with_align=False, model_saver=None, average_decay=0, average_every=1, model_dtype='fp32', - earlystopper=None, dropout=[0.3], dropout_steps=[0]): + earlystopper=None, dropout=[0.3], dropout_steps=[0], + is_bert=False): # Basic attributes. self.model = model self.train_loss = train_loss @@ -132,6 +139,7 @@ def __init__(self, model, train_loss, valid_loss, optim, self.earlystopper = earlystopper self.dropout = dropout self.dropout_steps = dropout_steps + self.is_bert = is_bert for i in range(len(self.accum_count_l)): assert self.accum_count_l[i] > 0 @@ -140,6 +148,10 @@ def __init__(self, model, train_loss, valid_loss, optim, """To enable accumulated gradients, you must disable target sequence truncating.""" + if self.is_bert: + assert self.trunc_size == 0 + """ Bert currently not support target sequence truncating""" + # Set model in training mode. self.model.train() @@ -162,12 +174,13 @@ def _accum_batches(self, iterator): self.accum_count = self._accum_count(self.optim.training_step) for batch in iterator: batches.append(batch) - if self.norm_method == "tokens": - num_tokens = batch.tgt[1:, :, 0].ne( - self.train_loss.padding_idx).sum() - normalization += num_tokens.item() - else: - normalization += batch.batch_size + if not self.is_bert: # Bert don't need normalization + if self.norm_method == "tokens": + num_tokens = batch.tgt[1:, :, 0].ne( + self.train_loss.padding_idx).sum() + normalization += num_tokens.item() + else: + normalization += batch.batch_size if len(batches) == self.accum_count: yield batches, normalization self.accum_count = self._accum_count(self.optim.training_step) @@ -216,9 +229,12 @@ def train(self, else: logger.info('Start training loop and validate every %d steps...', valid_steps) - - total_stats = onmt.utils.Statistics() - report_stats = onmt.utils.Statistics() + if self.is_bert: + total_stats = onmt.utils.BertStatistics() + report_stats = onmt.utils.BertStatistics() + else: + total_stats = onmt.utils.Statistics() + report_stats = onmt.utils.Statistics() self._start_report_manager(start_time=total_stats.start_time) for i, (batches, normalization) in enumerate( @@ -239,10 +255,16 @@ def train(self, .all_gather_list (normalization)) - self._gradient_accumulation( - batches, normalization, total_stats, - report_stats) + # Training Step: Forward -> compute Loss -> optimize + if self.is_bert: + self._bert_gradient_accumulation( + batches, total_stats, report_stats) + else: + self._gradient_accumulation( + batches, normalization, total_stats, + report_stats) + # Moving average if self.average_decay > 0 and i % self.average_every == 0: self._update_average(step) @@ -251,6 +273,7 @@ def train(self, self.optim.learning_rate(), report_stats) + # Part: validation if valid_iter is not None and step % valid_steps == 0: if self.gpu_verbose_level > 0: logger.info('GpuRank %d: validate step %d' @@ -287,7 +310,10 @@ def train(self, def validate(self, valid_iter, moving_average=None): """ Validate model. + + Args: valid_iter: validate data iterator + Returns: :obj:`nmt.Statistics`: validation loss statistics """ @@ -306,22 +332,47 @@ def validate(self, valid_iter, moving_average=None): valid_model.eval() with torch.no_grad(): - stats = onmt.utils.Statistics() + if self.is_bert: + stats = onmt.utils.BertStatistics() + for batch in valid_iter: + # input_ids: Size([batch_size, max_seq_length_in_batch]), + # seq_lengths: Size([batch_size]) + if isinstance(batch.tokens, tuple): + input_ids, _ = batch.tokens + else: + input_ids, _ = (batch.tokens, None) + # segment_ids: Size([batch_size, max_seq_length_in_batch]) + # 0 for sens A, 1 for sens B. 0 padding + token_type_ids = batch.segment_ids + # F-prop through the model. + all_encoder_layers, pooled_out = \ + valid_model(input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = \ + valid_model.generator(all_encoder_layers, pooled_out) + + outputs = (seq_class_log_prob, prediction_log_prob) + # Compute loss. + _, batch_stats = self.valid_loss(batch, outputs) + + # Update statistics. + stats.update(batch_stats) + else: + stats = onmt.utils.Statistics() + for batch in valid_iter: + src, src_lengths = batch.src if isinstance( + batch.src, tuple) else (batch.src, None) + tgt = batch.tgt - for batch in valid_iter: - src, src_lengths = batch.src if isinstance(batch.src, tuple) \ - else (batch.src, None) - tgt = batch.tgt + # F-prop through the model. + outputs, attns = valid_model(src, tgt, src_lengths, + with_align=self.with_align) - # F-prop through the model. - outputs, attns = valid_model(src, tgt, src_lengths, - with_align=self.with_align) + # Compute loss. + _, batch_stats = self.valid_loss(batch, outputs, attns) - # Compute loss. - _, batch_stats = self.valid_loss(batch, outputs, attns) + # Update statistics. + stats.update(batch_stats) - # Update statistics. - stats.update(batch_stats) if moving_average: for param_data, param in zip(model_params_data, self.model.parameters()): @@ -462,3 +513,76 @@ def _report_step(self, learning_rate, step, train_stats=None, return self.report_manager.report_step( learning_rate, step, train_stats=train_stats, valid_stats=valid_stats) + + def _bert_gradient_accumulation(self, true_batches, + total_stats, report_stats): + """As the loss will be reduced by mean, normalization is not needed. + But we still need to average between GPUs. + """ + if self.accum_count > 1: + self.optim.zero_grad() + + for k, batch in enumerate(true_batches): + # target_size = batch.tgt.size(0) + # NOTE: for batch in BERT : + # batch_first is True -> [batch, seq, vocab] + if isinstance(batch.tokens, tuple): + input_ids, seq_lengths = batch.tokens + else: + input_ids, seq_lengths = (batch.tokens, None) + + if seq_lengths is not None: + report_stats.n_src_words += seq_lengths.sum().item() + + token_type_ids = batch.segment_ids + + # 1. F-prop all to get log likelihood of two task. + if self.accum_count == 1: + self.optim.zero_grad() + + all_encoder_layers, pooled_out = self.model( + input_ids, token_type_ids) + seq_class_log_prob, prediction_log_prob = self.model.generator( + all_encoder_layers, pooled_out) + # NOTE: (batch_size, 2), (batch_size, seq_size, vocab_size) + outputs = (seq_class_log_prob, prediction_log_prob) + + # 2. Compute loss. + try: + loss, batch_stats = self.train_loss(batch, outputs) + + if loss is not None: + self.optim.backward(loss) + + total_stats.update(batch_stats) + report_stats.update(batch_stats) + except Exception: + traceback.print_exc() + logger.info("At step %d, we removed a batch - accum %d", + self.optim.training_step, k) + + # 3. Update the parameters and statistics. + if self.accum_count == 1: + # Multi GPU gradient gather + if self.n_gpu > 1: + grads = [p.grad.data for p in self.model.parameters() + if p.requires_grad + and p.grad is not None] + + # NOTE: average the gradient across the GPU + onmt.utils.distributed.all_reduce_and_rescale_tensors( + grads, float(self.n_gpu)) + + self.optim.step() + + # in case of multi step gradient accumulation, + # update only after accum batches + if self.accum_count > 1: + if self.n_gpu > 1: + grads = [p.grad.data for p in self.model.parameters() + if p.requires_grad + and p.grad is not None] + # NOTE: average the gradient across the GPU + onmt.utils.distributed.all_reduce_and_rescale_tensors( + grads, float(self.n_gpu)) + self.optim.step() diff --git a/onmt/translate/__init__.py b/onmt/translate/__init__.py index 6cd5668991..004ef222ac 100644 --- a/onmt/translate/__init__.py +++ b/onmt/translate/__init__.py @@ -7,8 +7,9 @@ from onmt.translate.penalties import PenaltyBuilder from onmt.translate.translation_server import TranslationServer, \ ServerModelError +from onmt.translate.predictor import Classifier, Tagger __all__ = ['Translator', 'Translation', 'BeamSearch', 'GNMTGlobalScorer', 'TranslationBuilder', 'PenaltyBuilder', 'TranslationServer', 'ServerModelError', - "DecodeStrategy", "GreedySearch"] + "DecodeStrategy", "GreedySearch", "Classifier", "Tagger"] diff --git a/onmt/translate/predictor.py b/onmt/translate/predictor.py new file mode 100644 index 0000000000..5ebffa2293 --- /dev/null +++ b/onmt/translate/predictor.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python +""" Classifier Class and builder """ +from __future__ import print_function +import codecs +import time + +import torch +import torchtext.data +import onmt.model_builder +import onmt.inputters as inputters +from onmt.inputters.inputter import max_tok_len +from onmt.utils.misc import set_random_seed + + +def build_classifier(opt, logger=None, out_file=None): + """Return a classifier with result redirect to `out_file`.""" + + if out_file is None: + out_file = codecs.open(opt.output, 'w', 'utf-8') + + load_model = onmt.model_builder.load_test_model + fields, model, model_opt = load_model(opt, opt.model) + + classifier = Classifier.from_opt( + model, + fields, + opt, + model_opt, + out_file=out_file, + logger=logger + ) + return classifier + + +def build_tagger(opt, logger=None, out_file=None): + """Return a tagger with result redirect to `out_file`.""" + + if out_file is None: + out_file = codecs.open(opt.output, 'w', 'utf-8') + + load_model = onmt.model_builder.load_test_model + fields, model, model_opt = load_model(opt, opt.model) + + tagger = Tagger.from_opt( + model, + fields, + opt, + model_opt, + out_file=out_file, + logger=logger + ) + return tagger + + +class Predictor(object): + """Predictor a batch of data with a saved model. + + Args: + model (nn.Sequential): model to use + fields (dict[str, torchtext.data.Field]): A dict of field. + gpu (int): GPU device. Set to negative for no GPU. + data_type (str): Source data type. + verbose (bool): output every predition with confidences. + report_time (bool): Print/log total time/frequency. + out_file (TextIO or codecs.StreamReaderWriter): Output file. + logger (logging.Logger or NoneType): Logger. + """ + + def __init__( + self, + model, + fields, + gpu=-1, + verbose=False, + out_file=None, + report_time=True, + logger=None, + seed=-1): + self.model = model + self.fields = fields + + self._gpu = gpu + self._use_cuda = gpu > -1 + self._dev = torch.device("cuda", self._gpu) \ + if self._use_cuda else torch.device("cpu") + + self.verbose = verbose + self.report_time = report_time + self.out_file = out_file + self.logger = logger + + set_random_seed(seed, self._use_cuda) + + @classmethod + def from_opt( + cls, + model, + fields, + opt, + model_opt, + out_file=None, + logger=None): + """Alternate constructor. + + Args: + model (onmt.modules): See :func:`__init__()`. + fields (dict[str, torchtext.data.Field]): See + :func:`__init__()`. + opt (argparse.Namespace): Command line options + model_opt (argparse.Namespace): Command line options saved with + the model checkpoint. + out_file (TextIO or codecs.StreamReaderWriter): See + :func:`__init__()`. + logger (logging.Logger or NoneType): See :func:`__init__()`. + """ + + return cls( + model, + fields, + gpu=opt.gpu, + verbose=opt.verbose, + out_file=out_file, + logger=logger, + seed=opt.seed) + + def _log(self, msg): + if self.logger is not None: + self.logger.info(msg) + else: + print(msg) + + +class Classifier(Predictor): + """classify a batch of sentences with a saved model. + + Args: + model (nn.Sequential): BERT model to use for classify + fields (dict[str, torchtext.data.Field]): A dict of field. + gpu (int): GPU device. Set to negative for no GPU. + data_type (str): Source data type. + verbose (bool): output every predition with confidences. + report_time (bool): Print/log total time/frequency. + out_file (TextIO or codecs.StreamReaderWriter): Output file. + logger (logging.Logger or NoneType): Logger. + """ + + def __init__( + self, + model, + fields, + gpu=-1, + verbose=False, + out_file=None, + report_time=True, + logger=None, + seed=-1): + super(Classifier, self).__init__( + model, + fields, + gpu=gpu, + verbose=verbose, + out_file=out_file, + report_time=report_time, + logger=logger, + seed=seed) + label_field = self.fields["category"] + self.label_vocab = label_field.vocab + + def classify(self, data, batch_size, tokenizer, delimiter=' ||| ', + max_seq_len=256, batch_type="sents"): + """Classify content of ``data``. + + Args: + data (list of str): ['Sentence1 ||| Sentence2',...]. + batch_size (int): size of examples per mini-batch + batch_type (str): Batch grouping for batch_size. Chose from + {'sents', 'tokens'}, default batch_size count by sentence. + + Returns: + all_predictions (list of str):[c1, ..., cn]. + """ + + dataset = inputters.ClassifierDataset( + self.fields, data, tokenizer, max_seq_len, delimiter) + + data_iter = torchtext.data.Iterator( + dataset=dataset, + batch_size=batch_size, + batch_size_fn=max_tok_len if batch_type == "tokens" else None, + device=self._dev, + train=False, + sort=False, + sort_within_batch=False, + shuffle=False + ) + + all_predictions = [] + + start_time = time.time() + + for batch in data_iter: + pred_sents_labels = self.classify_batch(batch) + all_predictions.extend(pred_sents_labels) + self.out_file.write('\n'.join(pred_sents_labels) + '\n') + self.out_file.flush() + + end_time = time.time() + + if self.report_time: + total_time = end_time - start_time + self._log("Total classification time: %f s" % total_time) + self._log("Average classification time: %f s" % ( + total_time / len(all_predictions))) + self._log("Sentences per second: %f" % ( + len(all_predictions) / total_time)) + return all_predictions + + def classify_batch(self, batch): + """Classify a batch of sentences.""" + with torch.no_grad(): + input_ids, _ = batch.tokens + token_type_ids = batch.segment_ids + all_encoder_layers, pooled_out = self.model( + input_ids, token_type_ids) + seq_class_log_prob, _ = self.model.generator( + all_encoder_layers, pooled_out) + # Predicting + pred_sents_ids = seq_class_log_prob.argmax(-1).tolist() + pred_sents_labels = [self.label_vocab.itos[index] + for index in pred_sents_ids] + if self.verbose: + seq_class_prob = seq_class_log_prob.exp() + category_probs = seq_class_prob.tolist() + preds = ['\t'.join(map(str, category_prob)) + '\t' + pred + for category_prob, pred in zip( + category_probs, pred_sents_labels)] + return preds + return pred_sents_labels + + +class Tagger(Predictor): + """Tagging a batch of sentences with a saved model. + + Args: + model (nn.Sequential): BERT model to use for Tagging + fields (dict[str, torchtext.data.Field]): A dict of field. + gpu (int): GPU device. Set to negative for no GPU. + data_type (str): Source data type. + verbose (bool): output every predition with confidences. + report_time (bool): Print/log total time/frequency. + out_file (TextIO or codecs.StreamReaderWriter): Output file. + logger (logging.Logger or NoneType): Logger. + """ + + def __init__( + self, + model, + fields, + gpu=-1, + verbose=False, + out_file=None, + report_time=True, + logger=None, + seed=-1): + super(Tagger, self).__init__( + model, + fields, + gpu=gpu, + verbose=verbose, + out_file=out_file, + report_time=report_time, + logger=logger, + seed=seed) + label_field = self.fields["token_labels"] + self.label_vocab = label_field.vocab + self.pad_token = label_field.pad_token + self.pad_index = self.label_vocab.stoi[self.pad_token] + + def tagging(self, data, batch_size, tokenizer, delimiter=' ', + max_seq_len=256, batch_type="sents"): + """Tagging content of ``data``. + + Args: + data (list of str): ['T1 T2 ... Tn',...]. + batch_size (int): size of examples per mini-batch + + Returns: + all_predictions (list of list of str): [['L1', ..., 'Ln'],...]. + """ + dataset = inputters.TaggerDataset( + self.fields, data, tokenizer, max_seq_len, delimiter) + + data_iter = torchtext.data.Iterator( + dataset=dataset, + batch_size=batch_size, + batch_size_fn=max_tok_len if batch_type == "tokens" else None, + device=self._dev, + train=False, + sort=False, + sort_within_batch=False, + shuffle=False + ) + + all_predictions = [] + + start_time = time.time() + + for batch in data_iter: + pred_tokens_tag = self.tagging_batch(batch) + all_predictions.extend(pred_tokens_tag) + for pred_sent in pred_tokens_tag: + self.out_file.write('\n'.join(pred_sent) + '\n' + '\n') + self.out_file.flush() + + end_time = time.time() + + if self.report_time: + total_time = end_time - start_time + self._log("Total tagging time (s): %f" % total_time) + self._log("Average tagging time (s): %f" % ( + total_time / len(all_predictions))) + self._log("Sentence per second: %f" % ( + len(all_predictions) / total_time)) + return all_predictions + + def tagging_batch(self, batch): + """Tagging a batch of sentences.""" + with torch.no_grad(): + # Batch + input_ids, _ = batch.tokens + token_type_ids = batch.segment_ids + taggings = batch.token_labels + # Forward + all_encoder_layers, pooled_out = self.model( + input_ids, token_type_ids) + _, prediction_log_prob = self.model.generator( + all_encoder_layers, pooled_out) + # Predicting + pred_tag_ids = prediction_log_prob.argmax(-1) + non_padding = taggings.ne(self.pad_index) + batch_tag_ids, batch_mask = list(pred_tag_ids), list(non_padding) + batch_tag_select_ids = [pred.masked_select(mask).tolist() + for pred, mask in + zip(batch_tag_ids, batch_mask)] + + pred_tokens_tag = [[self.label_vocab.itos[index] + for index in tag_select_ids] + for tag_select_ids in batch_tag_select_ids] + return pred_tokens_tag diff --git a/onmt/utils/__init__.py b/onmt/utils/__init__.py index 6e7873aca4..e2b4b782fa 100644 --- a/onmt/utils/__init__.py +++ b/onmt/utils/__init__.py @@ -1,13 +1,19 @@ """Module defining various utilities.""" + from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed from onmt.utils.alignment import make_batch_align_matrix from onmt.utils.report_manager import ReportMgr, build_report_manager -from onmt.utils.statistics import Statistics +from onmt.utils.statistics import Statistics, BertStatistics from onmt.utils.optimizers import MultipleOptimizer, \ - Optimizer, AdaFactor + Optimizer, AdaFactor, AdamW from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts +from onmt.utils.activation_fn import get_activation_fn +from onmt.utils.bert_tokenization import BertTokenizer +from onmt.utils.bert_vocab_archive_map import PRETRAINED_VOCAB_ARCHIVE_MAP __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", - "build_report_manager", "Statistics", - "MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping", - "scorers_from_opts", "make_batch_align_matrix"] + "build_report_manager", "Statistics", "BertStatistics", + "MultipleOptimizer", "Optimizer", "AdaFactor", "AdamW", + "EarlyStopping", "scorers_from_opts", "get_activation_fn", + "BertTokenizer", "PRETRAINED_VOCAB_ARCHIVE_MAP", + "make_batch_align_matrix"] diff --git a/onmt/utils/activation_fn.py b/onmt/utils/activation_fn.py new file mode 100644 index 0000000000..9497a8b845 --- /dev/null +++ b/onmt/utils/activation_fn.py @@ -0,0 +1,36 @@ +import torch +import torch.nn as nn +import math + + +def get_activation_fn(activation): + """Return an activation function Module according to its name.""" + if activation == 'gelu': + fn = GELU() + elif activation == 'relu': + fn = nn.ReLU() + elif activation == 'tanh': + fn = nn.Tanh() + else: + raise ValueError("Please pass a valid \ + activation function") + return fn + + +class GELU(nn.Module): + """ Implementation of the gelu activation function + :cite:`DBLP:journals/corr/HendrycksG16` + + For information: OpenAI GPT's gelu is slightly different + (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) + * (x + 0.044715 * torch.pow(x, 3)))) + + Examples:: + >>> m = GELU() + >>> inputs = torch.randn(2) + >>> outputs = m(inputs) + """ + def forward(self, x): + gelu = x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + return gelu diff --git a/onmt/utils/bert_tokenization.py b/onmt/utils/bert_tokenization.py new file mode 100644 index 0000000000..7dda9cbb63 --- /dev/null +++ b/onmt/utils/bert_tokenization.py @@ -0,0 +1,429 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tokenization classes.""" + +from __future__ import absolute_import, division, \ + print_function, unicode_literals + +import collections +import logging +import os +import unicodedata +from io import open + +from .file_utils import cached_path +from onmt.utils.bert_vocab_archive_map import PRETRAINED_VOCAB_ARCHIVE_MAP + +logger = logging.getLogger(__name__) + +PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { + 'bert-base-uncased': 512, + 'bert-large-uncased': 512, + 'bert-base-cased': 512, + 'bert-large-cased': 512, + 'bert-base-multilingual-uncased': 512, + 'bert-base-multilingual-cased': 512, + 'bert-base-chinese': 512, + 'bert-base-german-cased': 512, + 'bert-large-uncased-whole-word-masking': 512, + 'bert-large-cased-whole-word-masking': 512, + 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, + 'bert-large-cased-whole-word-masking-finetuned-squad': 512, + 'bert-base-cased-finetuned-mrpc': 512, +} +VOCAB_NAME = 'vocab.txt' + + +def load_vocab(vocab_file): + """Loads a vocabulary file into a dictionary.""" + vocab = collections.OrderedDict() + with open(vocab_file, "r", encoding="utf-8") as reader: + tokens = reader.readlines() + for index, token in enumerate(tokens): + token = token.rstrip('\n') + vocab[token] = index + return vocab + + +def whitespace_tokenize(text): + """Runs basic whitespace cleaning and splitting on a piece of text.""" + text = text.strip() + if not text: + return [] + tokens = text.split() + return tokens + + +class BertTokenizer(object): + """Runs end-to-end tokenization: punctuation splitting + wordpiece""" + + def __init__(self, vocab_file, do_lower_case=True, max_len=None, + do_basic_tokenize=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BertTokenizer. + + Args: + vocab_file (str): Path to a one-wordpiece-per-line vocabulary file + do_lower_case (bool): If to lower case the input, Only has + an effect when do_wordpiece_only=False + do_basic_tokenize (bool): If to do basic tokenization before WP. + max_len (int): Maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of + this value (if specified) and the underlying BERT + model's sequence length. + never_split (list): List of tokens which will never be split during + tokenization. Only has an effect when + do_wordpiece_only=False. + """ + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. " + "To load the vocabulary from a Google pretrained model use " + "`tokenizer = BertTokenizer.from_pretrained(" + "PRETRAINED_MODEL_NAME)`".format(vocab_file)) + self.vocab = load_vocab(vocab_file) + self.ids_to_tokens = collections.OrderedDict( + [(ids, tok) for tok, ids in self.vocab.items()]) + self.do_basic_tokenize = do_basic_tokenize + if do_basic_tokenize: + self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, + never_split=never_split) + self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) + self.max_len = max_len if max_len is not None else int(1e12) + + def tokenize(self, text): + split_tokens = [] + if self.do_basic_tokenize: + for token in self.basic_tokenizer.tokenize(text): + for sub_token in self.wordpiece_tokenizer.tokenize(token): + split_tokens.append(sub_token) + else: + split_tokens = self.wordpiece_tokenizer.tokenize(text) + return split_tokens + + def convert_tokens_to_ids(self, tokens): + """Converts a sequence of tokens into ids using the vocab.""" + ids = [] + for token in tokens: + ids.append(self.vocab[token]) + if len(ids) > self.max_len: + logger.warning( + "Token indices sequence length is longer than the specified " + "maximum sequence length for this BERT model ({} > {}). " + "Running this sequence through BERT will result in " + "indexing errors".format(len(ids), self.max_len) + ) + return ids + + def convert_ids_to_tokens(self, ids): + """Converts a sequence of ids in wordpiece tokens using the vocab.""" + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + index = 0 + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_NAME) + with open(vocab_file, "w", encoding="utf-8") as writer: + for token, token_index in sorted(self.vocab.items(), + key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: vocabulary " + "indices are not consecutive. Please " + "check that the vocabulary is not " + "corrupted!".format(vocab_file)) + index = token_index + writer.write(token + u'\n') + index += 1 + return vocab_file + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, + *inputs, **kwargs): + """ + Instantiate a PreTrainedBertModel from a pre-trained model file. + Download and cache the pre-trained model file if needed. + """ + if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: + vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ + pretrained_model_name_or_path] + if ('-cased' in pretrained_model_name_or_path + and kwargs.get('do_lower_case', True)): + logger.warning("The pre-trained model you are loading is " + "a cased model but you have not set " + "`do_lower_case` to False. We are setting " + "`do_lower_case=False` for you but " + "you may want to check this behavior.") + kwargs['do_lower_case'] = False + elif ('-cased' not in pretrained_model_name_or_path + and not kwargs.get('do_lower_case', True)): + logger.warning("The pre-trained model you are loading is " + "a uncased model but you have set " + "`do_lower_case` to False. We are setting " + "`do_lower_case=True` for you but " + "you may want to check this behavior.") + kwargs['do_lower_case'] = True + else: + vocab_file = pretrained_model_name_or_path + if os.path.isdir(vocab_file): + vocab_file = os.path.join(vocab_file, VOCAB_NAME) + # redirect to the cache, if necessary + try: + resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) + except EnvironmentError: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file" + " associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), + vocab_file)) + return None + if resolved_vocab_file == vocab_file: + logger.info("loading vocabulary file {}".format(vocab_file)) + else: + logger.info("loading vocabulary file {} from cache at {}".format( + vocab_file, resolved_vocab_file)) + if (pretrained_model_name_or_path + in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP): + # if we're using a pretrained model, ensure the tokenizer wont + # index sequences longer than the number of positional embeddings + max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ + pretrained_model_name_or_path] + kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Instantiate tokenizer. + tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) + return tokenizer + + +class BasicTokenizer(object): + """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" + + def __init__(self, + do_lower_case=True, + never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): + """Constructs a BasicTokenizer. + + Args: + do_lower_case: Whether to lower case the input. + """ + self.do_lower_case = do_lower_case + self.never_split = never_split + + def tokenize(self, text): + """Tokenizes a piece of text.""" + text = self._clean_text(text) + # This was added on Nov. 1st, 2018 for the multilingual and Chinese + # models. This is also applied to the English models now, but it does + # not matter since the English models were not trained on any Chinese + # data and generally don't have any Chinese data in them (there are + # Chinese characters in the vocabulary because Wikipedia does have + # some Chinese words in the English Wikipedia.). + text = self._tokenize_chinese_chars(text) + orig_tokens = whitespace_tokenize(text) + split_tokens = [] + for token in orig_tokens: + if self.do_lower_case and token not in self.never_split: + token = token.lower() + token = self._run_strip_accents(token) + split_tokens.extend(self._run_split_on_punc(token)) + + output_tokens = whitespace_tokenize(" ".join(split_tokens)) + return output_tokens + + def _run_strip_accents(self, text): + """Strips accents from a piece of text.""" + text = unicodedata.normalize("NFD", text) + output = [] + for char in text: + cat = unicodedata.category(char) + if cat == "Mn": + continue + output.append(char) + return "".join(output) + + def _run_split_on_punc(self, text): + """Splits punctuation on a piece of text.""" + if text in self.never_split: + return [text] + chars = list(text) + i = 0 + start_new_word = True + output = [] + while i < len(chars): + char = chars[i] + if _is_punctuation(char): + output.append([char]) + start_new_word = True + else: + if start_new_word: + output.append([]) + start_new_word = False + output[-1].append(char) + i += 1 + + return ["".join(x) for x in output] + + def _tokenize_chinese_chars(self, text): + """Adds whitespace around any CJK character.""" + output = [] + for char in text: + cp = ord(char) + if self._is_chinese_char(cp): + output.append(" ") + output.append(char) + output.append(" ") + else: + output.append(char) + return "".join(output) + + def _is_chinese_char(self, cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "Chinese character" as anything in CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that CJK Unicode block is NOT all Japanese and Korean chars, + # despite its name. Modern Korean Hangul alphabet is a different block + # as is Japanese Hiragana and Katakana. Those alphabets are used to + # write space-separated words, so they are not treated specially and + # handled like the all of the other languages. + if ((cp >= 0x4E00 and cp <= 0x9FFF) or # + (cp >= 0x3400 and cp <= 0x4DBF) or # + (cp >= 0x20000 and cp <= 0x2A6DF) or # + (cp >= 0x2A700 and cp <= 0x2B73F) or # + (cp >= 0x2B740 and cp <= 0x2B81F) or # + (cp >= 0x2B820 and cp <= 0x2CEAF) or + (cp >= 0xF900 and cp <= 0xFAFF) or # + (cp >= 0x2F800 and cp <= 0x2FA1F)): # + return True + + return False + + def _clean_text(self, text): + """Performs invalid character removal + and whitespace cleanup on text.""" + output = [] + for char in text: + cp = ord(char) + if cp == 0 or cp == 0xfffd or _is_control(char): + continue + if _is_whitespace(char): + output.append(" ") + else: + output.append(char) + return "".join(output) + + +class WordpieceTokenizer(object): + """Runs WordPiece tokenization.""" + + def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): + self.vocab = vocab + self.unk_token = unk_token + self.max_input_chars_per_word = max_input_chars_per_word + + def tokenize(self, text): + """Tokenizes a piece of text into its word pieces. + + This uses a greedy longest-match-first algorithm to perform + tokenization using the given vocabulary. + + For example: + input = "unaffable" + output = ["un", "##aff", "##able"] + + Args: + text: A single token or whitespace separated tokens. This should have + already been passed through `BasicTokenizer`. + + Returns: + A list of wordpiece tokens. + """ + + output_tokens = [] + for token in whitespace_tokenize(text): + chars = list(token) + if len(chars) > self.max_input_chars_per_word: + output_tokens.append(self.unk_token) + continue + + is_bad = False + start = 0 + sub_tokens = [] + while start < len(chars): + end = len(chars) + cur_substr = None + while start < end: + substr = "".join(chars[start:end]) + if start > 0: + substr = "##" + substr + if substr in self.vocab: + cur_substr = substr + break + end -= 1 + if cur_substr is None: + is_bad = True + break + sub_tokens.append(cur_substr) + start = end + + if is_bad: + output_tokens.append(self.unk_token) + else: + output_tokens.extend(sub_tokens) + return output_tokens + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or + (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False diff --git a/onmt/utils/bert_vocab_archive_map.py b/onmt/utils/bert_vocab_archive_map.py new file mode 100644 index 0000000000..9987a2edb9 --- /dev/null +++ b/onmt/utils/bert_vocab_archive_map.py @@ -0,0 +1,18 @@ +# coding=utf-8 +# flake8: noqa + +PRETRAINED_VOCAB_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", + 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", + 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", +} \ No newline at end of file diff --git a/onmt/utils/file_utils.py b/onmt/utils/file_utils.py new file mode 100644 index 0000000000..caddaa819d --- /dev/null +++ b/onmt/utils/file_utils.py @@ -0,0 +1,236 @@ +""" +Utilities for working with the local dataset cache. +Get from https://github.com/huggingface/pytorch-transformers. +This file is adapted from the AllenNLP library +at https://github.com/allenai/allennlp +Copyright by the AllenNLP authors. +""" +from __future__ import absolute_import, division, \ + print_function, unicode_literals + +import sys +import json +import logging +import os +import shutil +import tempfile +import fnmatch +from functools import wraps +from hashlib import sha256 +from io import open + +import boto3 +from botocore.exceptions import ClientError +import requests +from tqdm import tqdm + +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse + +try: + from pathlib import Path + PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv( + 'PYTORCH_PRETRAINED_BERT_CACHE', + Path.home() / '.pytorch_pretrained_bert')) +except (AttributeError, ImportError): + PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( + 'PYTORCH_PRETRAINED_BERT_CACHE', + os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def url_to_filename(url, etag=None): + """ + Convert `url` into a hashed filename in a repeatable way. + If `etag` is specified, append its hash to the url's, delimited + by a period. + """ + url_bytes = url.encode('utf-8') + url_hash = sha256(url_bytes) + filename = url_hash.hexdigest() + + if etag: + etag_bytes = etag.encode('utf-8') + etag_hash = sha256(etag_bytes) + filename += '.' + etag_hash.hexdigest() + + return filename + + +def cached_path(url_or_filename, cache_dir=None): + """ + Given something that might be a URL (or might be a local path), + determine which. If it's a URL, download the file and cache it, and + return the path to the cached file. If it's already a local path, + make sure the file exists and then return the path. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): + url_or_filename = str(url_or_filename) + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + parsed = urlparse(url_or_filename) + + if parsed.scheme in ('http', 'https', 's3'): + # URL, so get it from the cache (downloading if necessary) + return get_from_cache(url_or_filename, cache_dir) + elif os.path.exists(url_or_filename): + # File, and it exists. + return url_or_filename + elif parsed.scheme == '': + # File, but it doesn't exist. + raise EnvironmentError("file {} not found".format(url_or_filename)) + else: + # Something unknown + raise ValueError( + "unable to parse {} as a URL/local path".format(url_or_filename)) + + +def split_s3_path(url): + """Split a full s3 path into the bucket name and path.""" + parsed = urlparse(url) + if not parsed.netloc or not parsed.path: + raise ValueError("bad s3 path {}".format(url)) + bucket_name = parsed.netloc + s3_path = parsed.path + # Remove '/' at beginning of path. + if s3_path.startswith("/"): + s3_path = s3_path[1:] + return bucket_name, s3_path + + +def s3_request(func): + """ + Wrapper function for s3 requests in order to create more helpful error + messages. + """ + + @wraps(func) + def wrapper(url, *args, **kwargs): + try: + return func(url, *args, **kwargs) + except ClientError as exc: + if int(exc.response["Error"]["Code"]) == 404: + raise EnvironmentError("file {} not found".format(url)) + else: + raise + + return wrapper + + +@s3_request +def s3_etag(url): + """Check ETag on S3 object.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_object = s3_resource.Object(bucket_name, s3_path) + return s3_object.e_tag + + +@s3_request +def s3_get(url, temp_file): + """Pull a file directly from S3.""" + s3_resource = boto3.resource("s3") + bucket_name, s3_path = split_s3_path(url) + s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) + + +def http_get(url, temp_file): + req = requests.get(url, stream=True) + content_length = req.headers.get('Content-Length') + total = int(content_length) if content_length is not None else None + progress = tqdm(unit="B", total=total) + for chunk in req.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + temp_file.write(chunk) + progress.close() + + +def get_from_cache(url, cache_dir=None): + """ + Given a URL, look for the corresponding dataset in the local cache. + If it's not there, download it. Then return the path to the cached file. + """ + if cache_dir is None: + cache_dir = PYTORCH_PRETRAINED_BERT_CACHE + if sys.version_info[0] == 3 and isinstance(cache_dir, Path): + cache_dir = str(cache_dir) + + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Get eTag to add to filename, if it exists. + if url.startswith("s3://"): + etag = s3_etag(url) + else: + try: + response = requests.head(url, allow_redirects=True) + if response.status_code != 200: + etag = None + else: + etag = response.headers.get("ETag") + except EnvironmentError: + etag = None + + if sys.version_info[0] == 2 and etag is not None: + etag = etag.decode('utf-8') + filename = url_to_filename(url, etag) + + # get cache path to put the file + cache_path = os.path.join(cache_dir, filename) + + # If we don't have a connection (etag is None) and can't identify the file + # try to get the last downloaded one + if not os.path.exists(cache_path) and etag is None: + matching_files = fnmatch.filter(os.listdir(cache_dir), + filename + '.*') + matching_files = list( + filter(lambda s: not s.endswith('.json'), matching_files)) + if matching_files: + cache_path = os.path.join(cache_dir, matching_files[-1]) + + if not os.path.exists(cache_path): + # Download to temporary file, then copy to cache dir once finished. + # Or you get corrupt cache entries if the download gets interrupted. + with tempfile.NamedTemporaryFile() as temp_file: + logger.info("%s not found in cache, downloading to %s", + url, temp_file.name) + + # GET file object + if url.startswith("s3://"): + s3_get(url, temp_file) + else: + http_get(url, temp_file) + + # we are copying the file before close it, so flush to avoid trunc + temp_file.flush() + # shutil.copyfileobj() starts at the current position, + # so go to the start + temp_file.seek(0) + + logger.info("copying %s to cache at %s", + temp_file.name, cache_path) + with open(cache_path, 'wb') as cache_file: + shutil.copyfileobj(temp_file, cache_file) + + logger.info("creating metadata file for %s", cache_path) + meta = {'url': url, 'etag': etag} + meta_path = cache_path + '.json' + with open(meta_path, 'w') as meta_file: + output_string = json.dumps(meta) + if (sys.version_info[0] == 2 + and isinstance(output_string, str)): + # The beauty of python 2 + output_string = unicode( # noqa: F821 + output_string, 'utf-8') + meta_file.write(output_string) + + logger.info("removing temp file %s", temp_file.name) + + return cache_path diff --git a/onmt/utils/loss.py b/onmt/utils/loss.py index c48f0d3d21..2f492ff768 100644 --- a/onmt/utils/loss.py +++ b/onmt/utils/loss.py @@ -10,6 +10,7 @@ import onmt from onmt.modules.sparse_losses import SparsemaxLoss from onmt.modules.sparse_activations import LogSparsemax +from sklearn.metrics import f1_score def build_loss_compute(model, tgt_field, opt, train=True): @@ -22,48 +23,217 @@ def build_loss_compute(model, tgt_field, opt, train=True): for when using a copy mechanism. """ device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") - - padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] - unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] - - if opt.lambda_coverage != 0: - assert opt.coverage_attn, "--coverage_attn needs to be set in " \ - "order to use --lambda_coverage != 0" - - if opt.copy_attn: - criterion = onmt.modules.CopyGeneratorLoss( - len(tgt_field.vocab), opt.copy_attn_force, - unk_index=unk_idx, ignore_index=padding_idx - ) - elif opt.label_smoothing > 0 and train: - criterion = LabelSmoothingLoss( - opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx - ) - elif isinstance(model.generator[-1], LogSparsemax): - criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') - else: - criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') - - # if the loss function operates on vectors of raw logits instead of - # probabilities, only the first part of the generator needs to be - # passed to the NMTLossCompute. At the moment, the only supported - # loss function of this kind is the sparsemax loss. - use_raw_logits = isinstance(criterion, SparsemaxLoss) - loss_gen = model.generator[0] if use_raw_logits else model.generator - if opt.copy_attn: - compute = onmt.modules.CopyGeneratorLossCompute( - criterion, loss_gen, tgt_field.vocab, opt.copy_loss_by_seqlength, - lambda_coverage=opt.lambda_coverage - ) + if opt.is_bert: + if tgt_field.pad_token is not None: + if tgt_field.use_vocab: + padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] + else: # target is pre-numerized: -1 for unmasked token in mlm + padding_idx = tgt_field.pad_token + criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='mean') + else: # sentence level + criterion = nn.NLLLoss(reduction='mean') + task = opt.task_type + compute = BertLoss(criterion, task) else: - compute = NMTLossCompute( - criterion, loss_gen, lambda_coverage=opt.lambda_coverage, - lambda_align=opt.lambda_align) + assert isinstance(model, onmt.models.NMTModel) + padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] + unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] + if opt.lambda_coverage != 0: + assert opt.coverage_attn, "--coverage_attn needs to be set in " \ + "order to use --lambda_coverage != 0" + if opt.copy_attn: + criterion = onmt.modules.CopyGeneratorLoss( + len(tgt_field.vocab), opt.copy_attn_force, + unk_index=unk_idx, ignore_index=padding_idx + ) + elif opt.label_smoothing > 0 and train: + criterion = LabelSmoothingLoss(opt.label_smoothing, + len(tgt_field.vocab), + ignore_index=padding_idx) + elif isinstance(model.generator[-1], LogSparsemax): + criterion = SparsemaxLoss(ignore_index=padding_idx, + reduction='sum') + else: + criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') + + # if the loss function operates on vectors of raw logits instead of + # probabilities, only the first part of the generator needs to be + # passed to the NMTLossCompute. At the moment, the only supported + # loss function of this kind is the sparsemax loss. + use_raw_logits = isinstance(criterion, SparsemaxLoss) + loss_gen = model.generator[0] if use_raw_logits else model.generator + if opt.copy_attn: + compute = onmt.modules.CopyGeneratorLossCompute( + criterion, loss_gen, tgt_field.vocab, + opt.copy_loss_by_seqlength, + lambda_coverage=opt.lambda_coverage + ) + else: + compute = NMTLossCompute( + criterion, loss_gen, lambda_coverage=opt.lambda_coverage, + lambda_align=opt.lambda_align) compute.to(device) - return compute +class BertLoss(nn.Module): + """Class for managing BERT loss computation which is reduced by mean. + + Args: + criterion (:obj:`nn.NLLLoss`) : module that measures loss + between input and target. + task (str): BERT downstream task. + """ + def __init__(self, criterion, task): + super(BertLoss, self).__init__() + self.criterion = criterion + self.task = task + + @property + def padding_idx(self): + return self.criterion.ignore_index + + def _bottle(self, _v): + return _v.view(-1, _v.size(2)) + + def _stats(self, loss, tokens_scores, tokens_target, + sents_scores, sents_target): + """ + Args: + loss (:obj:`FloatTensor`): the loss reduced by mean. + tokens_scores (:obj:`FloatTensor`): scores for each token + tokens_target (:obj:`FloatTensor`): true targets for each token + sents_scores (:obj:`FloatTensor`): scores for each sentence + sents_target (:obj:`FloatTensor`): true targets for each sentence + + Returns: + :obj:`onmt.utils.BertStatistics` : statistics for this batch. + """ + if self.task == 'pretraining': + # masked lm task: token level + pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) + non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) + tokens_match = pred_tokens.eq( + tokens_target).masked_select(non_padding) + n_correct_tokens = tokens_match.sum().item() + n_tokens = non_padding.sum().item() + f1 = 0 + # next sentence prediction task: sentence level + pred_sents = sents_scores.argmax(-1) # (B, 2) -> (2) + n_correct_sents = sents_target.eq(pred_sents).sum().item() + n_sentences = len(sents_target) + + elif self.task == 'classification': + # token level task: Not valide + n_correct_tokens = 0 + n_tokens = 0 + f1 = 0 + # sentence level task: + pred_sents = sents_scores.argmax(-1) # (B, n_label) -> (n_label) + n_correct_sents = sents_target.eq(pred_sents).sum().item() + n_sentences = len(sents_target) + + elif self.task == 'tagging': + # token level task: + pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) + non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) + tokens_match = pred_tokens.eq( + tokens_target).masked_select(non_padding) + n_correct_tokens = tokens_match.sum().item() + n_tokens = non_padding.sum().item() + # for f1: + tokens_target_select = tokens_target.masked_select(non_padding) + pred_tokens_select = pred_tokens.masked_select(non_padding) + f1 = f1_score(tokens_target_select.cpu(), + pred_tokens_select.cpu(), average="micro") + + # sentence level task: Not valide + n_correct_sents = 0 + n_sentences = 0 + + elif self.task == 'generation': + # token level task: + pred_tokens = tokens_scores.argmax(1) # (B*S, V) -> (B*S) + non_padding = tokens_target.ne(self.padding_idx) # mask: (B*S) + tokens_match = pred_tokens.eq( + tokens_target).masked_select(non_padding) + n_correct_tokens = tokens_match.sum().item() + n_tokens = non_padding.sum().item() + f1 = 0 + # sentence level task: Not valide + n_correct_sents = 0 + n_sentences = 0 + else: + raise ValueError("task %s not available!" % (self.task)) + + return onmt.utils.BertStatistics(loss.item(), n_tokens, + n_correct_tokens, n_sentences, + n_correct_sents, f1) + + def forward(self, batch, outputs): + """ + Args: + batch (Tensor): batch of examples + outputs (tuple of Tensor): (seq_class_log_prob:``(B, 2)``, + prediction_log_prob:``(B, S, vocab)``) + + Returns: + (float, BertStatistics) + * total_loss: total loss of input batch reduced by 'mean'. + * stats: A statistic object. + """ + + assert isinstance(outputs, tuple) + seq_class_log_prob, prediction_log_prob = outputs + if self.task == 'pretraining': + assert list(seq_class_log_prob.size()) == [len(batch), 2] + # masked lm task: token level(loss mean by number of tokens) + gtruth_tokens = batch.lm_labels_ids # (B, S) + bottled_gtruth_tokens = gtruth_tokens.view(-1) # (B, S) + # prediction: (B, S, V) -> (B * S, V) + bottled_prediction_log_prob = self._bottle(prediction_log_prob) + mask_loss = self.criterion(bottled_prediction_log_prob, + bottled_gtruth_tokens) + # next sentence prediction task: sentence level(mean by sentence) + gtruth_sentences = batch.is_next # (B,) + next_loss = self.criterion(seq_class_log_prob, gtruth_sentences) + total_loss = next_loss + mask_loss # total_loss reduced by mean + + elif self.task == 'classification': + assert prediction_log_prob is None + assert hasattr(batch, 'category') + # token level task: Not valide + bottled_prediction_log_prob = None + bottled_gtruth_tokens = None + # sentence level task: loss mean by number of sentences + gtruth_sentences = batch.category + total_loss = self.criterion(seq_class_log_prob, gtruth_sentences) + + elif self.task == 'tagging' or self.task == 'generation': + assert seq_class_log_prob is None + assert hasattr(batch, 'token_labels') + # token level task: loss mean by number of tokens + gtruth_tokens = batch.token_labels # (B, S) + bottled_gtruth_tokens = gtruth_tokens.view(-1) # (B, S) + # prediction: (B, S, V) -> (B * S, V) + bottled_prediction_log_prob = self._bottle(prediction_log_prob) + total_loss = self.criterion(bottled_prediction_log_prob, + bottled_gtruth_tokens) + # sentence level task: Not valide + seq_class_log_prob = None + gtruth_sentences = None + + else: + raise ValueError("task %s not available!" % (self.task)) + + stats = self._stats(total_loss.clone(), + bottled_prediction_log_prob, + bottled_gtruth_tokens, + seq_class_log_prob, + gtruth_sentences) + return total_loss, stats + + class LossComputeBase(nn.Module): """ Class for managing efficient loss computation. Handles diff --git a/onmt/utils/optimizers.py b/onmt/utils/optimizers.py index f27d062034..ccd0162ba2 100644 --- a/onmt/utils/optimizers.py +++ b/onmt/utils/optimizers.py @@ -5,7 +5,7 @@ import operator import functools from copy import copy -from math import sqrt +from math import sqrt, cos, pi import types import importlib from onmt.utils.misc import fn_args @@ -55,6 +55,15 @@ def build_torch_optimizer(model, opt): lr=opt.learning_rate, betas=betas, eps=1e-9) + elif opt.optim == 'bertadam': + optimizer = AdamW( + params, + lr=opt.learning_rate, + betas=betas, + eps=1e-9, + amsgrad=False, + correct_bias=False, + weight_decay=0.01) elif opt.optim == 'sparseadam': dense = [] sparse = [] @@ -123,6 +132,33 @@ def make_learning_rate_decay_fn(opt): rate=opt.learning_rate_decay, decay_steps=opt.decay_steps, start_step=opt.start_decay_steps) + elif opt.decay_method == 'linear': + return functools.partial( + linear_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps) + elif opt.decay_method == 'linearconst': + return functools.partial( + linear_decay, + warmup_steps=opt.warmup_steps) + elif opt.decay_method == 'cosine': + return functools.partial( + cosine_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps, + cycles=opt.cycles if opt.cycles is not None else 0.5) + elif opt.decay_method == 'cosine_hard_restart': + return functools.partial( + cosine_hard_restart_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps, + cycles=opt.cycles if opt.cycles is not None else 1.0) + elif opt.decay_method == 'cosine_warmup_restart': + return functools.partial( + cosine_warmup_restart_decay, + warmup_steps=opt.warmup_steps, + total_steps=opt.train_steps, + cycles=opt.cycles if opt.cycles is not None else 1.0) elif opt.decay_method == 'rsqrt': return functools.partial( rsqrt_decay, warmup_steps=opt.warmup_steps) @@ -165,6 +201,77 @@ def rsqrt_decay(step, warmup_steps): return 1.0 / sqrt(max(step, warmup_steps)) +def linear_decay(step, warmup_steps, total_steps): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, linearly decrease the lr from 1 to 0 over (warmup_steps, train_step) + """ + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if step < warmup_steps: + return step / warmup_steps * 1.0 + else: + return max((total_steps - step) / (total_steps - warmup_steps), 0) + + +def linear_constant_decay(step, warmup_steps): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, keep constant. + """ + if step < warmup_steps: + return step / warmup_steps * 1.0 + return 1.0 + + +def cosine_decay(step, warmup_steps, total_steps, cycles=0.5): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, decrease lr from 1 to 0 over (warmup_steps, train_step) + following cosine curve. + """ + + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if step < warmup_steps: + return step / warmup_steps * 1.0 + else: + progress = (step - warmup_steps) / (total_steps - warmup_steps) + return 0.5 * (1 + cos(pi * cycles * 2 * progress)) + + +def cosine_hard_restart_decay(step, warmup_steps, total_steps, cycles=1.0): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, decrease the lr from 1 over (warmup_steps, train_step) + following cosine curve. + If `cycles` is different from default(1.0), learning rate follows + `cycles` times a cosine decaying learning rate (with hard restarts). + """ + assert(cycles >= 1.0) + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if step < warmup_steps: + return step / warmup_steps * 1.0 + else: + progress = (step - warmup_steps) / (total_steps - warmup_steps) + return 0.5 * (1 + cos(pi * ((cycles * progress) % 1))) + + +def cosine_warmup_restart_decay(step, warmup_steps, total_steps, cycles=1.0): + """Linearly increase the lr from 0 to 1 over (0, warmup_steps), + Then, decrease the lr from 1 to 0 over (warmup_steps, train_step) + following cosine curve. + """ + if not 0 <= warmup_steps < total_steps: + raise ValueError("Invalid decay: check warmup_step & train_steps") + if not cycles * warmup_steps / total_steps < 1.0: + raise ValueError("Invalid decay: Error for decay! Check cycles!") + warmup_ratio = warmup_steps * cycles / total_steps + progress = (step * cycles / total_steps) % 1 + if progress < warmup_ratio: + return progress / warmup_ratio + else: + progress = (progress - warmup_ratio) / (1 - warmup_ratio) + return 0.5 * (1 + cos(pi * progress)) + + class MultipleOptimizer(object): """ Implement multiple optimizers needed for sparse adam """ @@ -252,7 +359,8 @@ def from_opt(cls, model, opt, checkpoint=None): optim_opt = opt optim_state_dict = None - if opt.train_from and checkpoint is not None: + if opt.train_from and checkpoint is not None \ + and 'optim' in checkpoint: optim = checkpoint['optim'] ckpt_opt = checkpoint['opt'] ckpt_state_dict = {} @@ -535,6 +643,123 @@ def step(self, closure=None): return loss +class AdamW(torch.optim.Optimizer): + r"""Implements Adam algorithm with weight decay fix, compensate for bias + can be turned off (as in BERT) with option correct_bias. + Enable not use correct_bias comparing to torch.optim.adamw. + + Args: + params (iterable): iterable of parameters to optimize or dicts define + parameter groups + lr (float): learning rate + betas (tuple of float): Adam (beta1, beta2). Default: (0.9, 0.999) + eps (float): Adams epsilon. Default: 1e-6 + amsgrad (bool): whether to use the AMSGrad variant of this algorithm + from the paper `On the Convergence of Adam and Beyond`_ + Default: False. + weight_decay (float): Weight decay. Default: 0.01 + correct_bias (bool): whether to use bias correction. Default: True. + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _Decoupled Weight Decay Regularization: + https://arxiv.org/abs/1711.05101 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=None, betas=(0.9, 0.999), eps=1e-6, + amsgrad=False, correct_bias=True, weight_decay=0.01): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr) + + " - should be >= 0.0") + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps) + + " - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid betas[0] parameter: {}".format( + betas[0]) + " - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid betas[1] parameter: {}".format( + betas[1]) + " - should be in [0.0, 1.0)") + defaults = dict(lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, + correct_bias=correct_bias, weight_decay=weight_decay) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam: not support sparse gradients,' + + 'please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay first and second moment running average coefficient + # exp_avg = exp_avg * beta1 + (1-beta1)*grad + exp_avg.mul_(beta1).add_(1 - beta1, grad) + # exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2)*grad**2 + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + if amsgrad: + # Maintains max of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + step_size = group['lr'] + # NOTE: AdamW used in Bert has "No bias correction" + if group['correct_bias']: + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + step_size = (step_size * sqrt(bias_correction2) + / bias_correction1) + + p.data.addcdiv_(-step_size, exp_avg, denom) + + # Perform correct weight decay(rather than L2) + p.data.mul_(1 - group['lr'] * group['weight_decay']) + return loss + + class FusedAdam(torch.optim.Optimizer): """Implements Adam algorithm. Currently GPU-only. diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 273dae3dba..de199c060b 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -5,6 +5,7 @@ import onmt.opts as opts from onmt.utils.logging import logger +from onmt.utils import PRETRAINED_VOCAB_ARCHIVE_MAP class ArgumentParser(cfargparse.ArgumentParser): @@ -46,6 +47,9 @@ def update_model_opts(cls, model_opt): if model_opt.copy_attn_type is None: model_opt.copy_attn_type = model_opt.global_attention + if not hasattr(model_opt, 'is_bert'): + model_opt.is_bert = False + if model_opt.alignment_layer is None: model_opt.alignment_layer = -2 model_opt.lambda_align = 0.0 @@ -91,6 +95,16 @@ def ckpt_model_opts(cls, ckpt_opt): @classmethod def validate_train_opts(cls, opt): + if opt.is_bert: + logger.info("WE ARE IN BERT MODE.") + if opt.task_type == "none": + raise ValueError( + "Downstream task should be chosen when use BERT.") + if opt.reuse_embeddings is True: + if opt.task_type != "pretraining": + opt.reuse_embeddings = False + logger.warning( + "reuse_embeddings not available for this task.") if opt.epochs: raise AssertionError( "-epochs is deprecated please use -train_steps.") @@ -164,3 +178,70 @@ def validate_preprocess_args(cls, opt): "Please check path of your src vocab!" assert not opt.tgt_vocab or os.path.isfile(opt.tgt_vocab), \ "Please check path of your tgt vocab!" + + @classmethod + def validate_preprocess_bert_opts(cls, opt): + assert opt.vocab_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ + "Unsupported Pretrain model '%s'" % (opt.vocab_model) + if '-cased' in opt.vocab_model and opt.do_lower_case is True: + logger.warning("The pre-trained model you are loading is " + + "cased model, you shouldn't set `do_lower_case`," + + "we turned it off for you.") + opt.do_lower_case = False + elif '-cased' not in opt.vocab_model and not opt.do_lower_case: + logger.warning("The pre-trained model you are loading is " + + "uncased model, you should set `do_lower_case`, " + + "we turned it on for you.") + opt.do_lower_case = True + + for filename in opt.data: + assert os.path.isfile(filename),\ + "Please check path of %s" % filename + + if opt.task == "tagging": + assert opt.file_type == 'txt' and len(opt.data) == 1,\ + "For sequence tagging, only single txt file is supported." + opt.data = opt.data[0] + + assert len(opt.input_columns) == 1,\ + "For sequence tagging, only one column for input tokens." + opt.input_columns = opt.input_columns[0] + + assert opt.label_column is not None,\ + "For sequence tagging, label column should be given." + + if opt.task == "classification": + if opt.file_type == "csv": + assert len(opt.data) == 1,\ + "For csv, only single file is needed." + opt.data = opt.data[0] + assert len(opt.input_columns) in [1, 2],\ + "Please indicate colomn of sentence A (and B)" + assert opt.label_column is not None,\ + "For csv file, label column should be given." + if opt.delimiter != '\t': + logger.warning("for csv file, we set delimiter to '\t'") + opt.delimiter = '\t' + return opt + + @classmethod + def validate_predict_opts(cls, opt): + if opt.delimiter is None: + if opt.task == 'classification': + opt.delimiter = ' ||| ' + else: + opt.delimiter = ' ' + logger.info("NOTICE: opt.delimiter set to `%s`" % opt.delimiter) + assert opt.vocab_model in PRETRAINED_VOCAB_ARCHIVE_MAP.keys(), \ + "Unsupported Pretrain model '%s'" % (opt.vocab_model) + if '-cased' in opt.vocab_model and opt.do_lower_case is True: + logger.info("WARNING: The pre-trained model you are loading " + + "is cased model, you shouldn't set `do_lower_case`," + + "we turned it off for you.") + opt.do_lower_case = False + elif '-cased' not in opt.vocab_model and not opt.do_lower_case: + logger.info("WARNING: The pre-trained model you are loading " + + "is uncased model, you should set `do_lower_case`, " + + "we turned it on for you.") + opt.do_lower_case = True + return opt diff --git a/onmt/utils/report_manager.py b/onmt/utils/report_manager.py index 7e1d546c2d..422052f58a 100644 --- a/onmt/utils/report_manager.py +++ b/onmt/utils/report_manager.py @@ -73,6 +73,8 @@ def report_training(self, step, num_steps, learning_rate, onmt.utils.Statistics.all_gather_stats(report_stats) self._report_training( step, num_steps, learning_rate, report_stats) + if isinstance(report_stats, onmt.utils.BertStatistics): + return onmt.utils.BertStatistics() return onmt.utils.Statistics() else: return report_stats @@ -128,7 +130,10 @@ def _report_training(self, step, num_steps, learning_rate, "progress", learning_rate, step) - report_stats = onmt.utils.Statistics() + if isinstance(report_stats, onmt.utils.BertStatistics): + report_stats = onmt.utils.BertStatistics() + else: + report_stats = onmt.utils.Statistics() return report_stats @@ -138,7 +143,14 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None): """ if train_stats is not None: self.log('Train perplexity: %g' % train_stats.ppl()) - self.log('Train accuracy: %g' % train_stats.accuracy()) + if train_stats.accuracy() is None: + assert isinstance(train_stats, onmt.utils.BertStatistics) + accuracy = train_stats.sentence_accuracy() + else: + accuracy = train_stats.accuracy() + self.log('Train accuracy: %g' % accuracy) + if hasattr(train_stats, 'f1') and train_stats.f1 != 0: + self.log('Train F1: %g' % train_stats.f1) self.maybe_log_tensorboard(train_stats, "train", @@ -147,7 +159,14 @@ def _report_step(self, lr, step, train_stats=None, valid_stats=None): if valid_stats is not None: self.log('Validation perplexity: %g' % valid_stats.ppl()) - self.log('Validation accuracy: %g' % valid_stats.accuracy()) + if valid_stats.accuracy() is None: + assert isinstance(valid_stats, onmt.utils.BertStatistics) + accuracy = valid_stats.sentence_accuracy() + else: + accuracy = valid_stats.accuracy() + self.log('Validation accuracy: %g' % accuracy) + if hasattr(valid_stats, 'f1') and valid_stats.f1 != 0: + self.log('Validation F1: %g' % valid_stats.f1) self.maybe_log_tensorboard(valid_stats, "valid", diff --git a/onmt/utils/statistics.py b/onmt/utils/statistics.py index 896d98c74d..b3a2629901 100644 --- a/onmt/utils/statistics.py +++ b/onmt/utils/statistics.py @@ -134,3 +134,136 @@ def log_tensorboard(self, prefix, writer, learning_rate, step): writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) writer.add_scalar(prefix + "/lr", learning_rate, step) + + +class BertStatistics(Statistics): + """ Bert Statistics as the loss is reduced by mean. + + Currently calculates: + * accuracy in token/sentence level + * perplexity + * elapsed time + * micro f1 for tagging + """ + def __init__(self, loss=0, n_words=0, n_correct=0, + n_sentence=0, n_correct_sentence=0, f1=0): + super(BertStatistics, self).__init__(loss, n_words, n_correct) + self.n_update = 0 if n_words == 0 and n_sentence == 0 else 1 + self.n_sentence = n_sentence + self.n_correct_sentence = n_correct_sentence + self.f1 = f1 + + def accuracy(self): + """ compute token level accuracy """ + if self.n_words != 0: + return 100 * (self.n_correct / self.n_words) + else: + return None + + def sentence_accuracy(self): + """ compute sentence level accuracy """ + if self.n_sentence != 0: + return 100 * (self.n_correct_sentence / self.n_sentence) + else: + return None + + def xent(self): + """ compute cross entropy """ + return self.loss + + def ppl(self): + """ compute perplexity """ + return math.exp(min(self.loss, 100)) + + def update(self, stat, update_n_src_words=False): + """ + Update statistics by suming values with another `Statistics` object + + Args: + stat (BertStatistics): another statistic object + update_n_src_words (bool): whether to update (sum) `n_src_words` + or not + + """ + assert isinstance(stat, BertStatistics) + # Loss for BERT is computed and reduced by average. + # Which is different from the NMTModel reduced by sum. + self.loss = (self.loss * self.n_update + stat.loss * + stat.n_update) / (self.n_update + stat.n_update) + self.n_update += 1 + self.n_words += stat.n_words + self.n_correct += stat.n_correct + self.n_sentence += stat.n_sentence + self.n_correct_sentence += stat.n_correct_sentence + self.f1 = (self.f1 * self.n_update + stat.f1 * + stat.n_update) / (self.n_update + stat.n_update) + + if update_n_src_words: + self.n_src_words += stat.n_src_words + + def output(self, step, num_steps, learning_rate, start): + """Write out statistics to stdout. + + Args: + step (int): current step + n_batch (int): total batches + start (int): start time of step. + """ + t = self.elapsed_time() + step_fmt = "%2d" % step + if num_steps > 0: + step_fmt = "%s/%5d" % (step_fmt, num_steps) + if self.n_words == 0: # sentence level task: Acc, PPL, X-entropy + logger.info( + ("Step %s; acc(sent):%6.2f; ppl: %5.2f; " + + "xent: %4.2f; lr: %7.5f; %3.0f tok/%3.0f sent/s; %6.0f sec") + % (step_fmt, + self.sentence_accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_sentence / (t + 1e-5), + time.time() - start)) + elif self.n_sentence == 0: # token level task: Tok Acc, F1, X-entropy + logger.info( + ("Step %s; acc(token):%6.2f; f1: %5.4f; " + + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + % (step_fmt, + self.accuracy(), + self.f1, + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) + else: # pretraining + logger.info( + ("Step %s; acc(mlm/nx):%6.2f/%6.2f; total ppl: %5.2f; " + + "xent: %4.2f; lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") + % (step_fmt, + self.accuracy(), + self.sentence_accuracy(), + self.ppl(), + self.xent(), + learning_rate, + self.n_src_words / (t + 1e-5), + self.n_words / (t + 1e-5), + time.time() - start)) + sys.stdout.flush() + + def log_tensorboard(self, prefix, writer, learning_rate, step): + """ display statistics to tensorboard """ + t = self.elapsed_time() + writer.add_scalar(prefix + "/xent", self.xent(), step) + writer.add_scalar(prefix + "/ppl", self.ppl(), step) + if self.n_words != 0: # Token level task + writer.add_scalar(prefix + "/accuracy_token", + self.accuracy(), step) + writer.add_scalar(prefix + "/F1", + self.f1, step) + if self.n_sentence != 0: # Sentence level task + writer.add_scalar(prefix + "/accuracy_sent", + self.sentence_accuracy(), step) + writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) + writer.add_scalar(prefix + "/lr", learning_rate, step) diff --git a/predict.py b/predict.py new file mode 100755 index 0000000000..5fdeb53f29 --- /dev/null +++ b/predict.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals + +from onmt.utils.logging import init_logger +from onmt.utils.misc import split_corpus +from onmt.translate.predictor import build_classifier, build_tagger + +import onmt.opts as opts +from onmt.utils.parse import ArgumentParser +from onmt.utils.bert_tokenization import BertTokenizer + + +def main(opt): + logger = init_logger(opt.log_file) + opt = ArgumentParser.validate_predict_opts(opt) + tokenizer = BertTokenizer.from_pretrained( + opt.vocab_model, do_lower_case=opt.do_lower_case) + data_shards = split_corpus(opt.data, opt.shard_size) + if opt.task == 'classification': + classifier = build_classifier(opt, logger) + for i, data_shard in enumerate(data_shards): + logger.info("Classify shard %d." % i) + data = [seq.decode("utf-8") for seq in data_shard] + classifier.classify( + data, + opt.batch_size, + tokenizer, + delimiter=opt.delimiter, + max_seq_len=opt.max_seq_len, + batch_type=opt.batch_type + ) + if opt.task == 'tagging': + tagger = build_tagger(opt, logger) + for i, data_shard in enumerate(data_shards): + logger.info("Tagging shard %d." % i) + data = [seq.decode("utf-8") for seq in data_shard] + tagger.tagging( + data, + opt.batch_size, + tokenizer, + delimiter=opt.delimiter, + max_seq_len=opt.max_seq_len, + batch_type=opt.batch_type + ) + + +def _get_parser(): + parser = ArgumentParser(description='predict.py') + opts.config_opts(parser) + opts.predict_opts(parser) + return parser + + +if __name__ == "__main__": + parser = _get_parser() + + opt = parser.parse_args() + main(opt) diff --git a/pregenerate_bert_training_data.py b/pregenerate_bert_training_data.py new file mode 100755 index 0000000000..a557410196 --- /dev/null +++ b/pregenerate_bert_training_data.py @@ -0,0 +1,424 @@ +""" +This file is lifted from huggingface and adapted for onmt structure. +Ref in https://github.com/huggingface/pytorch-transformers/. +""" +from argparse import ArgumentParser +from pathlib import Path +from tqdm import tqdm, trange +from tempfile import TemporaryDirectory +import shelve + +from random import random, randrange, randint, shuffle, choice +from onmt.utils import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP +from onmt.utils.file_utils import cached_path +from preprocess_bert import build_vocab_from_tokenizer +import numpy as np +import json +from onmt.inputters.inputter import get_bert_fields +from onmt.inputters.dataset_bert import BertDataset, truncate_seq_pair +import torch +import collections + + +class DocumentDatabase: + def __init__(self, reduce_memory=False): + if reduce_memory: + self.temp_dir = TemporaryDirectory() + self.working_dir = Path(self.temp_dir.name) + self.document_shelf_filepath = self.working_dir / 'shelf.db' + self.document_shelf = shelve.open( + str(self.document_shelf_filepath), flag='n', protocol=-1) + self.documents = None + else: + self.documents = [] + self.document_shelf = None + self.document_shelf_filepath = None + self.temp_dir = None + self.doc_lengths = [] + self.doc_cumsum = None + self.cumsum_max = None + self.reduce_memory = reduce_memory + + def add_document(self, document): + if not document: + return + if self.reduce_memory: + current_idx = len(self.doc_lengths) + self.document_shelf[str(current_idx)] = document + else: + self.documents.append(document) + self.doc_lengths.append(len(document)) + + def _precalculate_doc_weights(self): + self.doc_cumsum = np.cumsum(self.doc_lengths) + self.cumsum_max = self.doc_cumsum[-1] + + def sample_doc(self, current_idx, sentence_weighted=True): + # Uses the current_idx to ensure we don't sample the same doc twice + if sentence_weighted: + # With sentence weighting, we sample docs + # proportionally to their sentence length + if (self.doc_cumsum is None + or len(self.doc_cumsum) != len(self.doc_lengths)): + self._precalculate_doc_weights() + rand_start = self.doc_cumsum[current_idx] + rand_end = (rand_start + self.cumsum_max + - self.doc_lengths[current_idx]) + sentence_index = randrange(rand_start, rand_end) % self.cumsum_max + sampled_doc_index = np.searchsorted( + self.doc_cumsum, sentence_index, side='right') + else: + # If sentence weighting is False, chose doc equally + sampled_doc_index = ((current_idx + + randrange(1, len(self.doc_lengths))) + % len(self.doc_lengths)) + assert sampled_doc_index != current_idx + if self.reduce_memory: + return self.document_shelf[str(sampled_doc_index)] + else: + return self.documents[sampled_doc_index] + + def __len__(self): + return len(self.doc_lengths) + + def __getitem__(self, item): + if self.reduce_memory: + return self.document_shelf[str(item)] + else: + return self.documents[item] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + if self.document_shelf is not None: + self.document_shelf.close() + if self.temp_dir is not None: + self.temp_dir.cleanup() + + +MaskedLmInstance = collections.namedtuple("MaskedLmInstance", + ["index", "label"]) + + +def create_masked_lm_predictions(tokens, masked_lm_prob, + max_predictions_per_seq, + whole_word_mask, vocab_dict): + """Creates the predictions for the masked LM. This is mostly copied from + the Huggingface BERT repo, but pregenerate lm_labels_ids.""" + vocab_list = list(vocab_dict.keys()) + cand_indices = [] + for (i, token) in enumerate(tokens): + if token == "[CLS]" or token == "[SEP]": + continue + # Whole Word Masking means that if we mask all of the wordpieces + # corresponding to an original word. When a word has been split into + # WordPieces, the first token does not have any marker and any + # subsequence tokens are prefixed with ##. So whenever we see the ##, + # we append it to the previous set of word indexes. + # + # Note that Whole Word Masking does *not* change the training code + # at all -- we still predict each WordPiece independently, softmaxed + # over the entire vocabulary. + if (whole_word_mask and len(cand_indices) >= 1 + and token.startswith("##")): + cand_indices[-1].append(i) + else: + cand_indices.append([i]) + + num_to_mask = min(max_predictions_per_seq, + max(1, int(round(len(tokens) * masked_lm_prob)))) + shuffle(cand_indices) + masked_lms = [] + covered_indexes = set() + for index_set in cand_indices: + if len(masked_lms) >= num_to_mask: + break + # If adding a whole-word mask would exceed the maximum number of + # predictions, then just skip this candidate. + if len(masked_lms) + len(index_set) > num_to_mask: + continue + is_any_index_covered = False + for index in index_set: + if index in covered_indexes: + is_any_index_covered = True + break + if is_any_index_covered: + continue + for index in index_set: + covered_indexes.add(index) + + masked_token = None + # 80% of the time, replace with [MASK] + if random() < 0.8: + masked_token = "[MASK]" + else: + # 10% of the time, keep original + if random() < 0.5: + masked_token = tokens[index] + # 10% of the time, replace with random word + else: + masked_token = choice(vocab_list) + masked_lms.append(MaskedLmInstance(index=index, + label=tokens[index])) + # Replace true token with masked token + tokens[index] = masked_token + + assert len(masked_lms) <= num_to_mask + masked_lms = sorted(masked_lms, key=lambda x: x.index) + mask_indices = [p.index for p in masked_lms] + masked_token_labels = [p.label for p in masked_lms] + lm_labels_ids = [-1 for _ in tokens] + for (i, token) in zip(mask_indices, masked_token_labels): + lm_labels_ids[i] = vocab_dict[token] + assert len(lm_labels_ids) == len(tokens) + return tokens, mask_indices, masked_token_labels, lm_labels_ids + + +def create_instances_from_document( + doc_database, doc_idx, vocab_dict, max_seq_length, short_seq_prob, + masked_lm_prob, max_predictions_per_seq, whole_word_mask): + """This code is mostly a duplicate of the equivalent function from + HuggingFace BERT's repo. But we use lm_labels_ids rather than + mask_indices and masked_token_labels.""" + document = doc_database[doc_idx] + # Account for [CLS], [SEP], [SEP] + max_num_tokens = max_seq_length - 3 + + # We *usually* want to fill up the entire sequence since we are padding + # to `max_seq_length` anyways, so short sequences are generally wasted + # computation. However, we *sometimes* + # (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter + # sequences to minimize the mismatch between pre-training and fine-tuning. + # The `target_seq_length` is just a rough target however, whereas + # `max_seq_length` is a hard limit. + target_seq_length = max_num_tokens + if random() < short_seq_prob: + target_seq_length = randint(2, max_num_tokens) + + # We DON'T just concatenate all of the tokens from a document into a long + # sequence and choose an arbitrary split point because this would make the + # next sentence prediction task too easy. Instead, we split the input into + # segments "A" and "B" based on the actual "sentences" provided by user's + # input. + instances = [] + current_chunk = [] + current_length = 0 + i = 0 + while i < len(document): + segment = document[i] + current_chunk.append(segment) + current_length += len(segment) + if i == len(document) - 1 or current_length >= target_seq_length: + if current_chunk: + # `a_end` is how many segments from `current_chunk` go into + # `A` (first) sentence. + a_end = 1 + if len(current_chunk) >= 2: + a_end = randrange(1, len(current_chunk)) + + tokens_a = [] + for j in range(a_end): + tokens_a.extend(current_chunk[j]) + + tokens_b = [] + + # Random next + if len(current_chunk) == 1 or random() < 0.5: + is_next = False + target_b_length = target_seq_length - len(tokens_a) + + # Sample a random document with longer docs being + # sampled more frequently + random_document = doc_database.sample_doc( + current_idx=doc_idx, sentence_weighted=True) + + random_start = randrange(0, len(random_document)) + for j in range(random_start, len(random_document)): + tokens_b.extend(random_document[j]) + if len(tokens_b) >= target_b_length: + break + # We didn't actually use these segments so we + # "put them back" so they don't go to waste. + num_unused_segments = len(current_chunk) - a_end + i -= num_unused_segments + # Actual next + else: + is_next = True + for j in range(a_end, len(current_chunk)): + tokens_b.extend(current_chunk[j]) + truncate_seq_pair(tokens_a, tokens_b, max_num_tokens) + + assert len(tokens_a) >= 1 + assert len(tokens_b) >= 1 + + tokens = ["[CLS]"] + tokens_a + ["[SEP]"] + \ + tokens_b + ["[SEP]"] + segment_ids = [0 for _ in range(len(tokens_a) + 2)] + \ + [1 for _ in range(len(tokens_b) + 1)] + + tokens, _, _, lm_labels_ids = create_masked_lm_predictions( + tokens, masked_lm_prob, max_predictions_per_seq, + whole_word_mask, vocab_dict) + + instance = { + "tokens": tokens, + "segment_ids": segment_ids, + "is_next": is_next, + # "masked_lm_positions": masked_lm_positions, + # "masked_lm_labels": masked_lm_labels, + "lm_labels_ids": lm_labels_ids} + instances.append(instance) + current_chunk = [] + current_length = 0 + i += 1 + + return instances + + +def build_document_database(input_file, tokenizer, reduce_memory): + with DocumentDatabase(reduce_memory=reduce_memory) as docs: + with input_file.open() as f: + doc = [] + for line in tqdm(f, desc="Loading Dataset", unit=" lines"): + line = line.strip() + if line == "": + docs.add_document(doc) + doc = [] + else: + tokens = tokenizer.tokenize(line) + doc.append(tokens) + if len(doc) != 0: # If didn't end on a newline, still add + docs.add_document(doc) + if len(docs) <= 1: + exit("""ERROR: No document breaks were found in the input file! + These are necessary to ensure that random NextSentences + are not sampled from the same document. Please add blank + lines to indicate breaks between documents in your file. + If your dataset does not contain multiple documents, + blank lines can be inserted at any natural boundary, + such as the ends of chapters, sections or paragraphs.""") + return docs + + +def create_instances_from_docs(doc_database, vocab_dict, args): + docs_instances = [] + for doc_idx in trange(len(doc_database), desc="Document"): + doc_instances = create_instances_from_document( + doc_database, doc_idx, vocab_dict=vocab_dict, + max_seq_length=args.max_seq_len, + short_seq_prob=args.short_seq_prob, + masked_lm_prob=args.masked_lm_prob, + max_predictions_per_seq=args.max_predictions_per_seq, + whole_word_mask=args.do_whole_word_mask) + docs_instances.extend(doc_instances) + return docs_instances + + +def save_data_as_json(instances, json_name): + instances_json = [json.dumps(instance) for instance in instances] + num_instances = 0 + with open(json_name, 'w') as json_file: + for instance in instances_json: + json_file.write(instance + '\n') + num_instances += 1 + return num_instances + + +def _get_parser(): + parser = ArgumentParser() + parser.add_argument('--input_file', type=Path, required=True) + parser.add_argument("--output_dir", type=Path, required=True) + parser.add_argument("--output_name", type=str, default="dataset") + parser.add_argument('--corpus_type', type=str, default="train", + choices=['train', 'valid'], + help="Choose from ['train', 'valid'], " + + "Vocab file will be generate if `train`") + parser.add_argument("--vocab_model", type=str, required=True, + choices=["bert-base-uncased", "bert-large-uncased", + "bert-base-cased", "bert-large-cased", + "bert-base-multilingual-uncased", + "bert-base-multilingual-cased", + "bert-base-chinese", + "bert-base-german-cased", + "bert-large-uncased-whole-word-masking", + "bert-large-cased-whole-word-masking", + "bert-base-cased-finetuned-mrpc"], + help="Pretrained vocab model use to tokenizer text.") + + parser.add_argument("--do_lower_case", action="store_true") + parser.add_argument("--do_whole_word_mask", action="store_true", + help="Whether to use whole word masking.") + parser.add_argument("--reduce_memory", action="store_true", + help="""Reduce memory usage for large datasets + by keeping data on disc rather than in memory""") + + parser.add_argument("--epochs_to_generate", type=int, default=2, + help="Number of epochs of data to pregenerate") + parser.add_argument("--max_seq_len", type=int, default=128) + parser.add_argument("--short_seq_prob", type=float, default=0.1, + help="Prob. of a short sentence as training example") + parser.add_argument("--masked_lm_prob", type=float, default=0.15, + help="Prob. of masking each token for the LM task") + parser.add_argument("--max_predictions_per_seq", type=int, default=20, + help="Max number of tokens to mask in each sequence") + parser.add_argument("--save_json", action="store_true", + help='save a copy of data in json form.') + return parser + + +def main(args): + tokenizer = BertTokenizer.from_pretrained( + args.vocab_model, do_lower_case=args.do_lower_case) + + docs = build_document_database( + args.input_file, tokenizer, args.reduce_memory) + + fields = get_bert_fields() + vocab_dict = tokenizer.vocab + args.output_dir.mkdir(exist_ok=True) + + # Build file corpus.pt + for epoch in trange(args.epochs_to_generate, desc="Epoch"): + docs_instances = create_instances_from_docs(docs, vocab_dict, args) + + # build BertDataset from instances collected from different document + dataset = BertDataset(fields, docs_instances) + epoch_filename = args.output_dir / "{}.{}.{}.pt".format( + args.output_name, args.corpus_type, epoch) + dataset.save(epoch_filename) + print("output file {}, num_example {}, max_seq_len {}".format( + epoch_filename, len(docs_instances), args.max_seq_len)) + + if args.save_json: + json_name = args.output_dir / "{}.{}.{}.json".format( + args.output_name, args.corpus_type, epoch) + num_instances = save_data_as_json(docs_instances, json_name) + metrics_file = args.output_dir / "{}.{}.{}.metrics.json".format( + args.output_name, args.corpus_type, epoch) + with metrics_file.open('w') as metrics_file: + metrics = { + "num_training_examples": num_instances, + "max_seq_len": args.max_seq_len + } + metrics_file.write(json.dumps(metrics)) + + # Build file Vocab.pt + if args.corpus_type == "train": + vocab_file_url = PRETRAINED_VOCAB_ARCHIVE_MAP[args.vocab_model] + vocab_dir = Path.joinpath(args.output_dir, + "%s-vocab.txt" % (args.vocab_model)) + cached_vocab = cached_path(vocab_file_url, cache_dir=vocab_dir) + print("Vocab file is Cached at %s." % cached_vocab) + fields_vocab = build_vocab_from_tokenizer( + fields, tokenizer, None) + bert_vocab_file = Path.joinpath(args.output_dir, + "%s.vocab.pt" % (args.output_name)) + print("Build Fields Vocab file.") + torch.save(fields_vocab, bert_vocab_file) + + +if __name__ == '__main__': + parser = _get_parser() + args = parser.parse_args() + main(args) diff --git a/preprocess_bert.py b/preprocess_bert.py new file mode 100755 index 0000000000..82b3590f75 --- /dev/null +++ b/preprocess_bert.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" + Pre-process Data files and build vocabulary for Bert model. +""" +from onmt.utils.parse import ArgumentParser +from tqdm import tqdm +import csv +from collections import Counter, defaultdict +import torch +import codecs +from random import shuffle +from onmt.utils.bert_tokenization import BertTokenizer +from onmt.inputters.inputter import get_bert_fields, \ + _build_bert_fields_vocab +import onmt.opts as opts +from onmt.inputters.dataset_bert import ClassifierDataset, \ + TaggerDataset +from onmt.utils.logging import init_logger, logger + + +def shuffle_pair_list(list_a, list_b): + assert len(list_a) == len(list_b),\ + "Two list to shuffle should be equal length" + logger.info("Shuffle all instance") + pair_list = list(zip(list_a, list_b)) + shuffle(pair_list) + list_a, list_b = zip(*pair_list) + return list_a, list_b + + +def build_label_vocab_from_path(paths): + labels = [] + for filename in paths: + label = filename.split('/')[-2] + if label not in labels: + labels.append(label) + return labels + + +def _build_bert_vocab(vocab, name, counters): + """ similar to _load_vocab in inputter.py, but build from a vocab list. + in place change counters + """ + vocab_size = len(vocab) + for i, token in enumerate(vocab): + counters[name][token] = vocab_size - i + return vocab, vocab_size + + +def build_vocab_from_tokenizer(fields, tokenizer, named_labels): + logger.info("Building token vocab from BertTokenizer...") + vocab_list = list(tokenizer.vocab.keys()) + counters = defaultdict(Counter) + _, vocab_size = _build_bert_vocab(vocab_list, "tokens", counters) + + if named_labels is not None: + label_name, label_list = named_labels + logger.info("Building label vocab {}...".format(named_labels)) + _, _ = _build_bert_vocab(label_list, label_name, counters) + else: + label_name = None + + fields_vocab = _build_bert_fields_vocab(fields, counters, vocab_size, + label_name) + return fields_vocab + + +def build_save_vocab(fields, tokenizer, label_vocab, opt): + if opt.sort_label_vocab is True: + label_vocab.sort() + if opt.task == "classification": + named_labels = ("category", label_vocab) + if opt.task == "tagging": + named_labels = ("token_labels", label_vocab) + + fields_vocab = build_vocab_from_tokenizer( + fields, tokenizer, named_labels) + bert_vocab_file = opt.save_data + ".vocab.pt" + torch.save(fields_vocab, bert_vocab_file) + + +def create_cls_instances_from_csv(opt): + logger.info("Reading csv with input in column %s, label in column %s" + % (opt.input_columns, opt.label_column)) + with codecs.open(opt.data, "r", encoding="utf-8-sig") as csvfile: + reader = csv.reader(csvfile, delimiter=opt.delimiter, quotechar=None) + lines = list(reader) + if opt.skip_head is True: + lines = lines[1:] + if len(opt.input_columns) == 1: + column_a = int(opt.input_columns[0]) + column_b = None + else: + column_a = int(opt.input_columns[0]) + column_b = int(opt.input_columns[1]) + + instances, labels, label_vocab = [], [], opt.labels + for line in tqdm(lines, desc="Process", unit=" lines"): + label = line[opt.label_column].strip() + if label not in label_vocab: + label_vocab.append(label) + sentence = line[column_a].strip() + if column_b is not None: + sentence_b = line[column_b].strip() + sentence = sentence + ' ||| ' + sentence_b + instances.append(sentence) + labels.append(label) + logger.info("total %d line loaded with skip_head [%s]" + % (len(lines), opt.skip_head)) + + return instances, labels, label_vocab + + +def create_cls_instances_from_files(opt): + instances = [] + labels = [] + label_vocab = build_label_vocab_from_path(opt.data) + for filename in opt.data: + label = filename.split('/')[-2] + with codecs.open(filename, "r", encoding="utf-8") as f: + lines = f.readlines() + print("total {} line of File {} loaded for label: {}.".format( + len(lines), filename, label)) + lines_labels = [label for _ in range(len(lines))] + instances.extend(lines) + labels.extend(lines_labels) + return instances, labels, label_vocab + + +def build_cls_dataset(corpus_type, fields, tokenizer, opt): + """Build classification dataset with vocab file if train set""" + assert corpus_type in ['train', 'valid'] + if opt.file_type == 'csv': + instances, labels, label_vocab = create_cls_instances_from_csv(opt) + else: + instances, labels, label_vocab = create_cls_instances_from_files(opt) + logger.info("Exiting labels:%s" % label_vocab) + if corpus_type == 'train': + build_save_vocab(fields, tokenizer, label_vocab, opt) + + if opt.do_shuffle is True: + instances, labels = shuffle_pair_list(instances, labels) + cls_instances = instances, labels + logger.info("Building %s dataset..." % corpus_type) + dataset = ClassifierDataset( + fields, cls_instances, tokenizer, opt.max_seq_len) + return dataset, len(cls_instances[0]) + + +def create_tag_instances_from_file(opt): + logger.info("Reading tag with token in column %s, tag in column %s" + % (opt.input_columns, opt.label_column)) + sentences, taggings = [], [] + tag_vocab = opt.labels + with codecs.open(opt.data, "r", encoding="utf-8") as f: + lines = f.readlines() + logger.info("total {} line of file {} loaded.".format( + len(lines), opt.data)) + sentence_sofar = [] + for line in tqdm(lines, desc="Process", unit=" lines"): + line = line.strip() + if line == '': + if len(sentence_sofar) > 0: + tokens, tags = zip(*sentence_sofar) + sentences.append(tokens) + taggings.append(tags) + sentence_sofar = [] + else: + elements = line.split(opt.delimiter) + token = elements[opt.input_columns] + tag = elements[opt.label_column] + if tag not in tag_vocab: + tag_vocab.append(tag) + sentence_sofar.append((token, tag)) + logger.info("total {} sentence loaded.".format(len(sentences))) + logger.info("All tags:{}".format(tag_vocab)) + + return sentences, taggings, tag_vocab + + +def build_tag_dataset(corpus_type, fields, tokenizer, opt): + """Build tagging dataset with vocab file if train set""" + assert corpus_type in ['train', 'valid'] + sentences, taggings, tag_vocab = create_tag_instances_from_file(opt) + logger.info("Exiting Tags:%s" % tag_vocab) + if corpus_type == 'train': + build_save_vocab(fields, tokenizer, tag_vocab, opt) + + if opt.do_shuffle is True: + sentences, taggings = shuffle_pair_list(sentences, taggings) + + tag_instances = sentences, taggings + logger.info("Building %s dataset..." % corpus_type) + dataset = TaggerDataset( + fields, tag_instances, tokenizer, opt.max_seq_len) + return dataset, len(tag_instances[0]) + + +def _get_parser(): + parser = ArgumentParser(description='preprocess_bert.py') + opts.config_opts(parser) + opts.preprocess_bert_opts(parser) + return parser + + +def main(opt): + init_logger(opt.log_file) + opt = ArgumentParser.validate_preprocess_bert_opts(opt) + logger.info("Preprocess dataset...") + + fields = get_bert_fields(opt.task) + logger.info("Get fields for Task: '%s'." % opt.task) + + tokenizer = BertTokenizer.from_pretrained( + opt.vocab_model, do_lower_case=opt.do_lower_case) + logger.info("Use pretrained tokenizer: '%s', do_lower_case [%s]" + % (opt.vocab_model, opt.do_lower_case)) + + if opt.task == "classification": + dataset, n_instance = build_cls_dataset( + opt.corpus_type, fields, tokenizer, opt) + + elif opt.task == "tagging": + dataset, n_instance = build_tag_dataset( + opt.corpus_type, fields, tokenizer, opt) + # Save processed data in OpenNMT format + onmt_filename = opt.save_data + ".{}.0.pt".format(opt.corpus_type) + dataset.save(onmt_filename) + logger.info("* save num_example [%d], max_seq_len [%d] to [%s]." + % (n_instance, opt.max_seq_len, onmt_filename)) + + +if __name__ == '__main__': + parser = _get_parser() + opt = parser.parse_args() + main(opt) diff --git a/requirements.opt.txt b/requirements.opt.txt index fdbd2d1ee3..3b73f9bd02 100644 --- a/requirements.opt.txt +++ b/requirements.opt.txt @@ -8,3 +8,5 @@ pyrouge opencv-python git+https://github.com/NVIDIA/apex pretrainedmodels +boto3 +sklearn