Skip to content

Commit 36bca54

Browse files
committed
tokenization abstract class - tests for examples
1 parent a4f9805 commit 36bca54

33 files changed

+815
-566
lines changed

examples/run_squad.py

+400
Large diffs are not rendered by default.

examples/test_examples.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# coding=utf-8
2+
# Copyright 2018 HuggingFace Inc..
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import sys
20+
import unittest
21+
import argparse
22+
23+
try:
24+
# python 3.4+ can use builtin unittest.mock instead of mock package
25+
from unittest.mock import patch
26+
except ImportError:
27+
from mock import patch
28+
29+
import run_bert_squad as rbs
30+
31+
def get_setup_file():
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument('-f')
34+
args = parser.parse_args()
35+
return args.f
36+
37+
class ExamplesTests(unittest.TestCase):
38+
39+
def test_run_squad(self):
40+
testargs = ["prog", "-f", "/home/test/setup.py"]
41+
with patch.object(sys, 'argv', testargs):
42+
setup = get_setup_file()
43+
assert setup == "/home/test/setup.py"
44+
# rbs.main()
45+
46+
47+
if __name__ == "__main__":
48+
unittest.main()

pytorch_transformers/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .tokenization_gpt2 import GPT2Tokenizer
66
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
77
from .tokenization_xlm import XLMTokenizer
8+
from .tokenization_utils import (PreTrainedTokenizer, clean_up_tokenization)
89

910
from .modeling_bert import (BertConfig, BertModel, BertForPreTraining,
1011
BertForMaskedLM, BertForNextSentencePrediction,
@@ -26,11 +27,10 @@
2627
from .modeling_xlm import (XLMConfig, XLMModel,
2728
XLMWithLMHeadModel, XLMForSequenceClassification,
2829
XLMForQuestionAnswering)
30+
from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
31+
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
2932

3033
from .optimization import BertAdam
3134
from .optimization_openai import OpenAIAdam
3235

3336
from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
34-
35-
from .model_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
36-
PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)

pytorch_transformers/modeling_bert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torch.nn import CrossEntropyLoss, MSELoss
3030

3131
from .file_utils import cached_path
32-
from .model_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
32+
from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer
3333

3434
logger = logging.getLogger(__name__)
3535

pytorch_transformers/modeling_gpt2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch.nn.parameter import Parameter
3232

3333
from .file_utils import cached_path
34-
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
34+
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
3535
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
3636
from .modeling_bert import BertLayerNorm as LayerNorm
3737

pytorch_transformers/modeling_openai.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch.nn.parameter import Parameter
3232

3333
from .file_utils import cached_path
34-
from .model_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
34+
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
3535
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
3636
from .modeling_bert import BertLayerNorm as LayerNorm
3737

pytorch_transformers/modeling_transfo_xl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from .modeling_bert import BertLayerNorm as LayerNorm
3838
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
3939
from .file_utils import cached_path
40-
from .model_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
40+
from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel
4141

4242
logger = logging.getLogger(__name__)
4343

pytorch_transformers/model_utils.py pytorch_transformers/modeling_utils.py

-6
Original file line numberDiff line numberDiff line change
@@ -598,9 +598,3 @@ def prune_layer(layer, index, dim=None):
598598
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
599599
else:
600600
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
601-
602-
def clean_up_tokenization(out_string):
603-
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
604-
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
605-
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
606-
return out_string

pytorch_transformers/modeling_xlm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from torch.nn import CrossEntropyLoss, MSELoss
3636

3737
from .file_utils import cached_path
38-
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
38+
from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
3939
prune_linear_layer, SequenceSummary, SQuADHead)
4040

4141
logger = logging.getLogger(__name__)

pytorch_transformers/modeling_xlnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from torch.nn import CrossEntropyLoss, MSELoss
3333

3434
from .file_utils import cached_path
35-
from .model_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
35+
from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel,
3636
SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits)
3737

3838

pytorch_transformers/tests/model_utils_test.py

-50
This file was deleted.

pytorch_transformers/tests/modeling_bert_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
BertForTokenClassification, BertForMultipleChoice)
2727
from pytorch_transformers.modeling_bert import PRETRAINED_MODEL_ARCHIVE_MAP
2828

29-
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
29+
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
3030

3131

3232
class BertModelTest(unittest.TestCase):

pytorch_transformers/tests/modeling_gpt2_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_transformers import (GPT2Config, GPT2Model,
2929
GPT2LMHeadModel, GPT2DoubleHeadsModel)
3030

31-
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
31+
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
3232

3333
class GPT2ModelTest(unittest.TestCase):
3434

pytorch_transformers/tests/modeling_openai_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel,
2525
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel)
2626

27-
from .model_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
27+
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, GPTModelTester)
2828

2929
class OpenAIModelTest(unittest.TestCase):
3030

pytorch_transformers/tests/modeling_transfo_xl_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel)
2929
from pytorch_transformers.modeling_transfo_xl import PRETRAINED_MODEL_ARCHIVE_MAP
3030

31-
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
31+
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
3232

3333
class TransfoXLModelTest(unittest.TestCase):
3434
class TransfoXLModelTester(object):

examples/tests/examples_tests.py pytorch_transformers/tests/modeling_utils_test.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,10 @@
1616
from __future__ import division
1717
from __future__ import print_function
1818

19-
import os
2019
import unittest
21-
import json
22-
import random
23-
import shutil
24-
import pytest
25-
26-
import torch
2720

