Fine Tuning BERT for Text Classification

Lawrence Emenike, MSc, ACCA
8 min readNov 28, 2023

--

Image Generated with Dall.E 3

BERT, standing for Bidirectional Encoder Representations from Transformers, revolutionises how machines understand human language by interpreting context from all sides of a word. Transformers, the architecture behind BERT, have set new standards in NLP by focusing on the relationships between words in a sentence. In this post, I will share the practical steps of fine-tuning a pre-trained BERT model using TensorFlow to tackle a text classification challenge, showcasing how this powerful combination can be harnessed effectively.

Setting Up the Environment

Before diving into the model, it’s crucial to set up a robust environment. For this project, TensorFlow is our primary library, providing an efficient backend for model training and evaluation.

Data Preparation

Our first step is preparing the data. We start with a dataset (which you can replace with your own) and perform preprocessing, including tokenization and encoding, to convert text data into a format understandable by BERT. For this demonstration, I used the Quora Insincere Questions Classification dataset, which can be downloaded from here. After downloading and decompressing the file, I read the data into a pandas DataFrame. The next crucial step is preprocessing, which involves tokenization and encoding to transform our text data into a format that BERT can comprehend and process effectively.

Analysing Data Distribution and Preparing Datasets

After loading the Quora Insincere Questions dataset into a pandas DataFrame, the first step was to visualise the distribution of the target classes. The histogram plot revealed a significant imbalance with a larger number of sincere questions (label 0) compared to insincere ones (label 1). Such an imbalance is a common challenge in classification tasks as it can bias the model towards the majority class.

To address this, I split the data into training and validation sets, ensuring that both sets reflected the original distribution. This stratification is crucial to evaluate the model’s performance accurately during validation.

I previewed the data pipeline to ensure that our text inputs and labels were correctly mapped and ready for model consumption. This step is like a sanity check before we commit to model training.

Downloading and Preparing the BERT Model

I proceeded by downloading a pre-trained model from TensorFlow Hub. This model comes equipped with the capability to process text in the way that BERT was originally designed to, which includes a deep understanding of language context.

The crux of BERT’s text processing lies in converting text to input features that the model can digest. This involves tokenizing the text into tokens that BERT understands and mapping these tokens to their respective IDs in the BERT vocabulary. I established the maximum sequence length that BERT would handle and chose a suitable batch size for training.

Integrating BERT with TensorFlow’s Data Pipeline

To prepare the data for the BERT model, I needed to convert it into a format that BERT could understand. This meant transforming the text data into InputExamples using the BERT library’s constructor, and then further converting these examples into features that BERT would use, such as input_ids, input_mask, and segment_ids. I defined a function, to_feature, that took text and labels and processed them into the required format.

Since TensorFlow operates in graph mode, which is non-eager execution, it requires all operations to be converted into TensorFlow operations to be part of the computation graph. To incorporate the Python function to_feature into the TensorFlow graph, I used tf.py_function, which allowed me to wrap to_feature and make it compatible with TensorFlow operations.

This function was crucial to ensure that the input data was in the correct shape and type for BERT. It also highlighted the flexibility of TensorFlow’s Data API, which allowed me to map complex operations over each element of the dataset efficiently. By using Dataset.map(), I could apply to_feature_map across the entire dataset, thereby preparing my data for model training.

Crafting a TensorFlow Input Pipeline

For my BERT model to process the data efficiently, I created a robust input pipeline using TensorFlow’s tf.data API. This API is a powerful tool for handling large datasets and streamlining data preprocessing for deep learning models.

Here’s how I constructed the input pipeline:

  1. Mapping the Function: First, I mapped the to_feature_map function over the datasets to convert the raw text and labels into the format expected by BERT.
  2. Shuffling and Batching: Next, I shuffled the training data to ensure the model would not encounter any unintentional patterns during training. Then, I batched the data, which is essential for training deep learning models, as it defines the number of samples to work through before updating the internal model parameters.
  3. Prefetching: Finally, I used .prefetch to allow later data batches to be prepared while the current batch is being processed. This helps to optimise the training speed by reducing the time spent waiting for data.

The resulting tf.data.Datasets were now in the perfect shape to feed into the keras.Model.fit method. They returned tuples of (features, labels), aligning with the input expected by my BERT model.

