Contents

Language Models [03] - Tokenizer

The first component we are going to build is a tokenizer. For the character-level model, it is very simple. But we will create an implementation that can be extended later wit ease.

Abstract

This post is a part of a series;

  1. Introduction and theory
  2. Dataset exploratory analysis
  3. Tokenizer (you are here)
  4. Training the model
  5. Evaluation of language model
  6. Experiments with the model
  7. Exercises for you

Requirements

Before we code, let’s think about requirements. Rough bullet points can be:

  • splitting text on characters
  • crating vocabulary from corpus
  • encoding text (changing tokens to integers)
  • decoding text (integers to tokens)
  • handling out of vocabulary tokens
  • serializing to file
  • loading from file

Quite a lot for a simple tokenizer!

Info
Complete source code is located at the bottom of the post.

Creating tokenizer interface

First, I’m going to define the tokenizer interface - a list of methods and their responsibilities. Create a file char_lm/tokenizer.py. Fill it with required imports and tokenizer class scaffold.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class CharacterTokenizer:
    def __init__(self) -> None:
        pass

    def fit(self, corpus_file: Path):
        """Fits tokenizer on text file. 
        Vocabulary is created from all characters from file."""
        pass

    def get_vocab_size(self) -> int:
        """Returns size of the vocabulary"""
        pass

    def encode(self, s: str, include_special_tokens=True) -> torch.LongTensor:
        """Encodes a single string"""
        pass

    def encode_batch(self, texts=List[str], include_special_tokens=True) -> torch.LongTensor:
        """Encodes a batch of strings and pads them to the longest one"""
        pass

    def decode(self, t: torch.LongTensor) -> str:
        """Decodes string from 1-D tensor"""
        pass

    def to_file(self, path: Path) -> None:
        """Saves the tokenizer to a file"""
        pass
    
    @staticmethod
    def from_file(path: Path):
        """Instantiates a tokenizer from file"""
        pass
Info
If you are not familiar with type hints - I encourage you to give them a try!

They should cover all our requirements. We are going to implement them more or less in the given order.

Implementation

Init

In __init__, we will define special tokens. The model will use sos,eos,oov,pad. To make life easier, they will be represented as single characters (instead of strings like [sos]). Additionally, we fix the index of pad token - it’s going to help us later.

1
2
3
4
5
6
def __init__(self) -> None:
    self.sos_token = "^"
    self.eos_token = "$"
    self.oov_token = "@"
    self.pad_token = "#"
    self.pad_index = 0

Every time a special token will be referenced - it will use those fields. I will make replacing them if necessary easy.

Fit

The fit method takes a path to a text file and uses it to create a vocabulary and token-index mapping.

Info
Every time I need a file path, I use Path class from pathlib module. It is a high-level replacement for ordinary strings making path manipulations safer. You can read more here: https://treyhunner.com/2018/12/why-you-should-be-using-pathlib/

We load all text from the file and remove newlines - as city names should not contain them.

1
2
3
def fit(self, corpus_file: Path):
    with corpus_file.open("rt") as f:
        text = f.read().replace("\n", "")

To get unique characters, we can use Counter from collections and log the ten most common characters - to have a sanity check.

Next, we construct a mapping from the index of the character to the actual character. We need to include all special tokens and all tokens from the vocabulary.

1
2
3
4
self.idx2tok = list(
    [self.pad_token, self.sos_token, self.eos_token, self.oov_token]
    + list(vocab.keys())
)

We also need a reverse one:

1
self.tok2idx = {tok: idx for idx, tok in enumerate(self.idx2tok)}

After fitting they can look like:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
idx2tok
['#', '^', '$', '@', 'A', 'b', 'e', 'v', 'i', 'l', 'r', 'n', 'a', 't', 'd', 'm', 's',
 'o', 'g', 'k', 'x', ' ', 'C', 'y', 'c', 'p', 'u', 'h', 'f', 'U', 'B', 'M', 'L', 'w',
 'z', 'S', 'H', 'D', 'I', 'P', 'E', 'T', 'q', 'F', 'R', 'G', 'J', 'K', 'W', 'O', 'V',
 'N', 'Q', 'Y', '-', "'", '.', 'j', 'Z', 'X']

