7c0h

Fine-tuning a Transformer for text classification

I have found myself often enough trying to classify a text that I decided to be lazy and write a solution that I could always rely on.

The following script is an example on how to pre-train a HuggingFace transformer for text classification. This example uses only two classes, identifying whether the sentiment of a text is positive or negative, but nothing stops you from adding more classes as long as you predict just one.

There are no difficult imports that I can tell - I think you can install all the required libraries with the command

pip install transformers datasets evaluate nltk

but I'll double check just in case.

Update: I have changed the script to extract the function get_dataset_from_json_files(). You can use this function as is to replace the calls to get_data(). I also updated the prediction code to give an example of how to make predictions for a lot of texts at once.

import json
import numpy as np
import os
import random
import tempfile
# These are all HuggingFace libraries
from datasets import load_dataset
import evaluate
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification,\
     TrainingArguments, Trainer, TextClassificationPipeline
# KeyDataset is a util that will just output the item we're interested in.
from transformers.pipelines.pt_utils import KeyDataset
# Library for importing a sentiment classification dataset.
# Only here for demo purposes, as you would use your own dataset.
from nltk.corpus import movie_reviews

# The BERT tokenizer is always the same so we declare it here globally.
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

def get_data():
    """ Returns a dataset built from a toy corpora.

    Returns
    -------
    Dataset
        An object containing all of our training, validation, and test data.
        This object stores all the information about a text, namely, the raw
        text (which is stored both as text and as numeric tokens) and the
        proper class for every text.

    Notes
    -----
    If you have texts in this format already then you can skip most of this
    function and simply jump to the `get_dataset_from_files` function.

    References
    ----------
    Most of this function is taken from the code in
    https://huggingface.co/course/chapter5/5?fw=pt.
    Note that that link is very confusing at times, so if you read that code
    remember to be patient.
    """
    # We read all of our (toy) data from NLTK. The exception catches the case
    # where we haven't downloaded it yet and downloads it.
    try:
        all_ids = movie_reviews.fileids()
    except LookupError:
        import nltk
        nltk.download('movie_reviews')
        all_ids = movie_reviews.fileids()
    # We identify texts by their ID. Given that we have those IDs now, we
    # proceed to shuffle them and assign each one to a specific split.
    # Note that we fix the random seed to ensure that the datasets are always
    # the same across runs. This is not strictly necessary after creating a
    # dataset object (because we want to reuse it), but it is useful if you
    # delete that dataset and want to recreate it from scratch.
    random.seed(42)
    random.shuffle(all_ids)
    num_records = len(all_ids)
    # This dictionary holds the name of the temporay files we will create.
    tmp_files = dict()
    # We assigne the IDs we read before to their data splits.
    # Note that we use a fixed 80/10/10 fixed split.
    for split in ['train', 'val', 'test']:
        if split == 'train':
            split_ids = all_ids[:int(0.8 * num_records)]
        elif split == 'val':
            split_ids = all_ids[int(0.8 * num_records):int(0.9 * num_records)]
        else:
            split_ids = all_ids[int(0.9 * num_records):]
        all_data = []
        # We read the data one record at the time and store it as dictionaries.
        for text_id in split_ids:
            # We read the text.
            text = movie_reviews.raw(text_id)
            # We read the class and store it as an integer.
            if movie_reviews.categories(text_id)[0] == 'neg':
                label = 0
            else:
                label = 1
            all_data.append({'id': text_id, 'text': text, 'label': label})
        # Step two is to save this data as a list of JSON records in a temporary
        # file. If you use your own data you can simply generate these files as
        # JSON from scratch and save yourself this step.
        # Note that we also save the name of the temporary file in a dictionary
        # so we can re-read it later.
        tmp_file = tempfile.NamedTemporaryFile(delete=False)
        tmp_files[split] = tmp_file.name
        tmp_file.write(json.dumps(all_data).encode())
        tmp_file.close()

    # Step three is to create a data loader that will read the data we just
    # saved to disk.
    dataset = get_dataset_from_json_files(tmp_files['train'], tmp_files['val'], tmp_files['test'])

    # Remove the temporary files and return the tokenized dataset.
    for _, filename in tmp_files.items():
        os.unlink(filename)
    return dataset


def get_dataset_from_json_files(train_file, val_file=None, test_file=None):
    """ Given a set of properly-formatted files, it reads them and returns
    a Dataset object containing all of them.

    Parameters
    ----------
    train_file : str
        Path to a file containing training data.
    val_file : str
        Path to a file containing validation data.
    test_file : str
        Path to a file containing test data.

    Returns
    -------
    Dataset
        An object containing all of our training, validation, and test data.
        This object stores all the information about a text, namely, the raw
        text (which is stored both as text and as numeric tokens) and the
        proper class for every text.

    Notes
    -----
    IMPORTANT: the first time you run this code, the resulting dataset is saved
    to a temporary file. The console will tell you where it is (in Linux it is
    /home/<user>/.cache/huggingface/datasets/...). This temporary location is
    used in all subsequent calls, so if you change your dataset remember to
    remove this cached file first!

    This function is a thin wrapper around the `load_dataset` function where
    we hard-coded the file format to use JSON.
    If you don't want to read that function's documentation, it is enough to
    provide files whose content is simply a JSON list of dictionaries, like so:
    [{'id': '1234', 'text': 'My first text', 'label': 0}, {'id': '5678', 'text': 'My second text', 'label': 1}]
    """
    files_dict = {'train': train_file}
    if val_file is not None:
        files_dict['val'] = val_file
    if test_file is not None:
        files_dict['test'] = test_file
    dataset = load_dataset('json', data_files=files_dict)

    # Tokenize the data
    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True)
    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    # Return the tokenized dataset.
    return tokenized_dataset