Inspecting the data specifications confirmed that the pipeline was correctly configured. This output validated that the input features (input_word_ids, input_mask, input_type_ids) and labels were all properly shaped and typed, ensuring that my model would receive accurately preprocessed input when training and validation began. This attention to detail in the data pipeline setup is crucial for a smooth and effective model training process.

Building and Fine-Tuning the BERT Model for Text Classification

With my input pipeline in place, I turned my focus to constructing the model. I built a function create_model that initialises the necessary BERT inputs (input_word_ids, input_mask, and input_type_ids) and defines the architecture. These inputs were fed into the BERT layer I had previously prepared, capturing BERT's understanding of language nuances.

To adapt BERT for the specific task of text classification, I added a dropout layer for regularisation, followed by a dense layer with a sigmoid activation function for binary classification. After constructing the model, I compiled it with the Adam optimizer and binary cross-entropy loss function, which are well-suited for binary text classification tasks. Additionally, I selected binary accuracy as the metric for monitoring the training process.

Running model.summary() yielded a detailed view of the model's architecture, confirming that each layer was properly configured with the correct output shapes and parameters. With this, the model was ready for fine-tuning, where it would learn from the training data and adjust its weights to better classify whether questions from the dataset were sincere or insincere. This fine-tuning step is critical, as it personalizes the model to the nuances of the specific dataset and task at hand.

Training and Evaluating the BERT Model

With the model architecture defined and compiled, I initiated the training process over 4 epochs. This would iteratively adjust the model’s parameters to minimise the loss and increase accuracy on the text classification task at hand.

The verbose output of the training process provided real-time feedback on the model’s performance, showing a steady decrease in loss and improvement in binary accuracy, both on the training and validation sets. These metrics are essential indicators of the model’s learning progression and its ability to generalize beyond the training data to new, unseen inputs.

After the model training concluded, I moved onto evaluating its performance. Using Matplotlib, I plotted the training history to visualise the model’s learning curve. The plots served as a graphical representation, clearly indicating the trends in accuracy and loss across the epochs for both training and validation phases. This visual assessment is pivotal for identifying patterns such as overfitting or underfitting and for making informed decisions about potential adjustments in the model training strategy or architecture.

Addressing Overfitting in Model Training

The training and validation loss and accuracy plots indicate a classic case of overfitting — while the training accuracy increases and loss decreases across epochs, the validation accuracy plateaus and the validation loss actually starts to increase. This suggests that the model is fitting the training data more closely than it should, at the expense of its ability to generalise to new data.

To mitigate overfitting, I can employ several strategies:

  1. Data Augmentation: Generate new training samples by altering existing ones in ways that preserve their labels. For text data, this might involve synonym replacement, sentence shuffling, or back-translation.
  2. Regularization: Increase the dropout rate in the model or add L1/L2 regularization to the dense layers, which can help to penalize large weights.
  3. Early Stopping: Monitor the validation loss and stop training when it begins to increase, even if the training loss continues to decrease.
  4. Reduce Model Complexity: Simplify the model by reducing the number of layers or the number of units in the dense layers.
  5. Use Pre-trained Embeddings: Leverage transfer learning more effectively by keeping parts of the BERT model frozen during the initial phase of training.
  6. Cross-validation: Implement k-fold cross-validation to ensure that the model’s performance is consistent across different subsets of the data.

Applying these techniques requires careful consideration and iterative experimentation to find the right balance that improves the model’s generalization ability.

Conclusion

In this journey, I’ve walked through the process of preparing data for BERT, crafting an input pipeline, building and fine-tuning the model, and finally, evaluating its performance. The insights gained from visualising the training history are invaluable — they point towards the next steps in the model development process, which, in this case, involves addressing overfitting.

Fine-tuning BERT for text classification is an iterative process that often requires multiple rounds of training and evaluation. The ability to interpret the results, diagnose issues like overfitting, and implement strategies to counteract them are critical skills in a data scientist’s toolbox.

Through this project, the power of transfer learning with BERT for complex NLP tasks has been demonstrated, while also highlighting the importance of a systematic approach to model training and evaluation. The path forward includes refining the model further, with the ultimate goal of creating an AI solution that is not only powerful but also robust and reliable when faced with real-world data.

--

--

Lawrence Emenike, MSc, ACCA
Lawrence Emenike, MSc, ACCA

Written by Lawrence Emenike, MSc, ACCA

#DataScience #ConversationalAI #GenerativeAI #IntelligentAutomation #AIArt #Finance #BusinessStrategy

No responses yet