tok2idx
{'#': 0, '^': 1, '$': 2, '@': 3, 'A': 4, 'b': 5, 'e': 6, 'v': 7, 'i': 8, 'l': 9,
 'r': 10, 'n': 11, 'a': 12, 't': 13, 'd': 14, 'm': 15, 's': 16, 'o': 17, 'g': 18,
 'k': 19, 'x': 20, ' ': 21, 'C': 22, 'y': 23, 'c': 24, 'p': 25, 'u': 26, 'h': 27,
 'f': 28, 'U': 29, 'B': 30, 'M': 31, 'L': 32, 'w': 33, 'z': 34, 'S': 35, 'H': 36,
 'D': 37, 'I': 38, 'P': 39, 'E': 40, 'T': 41, 'q': 42, 'F': 43, 'R': 44, 'G': 45,
 'J': 46, 'K': 47, 'W': 48, 'O': 49, 'V': 50, 'N': 51, 'Q': 52, 'Y': 53, '-': 54,
 "'": 55, '.': 56, 'j': 57, 'Z': 58, 'X': 59})

Special tokens are located on the top, followed by regular tokens in descending frequency order. With both of them, we can easily map between text and numeric representation.

Warning
Notice that pad is first at the list - as I want to have index 0 for it.

Out of vocabulary

If we try to get index for out of vocabulary token there will be an error:

1
2
3
4
self.tok2idx['?']
Traceback (most recent call last):
  File "<string>", line 1, in <module>
KeyError: '?'

Clearly, ? is out of vocabulary and does not exist in vocab. In such situation, we should return oov token. Easiest way is to wrap tok2idx into a defaultdict.

1
2
3
self.tok2idx = defaultdict(
    lambda: self.tok2idx[self.oov_token], self.tok2idx
)

Try again

1
2
3
4
tok2idx['?']
3
tok2idx[';']
3

We got index, is it correct?

1
2
idx2tok[3]
'@'

Yes, @ is our oov token.

Last small step is to log a simple summary

1
logger.info(f"Created vocabulary with {self.get_vocab_size()} tokens")

We can make a quick manual test - create tokenizer and fit it on full dataset. At the end of file add:

1
2
3
4
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO) # ignore this line if you use print
    ct = CharacterTokenizer()
    ct.fit(Path("data/dataset/city/full.txt"))
1
2
INFO:__main__:Most frequent tokens [('e', 25216), ('l', 22518), ('a', 21478), ('o', 17837), ('i', 16520), ('r', 15699), ('n', 15127), (' ', 12802), ('t', 11361), ('s', 9208)]
INFO:__main__:Created vocabulary with 60 tokens

Looks good.

Get vocab size

This one is trivial:

1
2
def get_vocab_size(self) -> int:
    return len(self.tok2idx)

Encode

The heart of tokenizer. We can imagine that sometimes we want to include special tokens and sometimes no - we can handle both scenarios.

1
2
3
4
5
6
def encode(self, s: str, include_special_tokens=True) -> torch.LongTensor:
    if include_special_tokens:
        s = [self.sos_token] + list(s) + [self.eos_token]
    
    indices = [self.tok2idx[c] for c in s]
    return torch.LongTensor(indices)

If special tokens are required, prepend and append them to the input string. To encode, we just replace every character with numeric representation and cast the result to a long tensor. We already have oov support!

Encode batch

As we train models in batches, it is helpful to have a method for direct batch encoding. The first step is to apply encode on each string from the list. As strings will have different lengths, we need to pad them to the longest one. For this purpose, we use pad_sequence from PyTorch. By default, it returns a result with shape max sequence lenxbatch size, but I found it easier to have batch sizexmax sequence len. Hopefully, there is an argument for that. Value for padding is the index of pad token. Instead of hard-coding 0, I use the field self.pad_index if it changes later.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def encode_batch(self, texts=List[str], include_special_tokens=True) -> torch.LongTensor:
    encoded = [
        self.encode(t, include_special_tokens=include_special_tokens)
        for t in texts
    ]

    padded = pad_sequence(
        encoded, batch_first=True, padding_value=self.pad_index
    )

    return padded

