peterjansen_cv_pandtformatted_nov2023

Using Tensorflow Ranking Bert (TFR-Bert), an end-to-end example

Recently an interesting paper (Han et al., “Learning-to-Rank with BERT in TF-Ranking”) appeared on Arxiv that combines Tensorflow Ranking with BERT to perform ranking (or, re-ranking) on BERT-encoded queries and documents. This seems like a major step in ranking algorithms, combining the power of large language model embeddings with ranking tasks. The paper reports using the TFR-BERT model to rank answers in the MS MARCO dataset and achieve high performance.

The code was recently released, but unfortunately without end-to-end examples for getting a model trained and generating predictions on unseen data. Not being a Tensorflow or Tensorflow Ranking expert, or a frequent Python user, the process of getting it running had many non-obvious steps that took some digging to figure out. Being that I think this is the first successful demonstration of using TFR-BERT end-to-end I could find, I thought I’d generate example code and a short tutorial in the hopes that this helps other folks get running quicker.

A word of note: TFR-BERT appears to have large computational requirements, even for large language models. Estimates of this are described at the bottom of this tutorial.

Preconditions

In this example I’m using a conda environment that has Python 3.7.7, and the Tensorflow (GPU) and other supporting dependencies installed. Here is the requirements.txt for my conda environment.

Before starting, I would recommend reading the instructions in the README.md of the Tensorflow Ranking repository, and installing any additional dependencies (like bazel). The TFR-BERT extensions and README are in the tensorflow_ranking/extension path, and are recommended to read beforehand, too.

Step 1: Clone the example repository

This tutorial uses wrappers, helpers, and example data I’ve put together in an end-to-end TFR-BERT example forked repository together with the official code base, so I would recommend cloning it to get started. Once you’re an expert, you can clone the official Tensorflow Ranking repository here, which includes the TFR-BERT code.

git clone https://github.com/cognitiveailab/ranking.git

Step 2: Download BERT checkpoints in Tensorflow 2 format.

Google recently released checkpoints for smaller versions of BERT that run faster on more modest hardware. These include BERT-Tiny, BERT-Mini, BERT-Small, and BERT-Medium, to complement the existing BERT-Base and BERT-Large models. TFR-BERT is a bit of a heavy model and requires fairly serious computational resources, so like the TFR team I would recommend debugging/developing with smaller checkpoints until you need to scale.

TFR-BERT requires BERT checkpoints in the Tensorflow 2 (TF2) format, which are (as of this writing) a little challenging to find pre-generated as typically the checkpoints are released in TF1 format. Converting between TF1 and TF2 takes a bit of tinkering with the conversion script, or you’re welcome to use the TF2 checkpoints linked below that I’ve converted (though, again, I’m not an expert on Tensorflow, so if you see something amiss, please send me a note).

Step 3: Convert your ranking problems into an appropriate format

Internally, TFR-BERT loads training and evaluation files that are lists of BERT-encoded query-document pairs, that have been converted into their ELWC format, then saved as a TFRecord. It’s a little challenging to do this, so I’ve put together a quick utility and set of helper functions to convert between a simple JSON format and their format.

The input JSON format:

{"rankingProblems": [
    {"queryText": "Where can you buy cat food?",
     "documents": [
        {"relevance": 3, "docText": "The pet food store"},
        {"relevance": 1, "docText": "Bicycles have two wheels"},
        {"relevance": 3, "docText": "The grocery store"},
        {"relevance": 2, "docText": "Cats eat cat food"}
        ]
    },
    {"queryText": "Where can you go swimming?",
     "documents": [
        {"relevance": 2, "docText": "At the lake"},
        {"relevance": 3, "docText": "In a swimming pool"}, 
        {"relevance": 1, "docText": "In a cloud"},
        {"relevance": 1, "docText": "On a pile of rocks"},
        {"relevance": 1, "docText": "In a garden"}        
        ]
    },
    {"queryText": "What helps to build a campfire?",
     "documents": [
        {"relevance": 1, "docText": "Rocks"},
        {"relevance": 2, "docText": "Tinder"}, 
        {"relevance": 3, "docText": "Wood"}, 
        {"relevance": 3, "docText": "Match"},
        {"relevance": 1, "docText": "Potato"},
        {"relevance": 1, "docText": "Can of soup"},
        {"relevance": 1, "docText": "Marshmallow"},
        {"relevance": 1, "docText": "Hot dog"},
        {"relevance": 1, "docText": "Rice"},
        {"relevance": 1, "docText": "Pot and pan"}
        ]
    }
    ]
}