if __name__ == '__main__':
    # In which mode we want to use this script
    # 'train': start from a pre-trained model, fine-tune it to a dataset, and then
    #          save the resulting model
    # 'test': evaluate the performance of a fine-tuned model over the test data.
    # 'predict': make predictions for unseen texts.
    mode = 'predict'
    assert mode in ['train', 'test', 'predict'], f'Invalid mode {mode}'
    # We define here the batch size and training epochs we want to use.
    batch_size = 16
    epochs = 10
    # Where to save the trained model.
    trained_model_dir = 'training_output_dir'
    if mode == 'train':
        # Train mode. Code adapted mostly from:
        # https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer.
        # Collect the training data
        inputs = get_data()
        # Since we are starting from scratch we download a pre-trained bert model
        # from HuggingFace. Note that we are hard-coding the number of classes here!
        model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
        # Define training arguments
        # You can define a loooot more hyperparameters - see them all in
        # https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
        training_args = TrainingArguments(output_dir=trained_model_dir,
                                          evaluation_strategy='epoch',
                                          per_device_train_batch_size=batch_size,
                                          num_train_epochs=epochs,
                                          learning_rate=1e-5)
        # Define training evaluation metrics
        metric = evaluate.load('accuracy')

        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            predictions = np.argmax(logits, axis=-1)
            return metric.compute(predictions=predictions, references=labels)
        # Define the trainer object and start training.
        # The model will be saved automatically every 500 epochs.
        trainer = Trainer(model=model,
                          args=training_args,
                          train_dataset=inputs['train'],
                          eval_dataset=inputs['val'],
                          compute_metrics=compute_metrics)
        trainer.train()
    elif mode == 'test':
        # Perform a whole run over the validation set.
        # This code is adapted mostly from:
        # https://huggingface.co/docs/evaluate/base_evaluator
        # Some useful information also here:
        # https://huggingface.co/docs/datasets/metrics

        # Collect the validation data
        inputs = get_data()
        # Create our evaluator
        task_evaluator = evaluate.evaluator("text-classification")

        # Use the model we trained before to predict over the validation data.
        # Note that we are providing the name of the classes by hand. We are not
        # using these labels here, but it is useful for the prediction branch.
        model = DistilBertForSequenceClassification.from_pretrained(f'./{trained_model_dir}/checkpoint-1000/',
                                                                    num_labels=2,
                                                                    id2label={0: 'negative', 1: 'positive'})
        # Define the evaluation parameters. We use the common text evaluation
        # measures, but there are plenty more.
        eval_results = task_evaluator.compute(
            model_or_pipeline=model,
            tokenizer=tokenizer,
            data=inputs['test'],
            metric=evaluate.combine(['accuracy', 'precision', 'recall', 'f1']),
            label_mapping={"negative": 0, "positive": 1}
        )
        # `eval_results` is a dictionary with the same keys we defined in
        # the `metric` parameters, plus some time measures.
        print(eval_results)
    elif mode == 'predict':
        # Make predictions for individual texts.
        # Same as above, we use the model we trained before to predict a single text.
        model = DistilBertForSequenceClassification.from_pretrained(f'./{trained_model_dir}/checkpoint-1000/',
                                                                    num_labels=2,
                                                                    id2label={0: 'negative', 1: 'positive'})
        # We build a text classification pipeline.
        # Note that `top_k=None` gives us probabilities for every class while
        # `top_k=1` returns values for the best class only.
        # Inspired on https://discuss.huggingface.co/t/i-have-trained-my-classifier-now-how-do-i-do-predictions/3625
        pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer,
                                          batch_size=batch_size, top_k=None)
        sequences = ['Hello, my dog is cute', 'Hello, my dog is not cute']
        predictions = pipe(sequences)
        print(predictions)
        # The above print shows something like:
        # [[{'label': 'positive', 'score': 0.574},
        #   {'label': 'negative', 'score': 0.426}],
        #  [{'label': 'negative', 'score': 0.868},
        #  {'label': 'positive', 'score': 0.132}]]

        # For a more heavy-duty approach we now classify an entire dataset.
        inputs = get_data()
        test_inputs = inputs['test']
        outputs = []
        # Note that 'text' is the key we defined in the `get_data` function.
        for out in pipe(KeyDataset(test_inputs, 'text'), batch_size=batch_size, truncation=True, max_length=512):
            outputs.append(out)
        # Now that we collected all outputs we massage them a little bit for
        # a more friendly format. For every prediction we will print a line like:
        # 'pos/cv093_13951.txt: positive (0.996)'
        for i in range(len(test_inputs)):
            id = test_inputs[i]['id']
            pred_label = outputs[i][0]['label']
            pred_prob = outputs[i][0]['score']
            print(f'{id}: {pred_label} ({pred_prob})')