Fine-tuning a Transformer for text classification
I have found myself often enough trying to classify a text that I decided to be lazy and write a solution that I could always rely on.
The following script is an example on how to pre-train a HuggingFace transformer for text classification. This example uses only two classes, identifying whether the sentiment of a text is positive or negative, but nothing stops you from adding more classes as long as you predict just one.
There are no difficult imports that I can tell - I think you can install all the required libraries with the command
pip install transformers datasets evaluate nltk
but I'll double check just in case.
Update: I have changed the script to extract the function get_dataset_from_json_files()
.
You can use this function as is to replace the calls to get_data()
. I also updated the prediction code to give an example of
how to make predictions for a lot of texts at once.
import json
import numpy as np
import os
import random
import tempfile
# These are all HuggingFace libraries
from datasets import load_dataset
import evaluate
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification,\
TrainingArguments, Trainer, TextClassificationPipeline
# KeyDataset is a util that will just output the item we're interested in.
from transformers.pipelines.pt_utils import KeyDataset
# Library for importing a sentiment classification dataset.
# Only here for demo purposes, as you would use your own dataset.
from nltk.corpus import movie_reviews
# The BERT tokenizer is always the same so we declare it here globally.
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
def get_data():
""" Returns a dataset built from a toy corpora.
Returns
-------
Dataset
An object containing all of our training, validation, and test data.
This object stores all the information about a text, namely, the raw
text (which is stored both as text and as numeric tokens) and the
proper class for every text.
Notes
-----
If you have texts in this format already then you can skip most of this
function and simply jump to the `get_dataset_from_files` function.
References
----------
Most of this function is taken from the code in
https://huggingface.co/course/chapter5/5?fw=pt.
Note that that link is very confusing at times, so if you read that code
remember to be patient.
"""
# We read all of our (toy) data from NLTK. The exception catches the case
# where we haven't downloaded it yet and downloads it.
try:
all_ids = movie_reviews.fileids()
except LookupError:
import nltk
nltk.download('movie_reviews')
all_ids = movie_reviews.fileids()
# We identify texts by their ID. Given that we have those IDs now, we
# proceed to shuffle them and assign each one to a specific split.
# Note that we fix the random seed to ensure that the datasets are always
# the same across runs. This is not strictly necessary after creating a
# dataset object (because we want to reuse it), but it is useful if you
# delete that dataset and want to recreate it from scratch.
random.seed(42)
random.shuffle(all_ids)
num_records = len(all_ids)
# This dictionary holds the name of the temporay files we will create.
tmp_files = dict()
# We assigne the IDs we read before to their data splits.
# Note that we use a fixed 80/10/10 fixed split.
for split in ['train', 'val', 'test']:
if split == 'train':
split_ids = all_ids[:int(0.8 * num_records)]
elif split == 'val':
split_ids = all_ids[int(0.8 * num_records):int(0.9 * num_records)]
else:
split_ids = all_ids[int(0.9 * num_records):]
all_data = []
# We read the data one record at the time and store it as dictionaries.
for text_id in split_ids:
# We read the text.
text = movie_reviews.raw(text_id)
# We read the class and store it as an integer.
if movie_reviews.categories(text_id)[0] == 'neg':
label = 0
else:
label = 1
all_data.append({'id': text_id, 'text': text, 'label': label})
# Step two is to save this data as a list of JSON records in a temporary
# file. If you use your own data you can simply generate these files as
# JSON from scratch and save yourself this step.
# Note that we also save the name of the temporary file in a dictionary
# so we can re-read it later.
tmp_file = tempfile.NamedTemporaryFile(delete=False)
tmp_files[split] = tmp_file.name
tmp_file.write(json.dumps(all_data).encode())
tmp_file.close()
# Step three is to create a data loader that will read the data we just
# saved to disk.
dataset = get_dataset_from_json_files(tmp_files['train'], tmp_files['val'], tmp_files['test'])
# Remove the temporary files and return the tokenized dataset.
for _, filename in tmp_files.items():
os.unlink(filename)
return dataset
def get_dataset_from_json_files(train_file, val_file=None, test_file=None):
""" Given a set of properly-formatted files, it reads them and returns
a Dataset object containing all of them.
Parameters
----------
train_file : str
Path to a file containing training data.
val_file : str
Path to a file containing validation data.
test_file : str
Path to a file containing test data.
Returns
-------
Dataset
An object containing all of our training, validation, and test data.
This object stores all the information about a text, namely, the raw
text (which is stored both as text and as numeric tokens) and the
proper class for every text.
Notes
-----
IMPORTANT: the first time you run this code, the resulting dataset is saved
to a temporary file. The console will tell you where it is (in Linux it is
/home/<user>/.cache/huggingface/datasets/...). This temporary location is
used in all subsequent calls, so if you change your dataset remember to
remove this cached file first!
This function is a thin wrapper around the `load_dataset` function where
we hard-coded the file format to use JSON.
If you don't want to read that function's documentation, it is enough to
provide files whose content is simply a JSON list of dictionaries, like so:
[{'id': '1234', 'text': 'My first text', 'label': 0}, {'id': '5678', 'text': 'My second text', 'label': 1}]
"""
files_dict = {'train': train_file}
if val_file is not None:
files_dict['val'] = val_file
if test_file is not None:
files_dict['test'] = test_file
dataset = load_dataset('json', data_files=files_dict)
# Tokenize the data
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Return the tokenized dataset.
return tokenized_dataset
if __name__ == '__main__':
# In which mode we want to use this script
# 'train': start from a pre-trained model, fine-tune it to a dataset, and then
# save the resulting model
# 'test': evaluate the performance of a fine-tuned model over the test data.
# 'predict': make predictions for unseen texts.
mode = 'predict'
assert mode in ['train', 'test', 'predict'], f'Invalid mode {mode}'
# We define here the batch size and training epochs we want to use.
batch_size = 16
epochs = 10
# Where to save the trained model.
trained_model_dir = 'training_output_dir'
if mode == 'train':
# Train mode. Code adapted mostly from:
# https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer.
# Collect the training data
inputs = get_data()
# Since we are starting from scratch we download a pre-trained bert model
# from HuggingFace. Note that we are hard-coding the number of classes here!
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
# Define training arguments
# You can define a loooot more hyperparameters - see them all in
# https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
training_args = TrainingArguments(output_dir=trained_model_dir,
evaluation_strategy='epoch',
per_device_train_batch_size=batch_size,
num_train_epochs=epochs,
learning_rate=1e-5)
# Define training evaluation metrics
metric = evaluate.load('accuracy')
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
# Define the trainer object and start training.
# The model will be saved automatically every 500 epochs.
trainer = Trainer(model=model,
args=training_args,
train_dataset=inputs['train'],
eval_dataset=inputs['val'],
compute_metrics=compute_metrics)
trainer.train()
elif mode == 'test':
# Perform a whole run over the validation set.
# This code is adapted mostly from:
# https://huggingface.co/docs/evaluate/base_evaluator
# Some useful information also here:
# https://huggingface.co/docs/datasets/metrics
# Collect the validation data
inputs = get_data()
# Create our evaluator
task_evaluator = evaluate.evaluator("text-classification")
# Use the model we trained before to predict over the validation data.
# Note that we are providing the name of the classes by hand. We are not
# using these labels here, but it is useful for the prediction branch.
model = DistilBertForSequenceClassification.from_pretrained(f'./{trained_model_dir}/checkpoint-1000/',
num_labels=2,
id2label={0: 'negative', 1: 'positive'})
# Define the evaluation parameters. We use the common text evaluation
# measures, but there are plenty more.
eval_results = task_evaluator.compute(
model_or_pipeline=model,
tokenizer=tokenizer,
data=inputs['test'],
metric=evaluate.combine(['accuracy', 'precision', 'recall', 'f1']),
label_mapping={"negative": 0, "positive": 1}
)
# `eval_results` is a dictionary with the same keys we defined in
# the `metric` parameters, plus some time measures.
print(eval_results)
elif mode == 'predict':
# Make predictions for individual texts.
# Same as above, we use the model we trained before to predict a single text.
model = DistilBertForSequenceClassification.from_pretrained(f'./{trained_model_dir}/checkpoint-1000/',
num_labels=2,
id2label={0: 'negative', 1: 'positive'})
# We build a text classification pipeline.
# Note that `top_k=None` gives us probabilities for every class while
# `top_k=1` returns values for the best class only.
# Inspired on https://discuss.huggingface.co/t/i-have-trained-my-classifier-now-how-do-i-do-predictions/3625
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer,
batch_size=batch_size, top_k=None)
sequences = ['Hello, my dog is cute', 'Hello, my dog is not cute']
predictions = pipe(sequences)
print(predictions)
# The above print shows something like:
# [[{'label': 'positive', 'score': 0.574},
# {'label': 'negative', 'score': 0.426}],
# [{'label': 'negative', 'score': 0.868},
# {'label': 'positive', 'score': 0.132}]]
# For a more heavy-duty approach we now classify an entire dataset.
inputs = get_data()
test_inputs = inputs['test']
outputs = []
# Note that 'text' is the key we defined in the `get_data` function.
for out in pipe(KeyDataset(test_inputs, 'text'), batch_size=batch_size, truncation=True, max_length=512):
outputs.append(out)
# Now that we collected all outputs we massage them a little bit for
# a more friendly format. For every prediction we will print a line like:
# 'pos/cv093_13951.txt: positive (0.996)'
for i in range(len(test_inputs)):
id = test_inputs[i]['id']
pred_label = outputs[i][0]['label']
pred_prob = outputs[i][0]['score']
print(f'{id}: {pred_label} ({pred_prob})')