Here, in the JSON input format, rankingProblems is a list of query-document pairs that define each of the ranking problems in your train, development, or test set. Each ranking problem has a query string (queryText), and a list of documents. Each document is an object containing the document text (docText), and a gold relevancy score (relevance) represented as an integer. The document list is unordered, and can be stored in any order (as shown). Higher relevancy scores mean the documents are more relevant for the query. 

I’ve put toy train and evaluation examples in the repository, to help illustrate how you might convert your own data into this JSON format.

Conversion script (JSON to TFRecord):

Once you have the data in JSON format, you need to convert it into the TFRecord format used by TFR-BERT. A conversion script that runs the tool is available here:

 #!/bin/bash
BERT_DIR="/home/peter/github/tensorflow/ranking/uncased_L-12_H-768_A-12_TF2"  && \
python tensorflow_ranking/extension/examples/tfrbert_convert_json_to_elwc.py \
    --vocab_file=${BERT_DIR}/vocab.txt \
    --sequence_length=128 \
    --input_file=/home/peter/github/peter-ranking/ranking/TFRBertExample-eval.json \
    --output_file=eval.toy.elwc.tfrecord \
    --do_lower_case 

The critical bits here are ensuring that BERT_DIR points to the BERT model checkpoint you’re using, that input_file points to your input JSON, and that the output_file is the TFRecord file that you’d like generated. sequence_length should be set to the maximum sequence length your model will be trained on (commonly, 128 tokens), and –do_lower_case should be set if you’re using uncased BERT models. Successfully running this script should output something similar to:

(tfranking-bert) peter@neutronium:~/github/peter-ranking/ranking$ ./tfrbert_convert_json_to_elwc.sh  
2020-09-07 17:34:52.464134: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
Utility to convert between JSON and ELWC for TFR-Bert

Model Parameters:  
Vocabulary filename: /home/peter/github/tensorflow/ranking/uncased_L-12_H-768_A-12_TF2/vocab.txt
sequence_length: 128
do_lower_case: True

Input file:  /home/peter/github/peter-ranking/ranking/TFRBertExample-eval.json
Output file: eval.toy.elwc.tfrecord
Success.

Run this to convert each file (train, development, and test) in your dataset.

Step 4: Train your TFR-BERT model

This step largely proceeds as in the official TFR-BERT documentation, using the tfrbert_example.py training example. Here’s a script to run it:

 #!/bin/bash
#BERT_DIR="/home/peter/github/tensorflow/ranking/uncased_L-4_H-256_A-4_TF2"  && \
BERT_DIR="/home/peter/github/tensorflow/ranking/uncased_L-12_H-768_A-12_TF2"  && \
OUTPUT_DIR="/tmp/tfr/model-petertoy-bertbase/" && \
DATA_DIR="/home/peter/github/peter-ranking/ranking" && \
rm -rf "${OUTPUT_DIR}" && \
bazel build -c opt \
tensorflow_ranking/extension/examples:tfrbert_example_py_binary && \
./bazel-bin/tensorflow_ranking/extension/examples/tfrbert_example_py_binary \
   --train_input_pattern=${DATA_DIR}/train.toy.elwc.tfrecord \
   --eval_input_pattern=${DATA_DIR}/eval.toy.elwc.tfrecord \
   --bert_config_file=${BERT_DIR}/bert_config.json \
   --bert_init_ckpt=${BERT_DIR}/bert_model.ckpt \
   --bert_max_seq_length=128 \
   --model_dir="${OUTPUT_DIR}" \
   --list_size=15 \
   --loss=softmax_loss \
   --train_batch_size=1 \
   --eval_batch_size=1 \
   --learning_rate=1e-5 \
   --num_train_steps=500 \
   --num_eval_steps=10 \
   --checkpoint_secs=500 \
   --num_checkpoints=2