Decode

This one is similar - just do it other way around. Here we do not care about oov at all.

1
2
3
def decode(self, t: torch.LongTensor) -> str:
    decoded = [self.idx2tok[i] for i in t]
    return "".join(decoded)

It is a good place for next testing round. We can try something like this: take an example from dataset and enode it. Next, try to decode encoded value and check if it is correct.

1
2
3
4
5
6
sample_sequence = "Abbeville"
print("Sequence:", sample_sequence)
encoded = ct.encode(sample_sequence)
print("Encoded:", encoded)
decoded = ct.decode(encoded)
print("Decoded:", decoded)

We will get

1
2
3
Sequence: Abbeville
Encoded: tensor([1, 4, 5, 5, 6, 7, 8, 9, 9, 6, 2])
Decoded: ^Abbeville$

Encoding added sos and eos (values 1 and 2). When we decode, special tokens are visible.

To file

As the tokenizer needs fitting, we need to save it to a file. I do not want to fit it every time before using the model. I use dill instead of pickle because it can serialize defaultdict (pickle cannot).

1
2
3
4
def to_file(self, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("wb") as f:
        dill.dump(self, f)

From file

If there is save, there is also load.

1
2
3
4
@staticmethod
def from_file(path: Path):
    with path.open("rb") as f:
        return dill.load(f)

Command line interface

We finished our tokenizer. Now let’s create a simple interface for it. As arguments we pick corpus file and file where save fitted tokenizer.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
@click.command()
@click.option("-i", "--corpus-path", type=str, required=True)
@click.option("-o", "--output-path", type=str, required=True)
def main(corpus_path, output_path):
    logging.basicConfig(level=logging.INFO)
    coprus_path = Path(corpus_path)
    output_path = Path(output_path)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    tokenizer = CharacterTokenizer()
    tokenizer.fit(coprus_path)
    tokenizer.to_file(output_path)

    logger.info(f"Saved tokenizer to: {output_path}")


if __name__ == "__main__":
    main()

Example run:

1
2
3
4
python char_lm/tokenizer.py -i data/dataset/city/train.txt -o models/tokenizer
INFO:__main__:Most frequent tokens [('e', 16069), ('l', 14244), ('a', 13752), ('o', 11395), ('i', 10539), ('r', 10018), ('n', 9778), (' ', 8099), ('t', 7249), ('s', 5867)]
INFO:__main__:Created vocabulary with 60 tokens
INFO:__main__:Saved tokenizer to: models/tokenizer

Unit testing

During development we tested the tokenizer manually. But this is not enough - that’s why we will also create basic unit tests.

Tip
I will not make introduction to unit testing here., but even if you have never used them I encourage you to try. For a tutorial on pytest check out: https://www.guru99.com/pytest-tutorial.html

For first, let’s create a file tests/test_tokenizer.py file. It will contain all unit tests for the tokenizer.

First test and fixtures

Trying to start simple, we can test the method get_vocab_size - see if it returns the expected size. For example, if our corpus contains 10 distinct characters we expect that vocab size will be 14 (10 + sos + eos + pad + oov). Ok, but how do we know how many unique characters are in the file? We could use our tokenizer… No. That would be pointless - we cannot use the code we test to test itself.

Instead, we should count characters manually (or use a different method that we consider safe). But this gives us another issue. What if the corpus is big? Should we scroll endlessly through it and taking notes about letters? Of course not. Instead, we use a data fixture - a small artificial data file.

Create a file tests/fixtures/tiny_corpus.txt with following content:

1
2
3
4
5
aaab
aaac
aaad
aaae
aaaf

For such small corpus it’s easy to spot unique characters - abcdef. With special token it is 10 tokens.

For first unit test let’s create file test/test_tokenizer.py with following content:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from pathlib import Path

import pytest
import torch

from char_lm.tokenizer import CharacterTokenizer

def test_vocab_size(tokenizer: CharacterTokenizer):
    corpus_path = Path("tests/fixtures/tiny_corpus.txt")
    tokenizer = CharacterTokenizer()

    tokenizer.fit(corpus_path)
    
    # 6 letters + 4 special tokens
    assert tokenizer.get_vocab_size() == 6 + 4

We have a single test function. It starts with required initialization - creating a tokenizer. Next we perform fit operation using test corpus. Last but not least we test if method we test behaves correctly. Here it is very simple - just compare actual output with expected one.

In order to run the test, just execute command pytest (remeber to install it first).

1
pytest
1
2
3
4
5
6
7
8
9
============================= test session starts ==============================
platform linux -- Python 3.8.5, pytest-5.4.3, py-1.10.0, pluggy-0.13.1
rootdir: /home/xevaquor/artofai/char-lm
plugins: anyio-2.1.0
collected 1 item                                                               

tests/test_tokenizer.py .                                                [100%]

============================== 1 passed in 0.26s ===============================

From the output we see that pytest found one test case and it passed! Let’s try to make it fail. Change assert tokenizer.get_vocab_size() == 6 + 4 to assert tokenizer.get_vocab_size() == 6 + 3. and run it again. Output is following:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
============================= test session starts ==============================
platform linux -- Python 3.8.5, pytest-5.4.3, py-1.10.0, pluggy-0.13.1
rootdir: /home/xevaquor/artofai/char-lm
plugins: anyio-2.1.0
collected 1 item                                                               

tests/test_tokenizer.py F                                                [100%]

=================================== FAILURES ===================================
_______________________________ test_vocab_size ________________________________

tokenizer = <char_lm.tokenizer.CharacterTokenizer object at 0x7f925eae0d90>

    def test_vocab_size(tokenizer: CharacterTokenizer):
        corpus_path = Path("tests/fixtures/tiny_corpus.txt")
        tokenizer = CharacterTokenizer()
        tokenizer.fit(corpus_path)
    
        # 6 letters + 4 special tokens
        assert tokenizer.get_vocab_size() == 6 + 3
E       assert 10 == (6 + 3)
E        +  where 10 = <bound method CharacterTokenizer.get_vocab_size of <char_lm.tokenizer.CharacterTokenizer object at 0x7f925eae0d90>>()
E        +    where <bound method CharacterTokenizer.get_vocab_size of <char_lm.tokenizer.CharacterTokenizer object at 0x7f925eae0d90>> = <char_lm.tokenizer.CharacterTokenizer object at 0x7f925eae0d90>.get_vocab_size

tests/test_tokenizer.py:23: AssertionError
=========================== short test summary info ============================
FAILED tests/test_tokenizer.py::test_vocab_size - assert 10 == (6 + 3)
============================== 1 failed in 0.36s ===============================

We see what and where failed - returned value was 10 but expected was 6+3=9. Fix it once again and we can add some more tests.

Test encode-decode

Important property for our tokenizer is reversebility - when we encode a string, after decoding we should get the original one. Let’s test it!

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
def test_encode_decode_without_special_chars():
    corpus_path = Path("tests/fixtures/tiny_corpus.txt")
    tokenizer = CharacterTokenizer()

    tokenizer.fit(corpus_path)

    text = "abc"
    encoded = tokenizer.encode(text, include_special_tokens=False)
    decoded = tokenizer.decode(encoded)

    assert decoded == text

First, the tokenizer is instanciated. Next, a test string abc is defined. It is encoded (without adding sos/eos) and then decoded. A check is made if result of this operation is equal to the original input.

1
2
3
4
5
6
7
8
9
============================= test session starts ==============================
platform linux -- Python 3.8.5, pytest-5.4.3, py-1.10.0, pluggy-0.13.1
rootdir: /home/xevaquor/artofai/char-lm
plugins: anyio-2.1.0
collected 2 items                                                              

tests/test_tokenizer.py ..                                               [100%]

============================== 2 passed in 0.27s ===============================

It works :)

