How to use BERT Question Answering in TensorFlow with Python
Introduction
Tensorflow is an open-source end-to-end platform for building machine learning-powered applications. It contains a set of rich tools and libraries used for various tasks but primarily focused on training and inference of neural networks. Tensorflow is used in several applications, including:
- Search Engines - for deploying deep neural networks for search ranking like Google's RankBrain.
- Automobiles - building neural networks designed for autonomous driving.
- Education - designing models to filter toxic chat messages in classrooms.
- Medicare - for building neural networks to detect potential health complications from patient data.
- Text summarization and sentiment analysis.
Bidirectional Encoder Representations from Transformers (BERT) is a natural language model that uses a transformer-based machine learning technique to accomplish several common language functions, including:
- Sentiment analysis - emotional inference such as determining the polarity of a movie's reviews.
- Natural language inference - determining whether a hypothesis logically follows from a premise.
- Named entity recognition - extracting information from unstructured text into predefined categories such as names of persons, locations, and quantities.
- Question answering - building conversational questions and answering systems such as chatbots.
- Text generation - generation of text indistinguishable from human-written texts.
- Text prediction - predictive text suggestions like Gmail's suggestive texts when composing emails.
- Text summarization.
BERT was pre-trained on 2.5 billion words from Wikipedia and 800 million from Google's BookCorpus.
This guide covers how to implement question-answering with BERT in TensorFlow using Python.
Prerequisites
- Working knowledge of Python.
- Properly installed and configured python toolchain, including pip (Python version >= 3.7).
Setting Up The Project Virtual Environment
Create an isolated virtual environment for your application:
Install the
virtualenv
python package:$ pip install virtualenv
Create the project directory:
$ mkdir bert_QA
Navigate into the new directory:
$ cd bert_QA
Create the virtual environment:
$ python3 -m venv env
This creates a new folder named
env
containing scripts to control the virtual environment, including program libraries.Activate the virtual environment:
$ source env/bin/activate
Installing TensorFlow
To install TensorFlow, enter the following command:
$ pip install tensorflow
Tflite Model Maker
Tflite model maker is a library that simplifies training a Tensorflow Lite model using custom datasets. It uses transfer learning to reduce the amount of training data and time required.
The tflite model maker library reduces the complexity of converting and adapting a Tensorflow neural-network model to particular input data when deploying for on-device ML applications. These models can be fine-tuned to work on memory and CPU-constrained devices like smartphones without sacrificing much of the accuracy when running on these low-power devices.
This guide uses the tflite model maker library to fine-tune a BERT model for question answering.
Installing Tflite Model Maker
To install the tflite model maker library:
Clone the repository:
git clone https://github.com/tensorflow/examples
Install the requirements:
pip install -r examples/tensorflow_examples/lite/model_maker/requirements.txt
Install the package:
pip install -e examples/tensorflow_examples/lite/model_maker/pip_package/
Building The Lite Model
To create the fine-tuned lite model responsible for question answering, create a lite_model_gen.py
file within the working directory:
touch lite_model_gen.py
Importing Libraries
Import the required libraries by adding the following lines to the lite_model_gen.py file:
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import question_answer
from tflite_model_maker.question_answer import DataLoader
The imported classes from the tflit model maker library have the following functions:
model_spec
: used to choose the model specification representing the model.question_answer
: used for data loading and model training for question answering.DataLoader
: provides generic utilities for loading custom data during model retraining.
Choosing a Model Spec
The tflite model maker library supports BERT-Base and MobileBERT models for question answering:
- BERT-Base - this is the standard BERT model used widely for NLP tasks.
- MobileBERT - is a compact version of BERT-Base that is about 4x smaller and almost 6x faster. MobileBERT achieves competitive results despite being smaller than BERT-Base and is more suitable for on-device use cases in power-constrained devices like smartphones.
- MobileBERT-SQuAD - this model utilizes the same architecture as the MobileBERT, but the initial model is already retrained on the Stanford Question Answering Dataset (SQuAD) 1.1. SQuAD is a reading comprehension dataset that consists of question-and-answer pairs posed by crowdworkers on a set of Wikipedia articles. SQuAD 1.1 contains 100,000+ question-answer pairs on 500+ articles.
To use the MobileBERT-SQuAD specification, add the following lines:
# Model specification representing the model
spec = model_spec.get('mobilebert_qa_squad')
model_spec
has a single method - get
, that takes the name of the model specification as an argument. For question-answering tasks with BERT, it can take any one of three strings as an argument:
mobilebert_qa_squad
: specifies MobileBERT-SQuAD.mobilebert_qa
: specifies MobileBERT.bert_qa
: specifies BERT-Base.
Getting Training Data
This guide uses the TriviaQA dataset for model training. TriviaQA is a large-scale dataset for reading comprehension and question answering containing over 650k question-answer-evidence triples.
To load the dataset for training, it should be converted to SQuAD1.1 format by using the TriviaQA conversion python script. This guide will use pre-converted datasets for training and validation.
To download the converted datasets, add the following lines:
# Download archived version of already converted datasets
train_data_path = tf.keras.utils.get_file(
fname='triviaqa-web-train-8000.json',
origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json')
validation_data_path = tf.keras.utils.get_file(
fname='triviaqa-verified-web-dev.json',
origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json')
The tf.keras.utils.get_file
function is used to download a file from a given URL if it's not already present in the cache. fname specifies the file's name, and origin specifies the original URL of the file. This downloads the dataset and returns the path to the downloaded file.
Loading the Datasets
Load the datasets by adding the following lines:
# Fetch the training and validation data
train_data = DataLoader.from_squad(train_data_path, spec, is_training=True)
validation_data = DataLoader.from_squad(validation_data_path, spec, is_training=False)
DataLoader.from_squad
loads the passed datasets in SQuAD format and preprocesses the text according to the given model_spec. This method takes the filename and model_spec as an argument, with optional arguments including is_training
- a boolean representing whether the data is for training or not.
Creating the Model
To create the Tflite model:
# Create the model
model = question_answer.create(train_data, model_spec=spec)
The question_answer.create
class method loads the data and trains the model for question answering. It takes two arguments - the training data and the specification for the model. It also takes optional arguments:
batch_size=None
: batch size for training.epochs=2
: number of epochs for training.steps_per_epoch=None
: batches of samples before declaring one epoch finished and starting the next epoch. It defaults to running until the input dataset is exhausted.shuffle=False
: takes a boolean determining whether the input data should be shuffled or not.
This method returns a model instance for Question Answering.
Evaluating the Model
To evaluate the model using the validation dataset:
# Evaluate the model
model.evaluate(validation_data)
Calling the evaluate
method on the model object returns a dictionary of metrics, including the f1
score and exact_match
.
Exporting the Model
To export the tflite model for use with on-device question answering:
# Export the model
model.export(export_dir='.')
This exports the model to the current working directory in the default TFLITE
format. Other export formats include VOCAB
and SAVED_MODEL
.
Model Generation Code
The final lite_model_gen.py
code:
import tensorflow as tf
from tflite_model_maker import model_spec
from tflite_model_maker import question_answer
from tflite_model_maker.question_answer import DataLoader
# Model specification representing the model
spec = model_spec.get('mobilebert_qa_squad')
# Download archived version of already converted datasets
train_data_path = tf.keras.utils.get_file(
fname='triviaqa-web-train-8000.json', origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json')
validation_data_path = tf.keras.utils.get_file(
fname='triviaqa-verified-web-dev.json',
origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json')
# Fetch the training and validation data
train_data = DataLoader.from_squad(train_data_path, spec, is_training=True)
validation_data = DataLoader.from_squad(validation_data_path, spec, is_training=False)
# Create the model
model = question_answer.create(train_data, model_spec=spec)
# Evaluate the model
model.evaluate(validation_data)
# Export the model
model.export(export_dir='.')
Running the Code
Run the code:
$ python3 lite_model_gen.py
Note: The code runtime lies between an hour and several hours, depending on CPU performance or the presence of a GPU.
The output looks similar to this:
…
….
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json
32571392/32570663 [==============================] - 0s 0us/step
32579584/32570663 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json
1171456/1167744 [==============================] - 0s 0us/step
1179648/1167744 [==============================] - 0s 0us/step
INFO:tensorflow:Retraining the models...
INFO:tensorflow:Retraining the models...
Epoch 1/2
1067/1067 [==============================] - 70867s 66s/step - loss: 1.1337 - start_positions_loss: 1.1310 - end_positions_loss: 1.1363
Epoch 2/2
1067/1067 [==============================] - 70983s 67s/step - loss: 0.7942 - start_positions_loss: 0.7934 - end_positions_loss: 0.7949
INFO:tensorflow:Made predictions for 200 records.
INFO:tensorflow:Made predictions for 200 records.
…
…
{'exact_match': 0.5986394557823129, 'final_f1': 0.6728435963129841}
…
…
INFO:tensorflow:TensorFlow Lite model exported successfully: ./model.tflite
After it runs successfully, the lite model is exported to the current project directory.
Building Question Answering
To integrate the exported model for question answering, this guide makes use of the TFlite Support toolkit. This toolkit comes with powerful libraries to integrate TFLite models onto different platforms.
To install:
$ pip install tflite-support
Now, create the question-answer script:
$ touch question_answer.py
Importing Library
Import the BertQuestionAnswerer
class:
from tflite_support.task.text import BertQuestionAnswerer
The BertQuestionAnswerer class performs question answering on text.
Create the BertQuestionAnswerer Object
Create the BertQuestionAnswerer object from the exported lite model:
# Create the BertQuestionAnswerer object from a TensorFlow lite model
question_answerer = BertQuestionAnswerer.create_from_file("./model.tflite")
This returns a BertQuestionAnswerer object created from the model file.
Create Question Answer Context
A context is needed for question answering. This context could be a passage or a sentence from which questions are asked. The following text about Alexander the Great is obtained from Wikipedia. This will be used as a context for question answering:
# Create context for question answering
context = "Alexander the Great was a king of the ancient Greek kingdom of Macedon. He succeeded his father Philip II to the throne in 336 BC at the age of 20, and spent most of his ruling years conducting a lengthy military campaign throughout Western Asia and Egypt. By the age of thirty, he had created one of the largest empires in history, stretching from Greece to northwestern India. He was undefeated in battle and is widely considered to be one of history's greatest and most successful military commanders. Until the age of 16, Alexander was tutored by Aristotle. In 335 BC, shortly after his assumption of kingship over Macedon, he campaigned in the Balkans and reasserted control over Thrace and Illyria before marching on the city of Thebes, which was subsequently destroyed in battle. Alexander then led the League of Corinth, and used his authority to launch the pan-Hellenic project envisaged by his father, assuming leadership over all Greeks in their conquest of Persia."
Create Questions
Create a dictionary to hold question-answer pairs based on the above context:
# Create questions
questions = {
"Who is Alexander the Great": None,
"Until the age of 16, who tutored Alexander": None,
}
Answering Questions
To answer the questions in the dictionary:
# Answer questions
for question in questions.keys():
answer = question_answerer.answer(context, question)
questions[question] = answer
print(questions)
The answer
method takes two arguments, a context and a question. It answers questions based on the context and returns a QuestionAnswererResult
. QuestionAnswererResult is a list of probable answers generated by the BertQuestionAnswerer.
Final Question Answering Code
The full question_answer.py code:
from tflite_support.task.text import BertQuestionAnswerer
# Create the BertQuestionAnswerer object from a TensorFlow lite model
question_answerer = BertQuestionAnswerer.create_from_file("./model.tflite")
# Create context for question answering
context = "Alexander the Great was a king of the ancient Greek kingdom of Macedon. He succeeded his father Philip II to the throne in 336 BC at the age of 20, and spent most of his ruling years conducting a lengthy military campaign throughout Western Asia and Egypt. By the age of thirty, he had created one of the largest empires in history, stretching from Greece to northwestern India. He was undefeated in battle and is widely considered to be one of history's greatest and most successful military commanders. Until the age of 16, Alexander was tutored by Aristotle. In 335 BC, shortly after his assumption of kingship over Macedon, he campaigned in the Balkans and reasserted control over Thrace and Illyria before marching on the city of Thebes, which was subsequently destroyed in battle. Alexander then led the League of Corinth, and used his authority to launch the pan-Hellenic project envisaged by his father, assuming leadership over all Greeks in their conquest of Persia."
# Create questions
questions = {
"Who is Alexander the Great": None,
"Until the age of 16, who tutored Alexander": None,
}
# Answer questions
for question in questions.keys():
answer = question_answerer.answer(context, question)
questions[question] = answer
print(questions)
Running the Code
Run the above code:
$ python3 question_answer.py
This yields the following result:
{
'Who is Alexander the Great': QuestionAnswererResult(answers=[
QaAnswer(pos=Pos(start=12, end=20, logit=-1.621170163154602), text='king of the ancient Greek kingdom of Macedon.'),
QaAnswer(pos=Pos(start=12, end=27, logit=-2.1207242012023926), text='king of the ancient Greek kingdom of Macedon. He succeeded his father Philip II'),
QaAnswer(pos=Pos(start=19, end=20, logit=-3.1698760986328125), text='Macedon.'),
QaAnswer(pos=Pos(start=26, end=27, logit=-3.3418025970458984), text='Philip II'),
QaAnswer(pos=Pos(start=12, end=12, logit=-3.3852314949035645), text='king')]),
'Until the age of 16, who tutored Alexander': QuestionAnswererResult(answers=[
QaAnswer(pos=Pos(start=121, end=121, logit=7.933090686798096), text='Aristotle.'),
QaAnswer(pos=Pos(start=118, end=121, logit=1.3499608039855957), text='tutored by Aristotle.'),
QaAnswer(pos=Pos(start=121, end=122, logit=1.0493016242980957), text='Aristotle.'),
QaAnswer(pos=Pos(start=110, end=121, logit=0.37497782707214355), text='Until the age of 16, Alexander was tutored by Aristotle.'),
QaAnswer(pos=Pos(start=118, end=119, logit=-5.260964870452881), text='tutored')])
}
The returned QuestionAnswererResult objects contain a list of QaAnswer
's that represent a probable answer to the questions posed. pos
marks the relative position of the answer in the context, while text
represents the answer text.
In the first question, five probable answers were returned, with two of them correct, while in the second question - 4 out of 5 probable answers were correct.
Conclusion
This guide covered how to use BERT in TensorFlow by building a lite model for question answering and using the Tflite Support library for question answering within a context.