There are a lot of knobs to turn here. For the purposes of this example, the critical bits are that the train_input_pattern and eval_input_pattern match the TFRecord files that you generated from your JSON dataset above, in Step 4. BERT_DIR should point to your BERT model (and, it’s helpful to have a few different/smaller/faster models quickly available during development — commented out, as shown). OUTPUT_DIR is where your model will be saved to — and note that the default script overwrites this directory each time it’s run. Finally, the list_size defines the maximum number of documents in each ranking problem (in the script above, this is set to 15). Increasing this increases the memory requirements of your model (see below for an estimate of memory requirements), so for even modest sized lists, you may find yourself frequently training with a batch size of 1. num_train_steps defines the number of training steps before completion, and this number is often quite high in the few examples I’ve seen (e.g. the example script for running the ANTIQUE dataset in the documentation lists 100000 training steps), so you may expect requiring some serious GPU or TPU hours during training.

After training, you should have several trained models in your ${OUTPUT_DIR}/exports/ folder — typically the most recent model, as well the best model (evaluated in terms of lowest loss). Plenty of output will stream by during training, but the end will likely look something like this:

INFO:tensorflow:SavedModel written to: /tmp/tfr/model-petertoy-testtrain/export/best_model_by_loss/temp-1599526275/saved_model.pb
I0907 17:51:17.281934 139878410606400 builder_impl.py:426] SavedModel written to: /tmp/tfr/model-petertoy-testtrain/export/best_model_by_loss/temp-1599526275/saved_model.pb

INFO:tensorflow:Loss for final step: 12.917309.
I0907 17:51:17.350415 139878410606400 estimator.py:352] Loss for final step: 12.917309.

If you see something different, such as a bunch of errors (particularly if you’re working off the official examples with the ANTIQUE dataset), then you might find a bunch of out-of-memory (OOM) errors when sifting through the output. One way of dealing with these is either reducing the model (e.g. to BERT-Mini or Tiny), reducing the list size, or (of course) finding a system with more GPU memory.

After running, the directory that the model is exported to should look something like the following, where the model files are exported into a numbered directory (the “version number”) of the model:

Step 5: Predictions: Setup a Tensorflow Serving prediction server

The training/evaluation procedure does not generate predictions, and there is no official example on how to perform this. Here we’ll setup a prediction mechanism.

Again, prefacing this by noting that I’m not a Tensorflow expert, there appear to be two methods of generating predictions — (1) directly, by loading the model and using the API to call a predict() method on the Estimator, or (2) indirectly, by using the Tensorflow Serving model server to load your model, then sending queries (and receiving prediction scores) over a socket. The latter seems much more common and supported, so that’s the approach described here.

There are a lot of tutorials for setting up a Tensorflow Serving model server, and they vary depending on your serving preference (CPU vs GPU) and whether you prefer the model server to be in a docker container. I preferred to get up and running quickly, and found this tutorial on installing a model server using apt-get on Ubuntu to be the simplest.

Once you setup Tensorflow Serving, assuming you chose the same method as I did (apt-get, no container), the model server can be started with a script such as this one:

 #!/bin/bash
export MODEL_DIR=/tmp/tfr/model-petertoy-bertbase/export/latest_model/
tensorflow_model_server \
  --rest_api_port=8501 \
  --model_name=tfrbert \
  --model_base_path="${MODEL_DIR}"