Fixtures once again

Before we add another tests you might spot an issue. We create an instance of tokenizer at every test - evein if it is exactly the same. Code duplication is of course wrong. TO deal with it we can create another kind of fixture - creating tokenizer. Having it, inside the test we can tell pytest “I need an instance of CharacterTokenizer, give me one.”. = without worrying how to create it inside the test. Let’s see the example.

Tip
To read more about fixtures see: https://docs.pytest.org/en/stable/fixture.html

Fixture creation is very simple. We need to create a function and decorate it with pytest.fixture decorator. Inside this function we perform all steps needed to build the tokenizer and just return it.

1
2
3
4
5
6
@pytest.fixture
def tokenizer():
    corpus_path = Path("tests/fixtures/tiny_corpus.txt")
    tokenizer = CharacterTokenizer()
    tokenizer.fit(corpus_path)
    return tokenizer

Having it we can modify our first test to use this fixture:

1
2
3
def test_vocab_size(tokenizer: CharacterTokenizer):
    # 6 letters + 4 special tokens
    assert tokenizer.get_vocab_size() == 6 + 4

As you can see, we got rid of creation iside the test, just added tokenizer argument. Pytest will understand that test_vocab_size needs a tokenizer. Next it will search in registered fixtures for a fixture named tokenizer. When it finds one, it will just execute our test function with value returned from the fixture.

