Fine-tuning a Transformer for text classification

I have found myself often enough trying to classify a text that I decided to be lazy and write a solution that I could always rely on.

The following script is an example on how to pre-train a HuggingFace transformer for text classification. This example uses only two classes, identifying whether the sentiment of a text is positive or negative, but nothing stops you from adding more classes as long as you predict just one.

There are no difficult imports that I can tell - I think you can install all the required libraries with the command

pip install transformers datasets evaluate nltk

but I'll double check just in case.

Update: I have changed the script to extract the function get_dataset_from_json_files(). You can use this function as is to replace the calls to get_data(). I also updated the prediction code to give an example of how to make predictions for a lot of texts at once.

import json
import numpy as np
import os
import random
import tempfile
# These are all HuggingFace libraries
from datasets import load_dataset
import evaluate
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification,\
     TrainingArguments, Trainer, TextClassificationPipeline
# KeyDataset is a util that will just output the item we're interested in.
from transformers.pipelines.pt_utils import KeyDataset
# Library for importing a sentiment classification dataset.
# Only here for demo purposes, as you would use your own dataset.
from nltk.corpus import movie_reviews

# The BERT tokenizer is always the same so we declare it here globally.
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

def get_data():
    """ Returns a dataset built from a toy corpora.

        An object containing all of our training, validation, and test data.
        This object stores all the information about a text, namely, the raw
        text (which is stored both as text and as numeric tokens) and the
        proper class for every text.

    If you have texts in this format already then you can skip most of this
    function and simply jump to the `get_dataset_from_files` function.

    Most of this function is taken from the code in
    Note that that link is very confusing at times, so if you read that code
    remember to be patient.
    # We read all of our (toy) data from NLTK. The exception catches the case
    # where we haven't downloaded it yet and downloads it.
        all_ids = movie_reviews.fileids()
    except LookupError:
        import nltk
        all_ids = movie_reviews.fileids()
    # We identify texts by their ID. Given that we have those IDs now, we
    # proceed to shuffle them and assign each one to a specific split.
    # Note that we fix the random seed to ensure that the datasets are always
    # the same across runs. This is not strictly necessary after creating a
    # dataset object (because we want to reuse it), but it is useful if you
    # delete that dataset and want to recreate it from scratch.
    num_records = len(all_ids)
    # This dictionary holds the name of the temporay files we will create.
    tmp_files = dict()
    # We assigne the IDs we read before to their data splits.
    # Note that we use a fixed 80/10/10 fixed split.
    for split in ['train', 'val', 'test']:
        if split == 'train':
            split_ids = all_ids[:int(0.8 * num_records)]
        elif split == 'val':
            split_ids = all_ids[int(0.8 * num_records):int(0.9 * num_records)]
            split_ids = all_ids[int(0.9 * num_records):]
        all_data = []
        # We read the data one record at the time and store it as dictionaries.
        for text_id in split_ids:
            # We read the text.
            text = movie_reviews.raw(text_id)
            # We read the class and store it as an integer.
            if movie_reviews.categories(text_id)[0] == 'neg':
                label = 0
                label = 1
            all_data.append({'id': text_id, 'text': text, 'label': label})
        # Step two is to save this data as a list of JSON records in a temporary
        # file. If you use your own data you can simply generate these files as
        # JSON from scratch and save yourself this step.
        # Note that we also save the name of the temporary file in a dictionary
        # so we can re-read it later.
        tmp_file = tempfile.NamedTemporaryFile(delete=False)
        tmp_files[split] = tmp_file.name

    # Step three is to create a data loader that will read the data we just
    # saved to disk.
    dataset = get_dataset_from_json_files(tmp_files['train'], tmp_files['val'], tmp_files['test'])

    # Remove the temporary files and return the tokenized dataset.
    for _, filename in tmp_files.items():
    return dataset

def get_dataset_from_json_files(train_file, val_file=None, test_file=None):
    """ Given a set of properly-formatted files, it reads them and returns
    a Dataset object containing all of them.

    train_file : str
        Path to a file containing training data.
    val_file : str
        Path to a file containing validation data.
    test_file : str
        Path to a file containing test data.

        An object containing all of our training, validation, and test data.
        This object stores all the information about a text, namely, the raw
        text (which is stored both as text and as numeric tokens) and the
        proper class for every text.

    IMPORTANT: the first time you run this code, the resulting dataset is saved
    to a temporary file. The console will tell you where it is (in Linux it is
    /home/<user>/.cache/huggingface/datasets/...). This temporary location is
    used in all subsequent calls, so if you change your dataset remember to
    remove this cached file first!

    This function is a thin wrapper around the `load_dataset` function where
    we hard-coded the file format to use JSON.
    If you don't want to read that function's documentation, it is enough to
    provide files whose content is simply a JSON list of dictionaries, like so:
    [{'id': '1234', 'text': 'My first text', 'label': 0}, {'id': '5678', 'text': 'My second text', 'label': 1}]
    files_dict = {'train': train_file}
    if val_file is not None:
        files_dict['val'] = val_file
    if test_file is not None:
        files_dict['test'] = test_file
    dataset = load_dataset('json', data_files=files_dict)

    # Tokenize the data
    def tokenize_function(examples):
        return tokenizer(examples['text'], padding='max_length', truncation=True)
    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    # Return the tokenized dataset.
    return tokenized_dataset