2821
from pytorch_transformers import PretrainedConfig, PreTrainedModel
29-
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP, PRETRAINED_CONFIG_ARCHIVE_MAP
22+
from pytorch_transformers.modeling_bert import BertModel, BertConfig, PRETRAINED_MODEL_ARCHIVE_MAP
3023

3124

3225
class ModelUtilsTest(unittest.TestCase):

pytorch_transformers/tests/modeling_xlm_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification)
2424
from pytorch_transformers.modeling_xlm import PRETRAINED_MODEL_ARCHIVE_MAP
2525

26-
from .model_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
26+
from .modeling_tests_commons import (create_and_check_commons, ConfigTester, ids_tensor)
2727

2828

2929
class XLMModelTest(unittest.TestCase):

pytorch_transformers/tests/modeling_xlnet_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering)
2929
from pytorch_transformers.modeling_xlnet import PRETRAINED_MODEL_ARCHIVE_MAP
3030

31-
from .model_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
31+
from .modeling_tests_commons import ConfigTester, create_and_check_commons, ids_tensor
3232

3333
class XLNetModelTest(unittest.TestCase):
3434
class XLNetModelTester(object):

pytorch_transformers/tests/tokenization_bert_test.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
BertTokenizer,
2525
WordpieceTokenizer,
2626
_is_control, _is_punctuation,
27-
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
27+
_is_whitespace)
2828

2929
from .tokenization_tests_commons import create_and_check_tokenizer_commons
3030

@@ -49,14 +49,6 @@ def test_full_tokenizer(self):
4949

5050
os.remove(vocab_file)
5151

52-
@pytest.mark.slow
53-
def test_tokenizer_from_pretrained(self):
54-
cache_dir = "/tmp/pytorch_transformers_test/"
55-
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
56-
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
57-
shutil.rmtree(cache_dir)
58-
self.assertIsNotNone(tokenizer)
59-
6052
def test_chinese(self):
6153
tokenizer = BasicTokenizer()
6254

pytorch_transformers/tests/tokenization_gpt2_test.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
import os
1818
import unittest
1919
import json
20-
import shutil
21-
import pytest
2220

23-
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
21+
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
2422

2523
from .tokenization_tests_commons import create_and_check_tokenizer_commons
2624

@@ -56,13 +54,6 @@ def test_full_tokenizer(self):
5654
os.remove(vocab_file)
5755
os.remove(merges_file)
5856

59-
# @pytest.mark.slow
60-
def test_tokenizer_from_pretrained(self):
61-
cache_dir = "/tmp/pytorch_transformers_test/"
62-
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
63-
tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
64-
shutil.rmtree(cache_dir)
65-
self.assertIsNotNone(tokenizer)
6657

6758
if __name__ == '__main__':
6859
unittest.main()

pytorch_transformers/tests/tokenization_openai_test.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import shutil
2121
import pytest
2222

23-
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
23+
from pytorch_transformers.tokenization_openai import OpenAIGPTTokenizer
2424

2525
from.tokenization_tests_commons import create_and_check_tokenizer_commons
2626

@@ -58,14 +58,6 @@ def test_full_tokenizer(self):
5858
self.assertListEqual(
5959
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
6060

61-
@pytest.mark.slow
62-
def test_tokenizer_from_pretrained(self):
63-
cache_dir = "/tmp/pytorch_transformers_test/"
64-
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
65-
tokenizer = OpenAIGPTTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
66-
shutil.rmtree(cache_dir)
67-
self.assertIsNotNone(tokenizer)
68-
6961

7062
if __name__ == '__main__':
7163
unittest.main()

pytorch_transformers/tests/tokenization_transfo_xl_test.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import shutil
2121
import pytest
2222

23-
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
23+
from pytorch_transformers.tokenization_transfo_xl import TransfoXLTokenizer
2424

2525
from.tokenization_tests_commons import create_and_check_tokenizer_commons
2626

@@ -59,13 +59,6 @@ def test_full_tokenizer_no_lower(self):
5959
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
6060
["HeLLo", "!", "how", "Are", "yoU", "?"])
6161

62-
@pytest.mark.slow
63-
def test_tokenizer_from_pretrained(self):
64-
cache_dir = "/tmp/pytorch_transformers_test/"
65-
for model_name in list(PRETRAINED_VOCAB_ARCHIVE_MAP.keys())[:1]:
66-
tokenizer = TransfoXLTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
67-
shutil.rmtree(cache_dir)
68-
self.assertIsNotNone(tokenizer)
6962

7063
if __name__ == '__main__':
7164
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# coding=utf-8
2+
# Copyright 2018 HuggingFace Inc..
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import unittest
20+
21+
from pytorch_transformers import PreTrainedTokenizer
22+
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer
23+
24+
class TokenizerUtilsTest(unittest.TestCase):
25+
def check_tokenizer_from_pretrained(self, tokenizer_class):
26+
s3_models = list(tokenizer_class.max_model_input_sizes.keys())
27+
for model_name in s3_models[:1]:
28+
tokenizer = tokenizer_class.from_pretrained(model_name)
29+
self.assertIsNotNone(tokenizer)
30+
self.assertIsInstance(tokenizer, PreTrainedTokenizer)
31+
32+
def test_pretrained_tokenizers(self):
33+
self.check_tokenizer_from_pretrained(GPT2Tokenizer)
34+
35+
if __name__ == "__main__":
36+
unittest.main()

0 commit comments

Comments
 (0)