A somewhat counter-intuitive step here is that MODEL_DIR should not point directly to the model files (i.e. the version-numbered folder), but rather to the parent folder that contains one or more version number folders that contain the actual model(s). If you run the training step above, an example of this directory structure can be found in ${OUTPUT_DIR}/export/latest_model/ , which is pointed to in the example script.

Step 6: Predictions: Finally generating predictions

With the model server up, we can now connect to it using the client-side prediction example, and generate predictions for our ranking problems. This code takes ranking problems in the JSON format as input, serves each one individually to the model server, and exports a ranked list with document scores added. Thisexample run script shows predictions being generated for both the train and evaluation toy data:

#!/bin/bash
BERT_DIR="/home/peter/github/tensorflow/ranking/uncased_L-12_H-768_A-12_TF2"  && \
python tensorflow_ranking/extension/examples/tfrbert_client_predict_from_json.py \
    --vocab_file=${BERT_DIR}/vocab.txt \
    --sequence_length=128 \
    --input_file=TFRBertExample-train.json \
    --output_file=train.scoresOut.json \
    --do_lower_case 

python tensorflow_ranking/extension/examples/tfrbert_client_predict_from_json.py \
    --vocab_file=${BERT_DIR}/vocab.txt \
    --sequence_length=128 \
    --input_file=TFRBertExample-eval.json \
    --output_file=eval.scoresOut.json \
    --do_lower_case 

And here’s an example of the script running:

(tfranking-bert) peter@neutronium:~/github/peter-ranking/ranking$ ./tfrbert_predict_from_json.sh  
2020-09-07 23:33:13.007881: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
* Running with arguments: Namespace(do_lower_case=True, input_file=’TFRBertExample-train.json’, output_file=’train.scoresOut.json’, sequence_length=128, vocab_file=’/home/peter/github/tensorflow/ranking/uncased_L-12_H-768_A-12_TF2/vocab.txt’)
* Generating predictions for JSON ranking problems (filename: TFRBertExample-train.json)

Predicting 1 / 3 (33.33%)
Predicting 2 / 3 (66.67%)
Predicting 3 / 3 (100.00%)

* exportRankingOutput(): Exporting scores to JSON (train.scoresOut.json)
* Total execution time: 0:00:01.241

2020-09-07 23:33:15.467087: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
* Running with arguments: Namespace(do_lower_case=True, input_file=’TFRBertExample-eval.json’, output_file=’eval.scoresOut.json’, sequence_length=128, vocab_file=’/home/peter/github/tensorflow/ranking/uncased_L-12_H-768_A-12_TF2/vocab.txt’)
* Generating predictions for JSON ranking problems (filename: TFRBertExample-eval.json)

Predicting 1 / 3 (33.33%)
Predicting 2 / 3 (66.67%)
Predicting 3 / 3 (100.00%)

* exportRankingOutput(): Exporting scores to JSON (eval.scoresOut.json)
* Total execution time: 0:00:01.492