Similarly we can modyfy second test case.

1
2
3
4
5
6
def test_encode_decode_without_special_chars(tokenizer: CharacterTokenizer):
    text = "abc"
    encoded = tokenizer.encode(text, include_special_tokens=False)
    decoded = tokenizer.decode(encoded)

    assert decoded == text

Now, we can add as much test cases we would like to without writing a tokenizer creation code again and again.

note

Question
Think about tests you can add! In the source code for this post, I placed some propsitions.

Summary

In this part, you lerned how to build a nice tokenizer for the model. What is more you have unit tests for your code which helps you be sure it is working as intended.

Question
Try to add a method batch_decode - for decoding multiple sequences at once.

Full source code

Full code - char_lm/tokenizer.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import logging
from collections import Counter, defaultdict
from pathlib import Path
from typing import List

import click
import dill
import torch
from torch.nn.utils.rnn import pad_sequence

logger = logging.getLogger(__name__)


class CharacterTokenizer:
    def __init__(self) -> None:
        self.sos_token = "^"
        self.eos_token = "$"
        self.oov_token = "@"
        self.pad_token = "#"
        self.pad_index = 0

    def fit(self, corpus_file: Path):
        """Fits tokenizer on text file. Vocabulary is created from all characters from file.

        Args:
            corpus_file (Path): Text file location
        """
        with corpus_file.open("rt") as f:
            text = f.read().replace("\n", "")

        vocab = Counter(text)
        logger.info(f"Most frequent tokens {vocab.most_common(10)}")
        self.idx2tok = list(
            [self.pad_token, self.sos_token, self.eos_token, self.oov_token]
            + list(vocab.keys())
        )
        self.tok2idx = {tok: idx for idx, tok in enumerate(self.idx2tok)}
        self.tok2idx = defaultdict(
            lambda: self.tok2idx[self.oov_token], self.tok2idx
        )
        logger.info(f"Created vocabulary with {self.get_vocab_size()} tokens")

    def get_vocab_size(self) -> int:
        """Returns size of the vocabulary

        Returns:
            int: Vocabulary size
        """
        return len(self.tok2idx)

    def encode_batch(
        self, texts=List[str], include_special_tokens=True
    ) -> torch.LongTensor:
        """Encodes a batch of strings

        Args:
            texts (List[str], required): List of strings to be encoded
            include_special_tokens (bool, optional): To include sos/eos tokens. Defaults to True.

        Returns:
            torch.LongTensor: Encoded strings (2D tensor)
        """
        encoded = [
            self.encode(t, include_special_tokens=include_special_tokens)
            for t in texts
        ]

        padded = pad_sequence(
            encoded, batch_first=True, padding_value=self.pad_index
        )

        return padded

    def encode(self, s: str, include_special_tokens=True) -> torch.LongTensor:
        """Encodes a single string

        Args:
            s (str): String to encode
            include_special_tokens (bool, optional): If to add sos/eos. Defaults to True.

        Returns:
            torch.LongTensor: 1-D encoded tensor
        """
        if include_special_tokens:
            s = [self.sos_token] + list(s) + [self.eos_token]

        indices = [self.tok2idx[c] for c in s]
        return torch.LongTensor(indices)

    def decode(self, t: torch.LongTensor) -> str:
        """[summary]

        Args:
            t (torch.LongTensor): 1-D encoded tensor

        Returns:
            str: Decoded string
        """
        decoded = [self.idx2tok[i] for i in t]
        return "".join(decoded)

    def to_file(self, path: Path):
        """Serializes tokenizer to a file

        Args:
            path (Path): File to save tokenizer to
        """
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open("wb") as f:
            dill.dump(self, f)

    @staticmethod
    def from_file(path: Path):
        """Deserializes tokenizer from a file

        Args:
            path (Path): Path from read the encoder

        Returns:
            [type]: A CharacterEncoder
        """
        with path.open("rb") as f:
            return dill.load(f)