if __name__ == '__main__':
    # In which mode we want to use this script
    # 'train': start from a pre-trained model, fine-tune it to a dataset, and then
    #          save the resulting model
    # 'test': evaluate the performance of a fine-tuned model over the test data.
    # 'predict': make predictions for unseen texts.
    mode = 'predict'
    assert mode in ['train', 'test', 'predict'], f'Invalid mode {mode}'
    # We define here the batch size and training epochs we want to use.
    batch_size = 16
    epochs = 10
    # Where to save the trained model.
    trained_model_dir = 'training_output_dir'
    if mode == 'train':
        # Train mode. Code adapted mostly from:
        # https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer.
        # Collect the training data
        inputs = get_data()
        # Since we are starting from scratch we download a pre-trained bert model
        # from HuggingFace. Note that we are hard-coding the number of classes here!
        model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
        # Define training arguments
        # You can define a loooot more hyperparameters - see them all in
        # https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
        training_args = TrainingArguments(output_dir=trained_model_dir,
        # Define training evaluation metrics
        metric = evaluate.load('accuracy')

        def compute_metrics(eval_pred):
            logits, labels = eval_pred
            predictions = np.argmax(logits, axis=-1)
            return metric.compute(predictions=predictions, references=labels)
        # Define the trainer object and start training.
        # The model will be saved automatically every 500 epochs.
        trainer = Trainer(model=model,
    elif mode == 'test':
        # Perform a whole run over the validation set.
        # This code is adapted mostly from:
        # https://huggingface.co/docs/evaluate/base_evaluator
        # Some useful information also here:
        # https://huggingface.co/docs/datasets/metrics

        # Collect the validation data
        inputs = get_data()
        # Create our evaluator
        task_evaluator = evaluate.evaluator("text-classification")

        # Use the model we trained before to predict over the validation data.
        # Note that we are providing the name of the classes by hand. We are not
        # using these labels here, but it is useful for the prediction branch.
        model = DistilBertForSequenceClassification.from_pretrained(f'./{trained_model_dir}/checkpoint-1000/',
                                                                    id2label={0: 'negative', 1: 'positive'})
        # Define the evaluation parameters. We use the common text evaluation
        # measures, but there are plenty more.
        eval_results = task_evaluator.compute(
            metric=evaluate.combine(['accuracy', 'precision', 'recall', 'f1']),
            label_mapping={"negative": 0, "positive": 1}
        # `eval_results` is a dictionary with the same keys we defined in
        # the `metric` parameters, plus some time measures.
    elif mode == 'predict':
        # Make predictions for individual texts.
        # Same as above, we use the model we trained before to predict a single text.
        model = DistilBertForSequenceClassification.from_pretrained(f'./{trained_model_dir}/checkpoint-1000/',
                                                                    id2label={0: 'negative', 1: 'positive'})
        # We build a text classification pipeline.
        # Note that `top_k=None` gives us probabilities for every class while
        # `top_k=1` returns values for the best class only.
        # Inspired on https://discuss.huggingface.co/t/i-have-trained-my-classifier-now-how-do-i-do-predictions/3625
        pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer,
                                          batch_size=batch_size, top_k=None)
        sequences = ['Hello, my dog is cute', 'Hello, my dog is not cute']
        predictions = pipe(sequences)
        # The above print shows something like:
        # [[{'label': 'positive', 'score': 0.574},
        #   {'label': 'negative', 'score': 0.426}],
        #  [{'label': 'negative', 'score': 0.868},
        #  {'label': 'positive', 'score': 0.132}]]

        # For a more heavy-duty approach we now classify an entire dataset.
        inputs = get_data()
        test_inputs = inputs['test']
        outputs = []
        # Note that 'text' is the key we defined in the `get_data` function.
        for out in pipe(KeyDataset(test_inputs, 'text'), batch_size=batch_size, truncation=True, max_length=512):
        # Now that we collected all outputs we massage them a little bit for
        # a more friendly format. For every prediction we will print a line like:
        # 'pos/cv093_13951.txt: positive (0.996)'
        for i in range(len(test_inputs)):
            id = test_inputs[i]['id']
            pred_label = outputs[i][0]['label']
            pred_prob = outputs[i][0]['score']
            print(f'{id}: {pred_label} ({pred_prob})')

Thoughts on ChatGPT, LLMs, and search engines

Now that everyone is talking about Large Language Models (LLMs) in general and ChatGPT in particular I thought I would share a couple thoughts I've been having about this technology that I haven't seen anywhere else. But first, let's talk about chess.

Chess is a discipline notorious for coming out on top when the robots came to take its lunch. Once it became clear that even the humblest of PCs can play better than a World Chess Champion, the chess community started using these chess engines to improve their games and learn new tactics. Having adopted these engines as a fact of life, chess is as popular as ever (if not more). Sure, a computer can do a "better" job on a fraction of the time, but who cares? It's not like there's an unmet economic need for industrial-strength chess players.

In contrast, one field that isn't doing as well is digital art. With the advent of diffusion models like Midjourney and Stable Diffusion many artists are worried that their livelihoods are now at stake and are letting the world know, ranging from comics (I like this one more but this one hits closer to home) all the way to class action lawsuits. My quick take is that these models are here to stay and that they won't necessarily destroy art but they will probably cripple the art business the same way cheap restaurants download food pictures from the internet instead of paying a professional photographer.

LLMs are unusual in that sense because they aren't disrupting a field as much as a means of communication. "Chess player" is something one does, and so is "artist"1. But "individual who uses language" is on a different class altogether, and while some of the effects of these technologies are easy to guess, others not so much.

Random predictions

What are we likely to see in the near future? The easiest prediction is a rise in plagiarism. Students are already using ChatGPT to generate essays regardless of whether the output makes sense or not. And spam will follow closely behind: we are already awash in repost spam and ramblings disguised as recipes, but once people start submitting auto-generated sci-fi stories to magazines we will see garbage in any medium that offers even the slightest of rewards, financial or not. Do you want a high-karma account on Reddit to establish yourself as not-a-spammer and use it to push products? Just put your payment information here and the robots will comment for you. No human interaction needed2.

What I find more interesting and likely more disruptive is the replacement of search engines with LLMs. If I ask ChatGPT right now about my PhD advisor I get an answer - this particular answer happens to be all wrong3, but let's pretend that it's not 4. This information came from some website, but the system is not telling me which one. And here there's both an opportunity and a risk. The opportunity is the chance of cutting through the spam: if I ask for a recipe for poached eggs and I get a recipe for poached eggs then I no longer have to waddle through long-winded essays on how poached eggs remind someone of evening at their grandma's house. On the other hand, this also means that all the information we collectively placed on the internet would be used for the profit of some company without even the meagre attributions we currently get.

On the long tail, and entering into guessing territory, it would be tragic if people started writing like ChatGPT. These systems have a particular writing style composed of multiple short sentences and it's not hard to imagine that young, impressionable people may start copying this output once it is widespread enough. This has happened before with SMS, so I don't see why it couldn't happen again.

Pointless letters and moving forward

One positive way to move forward would be to accept that a lot of our daily communication is so devoid of content that even a computer with no understanding of the real world can do it and work on that.

When I left my last job I auto-generated a goodbye e-mail with GPT-3, and the result was so incredibly generic that no one would have been able to learn anything from it. On the other hand, I doubt anyone would have noticed: once you've read a hundred references to "the good memories" you no longer stop to wonder whether there were any good memories to begin with. I didn't send that auto-generated e-mail. In fact, I didn't send anything: I had already said goodbye in person to the people that knew me and there was no reason to say anything else. The amount of information that was conveyed was exactly the same, but my solution wasted less of other people's time.

Maybe this is our opportunity to freshen up our writing and start writing interestingly, both in form (long sentences for the win!) and in content. The most straightforward solution would be cursing: these models have to be attractive to would-be investors they are strictly programmed not to use curse words and NSFW content (I just tried). So there's a style that no AI will be copying in the near future.


  1. Note that this is a simplification for the sake of the argument. As someone who often said "being a programmer is both what I do and what I am" I am aware that "artist" (like so many other professions) isn't just a job but also a way of looking at the world.
  2. I have noticed that people on Reddit will upvote anything without reading it first, so this is not a high bar to clear.
  3. The answer mashed together several researchers into one. One could argue that I got more researchers per researcher, which is definitely a take.
  4. The answer doesn't have to be correct - all it takes is for the person using the system to believe that the answer is correct, something we are already seeing despite the overwhelming evidence to the contrary.