The JSON output adds a score for each document generated from the prediction model, and returns the document lists for each query sorted by these predicted scores. Here’s an example output file with predictions on the toy training set, which should of course be very good since we’re training and evaluating on the same data:

 {
    "rankingProblemsOutput": [
        {
            "queryText": "Where can you buy cat food?",
            "documents": [
                {
                    "relevance": 3,
                    "docText": "The grocery store",
                    "score": 0.44388255
                },
                {
                    "relevance": 3,
                    "docText": "The pet food store",
                    "score": 0.40264943
                },
                {
                    "relevance": 2,
                    "docText": "Cats eat cat food",
                    "score": -0.15662411
                },
                {
                    "relevance": 1,
                    "docText": "Bicycles have two wheels",
                    "score": -0.8667503
                }
            ]
        },
        {
            "queryText": "Where can you go swimming?",
            "documents": [
                {
                    "relevance": 3,
                    "docText": "In a swimming pool",
                    "score": 0.48036778
                },
                {
                    "relevance": 2,
                    "docText": "At the lake",
                    "score": -0.10280942
                },
                {
                    "relevance": 1,
                    "docText": "On a pile of rocks",
                    "score": -0.7149895
                },
                {
                    "relevance": 1,
                    "docText": "In a cloud",
                    "score": -0.7245462
                },
                {
                    "relevance": 1,
                    "docText": "In a garden",
                    "score": -0.75645095
                }
            ]
        },
        {
            "queryText": "What helps to build a campfire?",
            "documents": [
                {
                    "relevance": 3,
                    "docText": "Wood",
                    "score": -0.008856705
                },
                {
                    "relevance": 3,
                    "docText": "Match",
                    "score": -0.05323608
                },
                {
                    "relevance": 2,
                    "docText": "Tinder",
                    "score": -0.42123765
                },
                {
                    "relevance": 1,
                    "docText": "Pot and pan",
                    "score": -1.0901607
                },
                {
                    "relevance": 1,
                    "docText": "Rocks",
                    "score": -1.1488856
                },
                {
                    "relevance": 1,
                    "docText": "Rice",
                    "score": -1.1492822
                },
                {
                    "relevance": 1,
                    "docText": "Can of soup",
                    "score": -1.154706
                },
                {
                    "relevance": 1,
                    "docText": "Hot dog",
                    "score": -1.1791945
                },
                {
                    "relevance": 1,
                    "docText": "Potato",
                    "score": -1.2208372
                },
                {
                    "relevance": 1,
                    "docText": "Marshmallow",
                    "score": -1.2553226
                }
            ]
        }
    ]
}

Pre-generated predictions on the toy evaluation set are also available, which illustrate performance on unseen toy data:

 {
    "rankingProblemsOutput": [
        {
            "queryText": "Where can you buy dog food?",
            "documents": [
                {
                    "relevance": 3,
                    "docText": "The pet food store",
                    "score": 0.4045086
                },
                {
                    "relevance": 1,
                    "docText": "Cars are for driving place to place",
                    "score": -0.087259516
                },
                {
                    "relevance": 2,
                    "docText": "Dogs eat dog food",
                    "score": -0.14085332
                },
                {
                    "relevance": 1,
                    "docText": "Red strawberries grow on strawberry plants",
                    "score": -0.97852093
                }
            ]
        },
        {
            "queryText": "Where can you go rock climbing?",
            "documents": [
                {
                    "relevance": 3,
                    "docText": "In a climbing gym",
                    "score": 0.14498216
                },
                {
                    "relevance": 1,
                    "docText": "In a swimming pool",
                    "score": 0.14445521
                },
                {
                    "relevance": 1,
                    "docText": "At the lake",
                    "score": -0.2063725
                },
                {
                    "relevance": 3,
                    "docText": "On a mountain cliff",
                    "score": -0.23254125
                },
                {
                    "relevance": 1,
                    "docText": "On a pile of rocks",
                    "score": -0.59751
                },
                {
                    "relevance": 1,
                    "docText": "In a cloud",
                    "score": -0.73646384
                },
                {
                    "relevance": 1,
                    "docText": "In a garden",
                    "score": -0.77151793
                }
            ]
        },
        {
            "queryText": "What parts are most important for a computer?",
            "documents": [
                {
                    "relevance": 3,
                    "docText": "Hard drive",
                    "score": -0.16578771
                },
                {
                    "relevance": 3,
                    "docText": "CPU",
                    "score": -0.27822962
                },
                {
                    "relevance": 3,
                    "docText": "Keyboard",
                    "score": -0.34836236
                },
                {
                    "relevance": 2,
                    "docText": "Printer",
                    "score": -0.44954214
                },
                {
                    "relevance": 1,
                    "docText": "Soldering iron",
                    "score": -0.44981638
                },
                {
                    "relevance": 2,
                    "docText": "Scanner",
                    "score": -0.512414
                },
                {
                    "relevance": 1,
                    "docText": "Street",
                    "score": -0.68731207
                },
                {
                    "relevance": 1,
                    "docText": "Mouse",
                    "score": -0.696253
                },
                {
                    "relevance": 2,
                    "docText": "Monitor",
                    "score": -0.74959594
                },
                {
                    "relevance": 1,
                    "docText": "Lamp",
                    "score": -0.9156646
                },
                {
                    "relevance": 1,
                    "docText": "Tree",
                    "score": -0.9277943
                },
                {
                    "relevance": 1,
                    "docText": "Couch",
                    "score": -0.9586045
                },
                {
                    "relevance": 1,
                    "docText": "Electronics factory",
                    "score": -1.0005659
                }
            ]
        }
    ]
}