@click.command()
@click.option("-i", "--corpus-path", type=str, required=True)
@click.option("-o", "--output-path", type=str, required=True)
def main(corpus_path, output_path):
    logging.basicConfig(level=logging.INFO)
    coprus_path = Path(corpus_path)
    output_path = Path(output_path)

    output_path.parent.mkdir(parents=True, exist_ok=True)
    tokenizer = CharacterTokenizer()
    tokenizer.fit(coprus_path)
    tokenizer.to_file(output_path)

    logger.info(f"Saved tokenizer to: {output_path}")


if __name__ == "__main__":
    main()

tests/test_tokenizer.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from pathlib import Path

import pytest
import torch

from char_lm.tokenizer import CharacterTokenizer


@pytest.fixture
def tokenizer():
    corpus_path = Path("tests/fixtures/tiny_corpus.txt")
    tokenizer = CharacterTokenizer()
    tokenizer.fit(corpus_path)
    return tokenizer


def test_vocab_size(tokenizer: CharacterTokenizer):
    # 6 letters + 4 special tokens
    assert tokenizer.get_vocab_size() == 6 + 4


def test_encode_decode_without_special_chars(tokenizer: CharacterTokenizer):
    text = "abc"
    encoded = tokenizer.encode(text, include_special_tokens=False)
    decoded = tokenizer.decode(encoded)

    assert decoded == text


def test_encode_decode_with_special_chars(tokenizer: CharacterTokenizer):
    text = "abc"
    encoded = tokenizer.encode(text, include_special_tokens=True)
    decoded = tokenizer.decode(encoded)

    assert decoded == tokenizer.sos_token + text + tokenizer.eos_token


def test_encode_with_oov(tokenizer: CharacterTokenizer):
    text = "abcX"
    encoded = tokenizer.encode(text, include_special_tokens=False)
    assert encoded[3] == tokenizer.tok2idx[tokenizer.oov_token]


def test_encode_decode_with_oov(tokenizer: CharacterTokenizer):
    text = "abcX"
    encoded = tokenizer.encode(text, include_special_tokens=True)
    decoded = tokenizer.decode(encoded)
    assert (
        decoded
        == tokenizer.sos_token
        + "abc"
        + tokenizer.oov_token
        + tokenizer.eos_token
    )


def test_encode_batch_with_var_size(tokenizer: CharacterTokenizer):
    texts = ["a", "ab", "abc"]
    batch = tokenizer.encode_batch(texts)
    # batch x seq len; 3x5
    assert batch.shape == (3, 5)
    # 2 pads in 0th example
    assert torch.equal(batch[0, 3:], torch.LongTensor([0, 0]))
    # 0 pads in last example
    assert torch.all(batch[2, :] != torch.LongTensor([0, 0, 0, 0, 0]))


def test_serialize_deserialize(tokenizer: CharacterTokenizer, tmp_path):
    serialization_path = tmp_path / "tokenizer.pickle"
    texts = ["a", "b", "abc"]
    encoded = tokenizer.encode_batch(texts)

    tokenizer.to_file(serialization_path)

    loaded_tok = CharacterTokenizer.from_file(serialization_path)
    loaded_encoded = loaded_tok.encode_batch(texts)

    assert torch.all(encoded == loaded_encoded)