That’s it! Congratulations, if you’ve made it this far, then you’ve successfully taken the first steps in using TFR-BERT to generate ranked lists. You can now translate your own data into the JSON format, and write an importer to use the ranked lists in your downstream tasks.

Frequently Asked Questions

Q: What are the memory requirements of the model as they relate to list size?

As of writing, I’m not aware of any official requirements. I’ve empirically found that on a 24GB Titan RTX (batch size=1, sequence length=128 tokens, BERT-base-uncased), I’m able to fit a list of approximately 80 documents in GPU memory before receiving out of memory errors during training. Using this as a guide, we can assume:

  • Lists of up to 35 documents per 11GB card (e.g. RTX 2080 Ti)
  • Lists of up to 80 documents per 24GB card (e.g. Titan RTX, 3090 RTX)
  • Lists of up to 105 documents per 32GB card (e.g. V100)
  • Lists of up to 130 documents per 40GB card (e.g. A100)
  • Lists of up to 425 documents per 4x32GB cards (e.g. 4xV100)
  • Lists of up to 530 documents per 4x40GB cards (e.g. 4xA100)

Q: What are the training time requirements of the model as they relate to list size?

There are no official benchmarks as of writing time, but empirically the training time appears to scale linearly with list size. The above graph shows total train runtime (training, evaluation for one cycle, and model export) per 100 training cycles for toy data of different list sizes on a Titan RTX (batch size=1, sequence length=128 tokens), on differently sized BERT models, on a workstation with very fast M2 SSD I/O. Using this, we can roughly gauge that:

  • BERT-Tiny is about 8-10X as fast as BERT-Base
  • BERT-Medium is about 2X as fast as BERT-BASE

The Google TFR-BERT paper describes their MS MARCO experiment using Google TPU V3s and reranking lists of size 12 (with BERT-Large), so it’s likely that TFR-BERT’s computational requirements are somewhat steep, particularly for large datasets, model sizes, and list sizes. For reference in estimating possible model runtime, the TFR-BERT team mentioned their MS MARCO experiment was fine-tuned on only 5% of the corpus, and that each experiment took approximately 1 day to 1 week to complete. If this was on a 32-Core TPU V3 pod, we might ballpark the training cost (as of writing) to replicate these models as anywhere between $1k to $8k per experiment, and perhaps between $20k-$50k to replicate the entire paper — definitely interesting work on a large dataset with a lot of compute behind it.

Q: How do I generate evaluation scores (MAP, NDCG, etc) for a given train/dev/test set?

You may either wish to use the evaluation metric code built into Tensorflow, or write your own scorer that takes the ranked prediction output generated in this tutorial as input.

Acknowledgements

The TFR-BERT authors have been kind to answer a number of e-mails and github issues in my process of figuring out how to create a working end-to-end example. Alexander Zagniotov has spent quite a bit of time with Tensorflow Ranking, and has kindly posted detailed github issue responses and code snippits that helped create this tutorial.