Passage Ranking using TFR-BERT

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlow Ranking can handle heterogeneous dense and sparse features, and scales up to millions of data points. However, building and deploying a learning to rank model to operate at scale creates additional challenges beyond simply designing a model. The Ranking library provides workflow utility classes for building distributed training for large-scale ranking applications. For more information about these features, see the TensorFlow Ranking Overview.

This tutorial shows you how to build a ranking model that uses BERT for scoring. BERT is a highly effective pretrained module to effective encode textual features into contextualized word embeddings. We use BERT to initialize the ranking model and finetune the model using a ranking loss.

ANTIQUE dataset

In this tutorial, you will build a ranking model for ANTIQUE, a question-answering dataset using BERT as the scoring function. Bidirectional Encoder Representations from Transformers (BERT) is a transformer-based machine learning technique which has proven to be effective in many natural language processing (NLP) tasks. Recent work on TFR-BERT has shown BERT to be an effective scoring function for learning-to-rank tasks.

Given a query, and a list of answers, the objective of the ranking model is to rank the answers with optimal rank related metrics, such as NDCG. For more details about ranking metrics, review evaluation measures offline metrics.

ANTIQUE is a publicly available dataset for open-domain non-factoid question answering, collected from Yahoo! answers. Each question has a list of answers, whose relevance are graded on a scale of 0-4, 0 for irrelevant and 4 for fully relevant. The list size can vary depending on the query, so we use a fixed "list size" of 50, where the list is either truncated or padded with default values. The dataset is split into 2206 queries for training and 200 queries for testing. For more details, please read the technical paper on arXiv.

Setup

Download and install the TensorFlow Ranking and TensorFlow Model Garden packages.

pip install -q tensorflow-ranking tf-models-official

Import TensorFlow Ranking and useful libraries through the notebook.

import os
import tensorflow as tf
import tensorflow_ranking as tfr
from official.nlp.configs import encoders
from tensorflow_ranking.extension.premade import tfrbert_task
2022-12-14 12:15:38.646771: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:15:38.646870: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:15:38.646879: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.

Data preparation

Download training and test data.

wget -O "/tmp/train.tfrecords" "https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_train_seq_64_elwc.tfrecords"
wget -O "/tmp/test.tfrecords" "https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_test_seq_64_elwc.tfrecords"
--2022-12-14 12:15:40--  https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_train_seq_64_elwc.tfrecords
Resolving ciir.cs.umass.edu (ciir.cs.umass.edu)... 128.119.246.154
Connecting to ciir.cs.umass.edu (ciir.cs.umass.edu)|128.119.246.154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 8743528 (8.3M)
Saving to: ‘/tmp/train.tfrecords’

/tmp/train.tfrecord 100%[===================>]   8.34M  12.0MB/s    in 0.7s    

2022-12-14 12:15:41 (12.0 MB/s) - ‘/tmp/train.tfrecords’ saved [8743528/8743528]

--2022-12-14 12:15:41--  https://ciir.cs.umass.edu/downloads/Antique/tf-ranking/antique_test_seq_64_elwc.tfrecords
Resolving ciir.cs.umass.edu (ciir.cs.umass.edu)... 128.119.246.154
Connecting to ciir.cs.umass.edu (ciir.cs.umass.edu)|128.119.246.154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 692072 (676K)
Saving to: ‘/tmp/test.tfrecords’

/tmp/test.tfrecords 100%[===================>] 675.85K  3.93MB/s    in 0.2s    

2022-12-14 12:15:41 (3.93 MB/s) - ‘/tmp/test.tfrecords’ saved [692072/692072]
mkdir -p /tmp/tfrbert
wget "https://storage.googleapis.com/cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12.tar.gz" -P "/tmp/tfrbert"
mkdir -p /tmp/tfrbert/uncased_L-12_H-768_A-12
tar -xvf /tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz --strip-components 3 -C "/tmp/tfrbert/uncased_L-12_H-768_A-12/"
--2022-12-14 12:15:41--  https://storage.googleapis.com/cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.196.128, 142.251.107.128, 142.250.97.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.196.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405351189 (387M) [application/octet-stream]
Saving to: ‘/tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz’

uncased_L-12_H-768_ 100%[===================>] 386.57M   143MB/s    in 2.7s    

2022-12-14 12:15:44 (143 MB/s) - ‘/tmp/tfrbert/uncased_L-12_H-768_A-12.tar.gz’ saved [405351189/405351189]

tmp/temp_dir/raw/vocab.txt
tmp/temp_dir/raw/bert_model.ckpt.index
tmp/temp_dir/raw/bert_model.ckpt.data-00000-of-00001
tmp/temp_dir/raw/bert_config.json

Overview of TFR-BERT in Orbit

BERT-based ranking models (TFR-BERT) have been shown to be effective for learning-to-rank tasks when using raw textual features for query and passages in MSMARCO passage ranking dataset.

Orbit is a flexible, lightweight library designed to make it easy to write custom training loops in TensorFlow. TensorFlow Ranking provides support for implementing ranking models, particularly for BERT based ranking models using Orbit.

Create a Ranking Task for TFR-BERT

We create a ranking task for TFR-BERT model which can be trained using Orbit. The steps to build this are:

  1. Define Feature Specifications
  2. Define datasets
  3. Setup data and task configurations

Specify Features

Feature Specification are TensorFlow abstractions to capture information about each feature. These help developers and model researchers understand and use a model.

Create feature specifications for context features, example features, and labels, consistent with the input formats for ranking, such as ELWC format.

SEQ_LENGTH = 64
context_feature_spec = {}
example_feature_spec = {
    'input_word_ids': tf.io.FixedLenFeature(
        shape=(SEQ_LENGTH,), dtype=tf.int64,
        default_value=[0] * SEQ_LENGTH),
    'input_mask': tf.io.FixedLenFeature(
        shape=(SEQ_LENGTH,), dtype=tf.int64,
        default_value=[0] * SEQ_LENGTH),
    'input_type_ids': tf.io.FixedLenFeature(
        shape=(SEQ_LENGTH,), dtype=tf.int64,
        default_value=[0] * SEQ_LENGTH)}
label_spec = (
    "relevance",
    tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64, default_value=-1)
)

Define Datasets

We define data configurations for training and validation data, which specifies parameters such as path, batch size, and dataset format. These configurations are used to create training and validation datasets.

# Set up data config
# We use a small list size here for demo purposes only. Users can use a larger
# list size on a machine with more memory to train TFR-BERT.
train_data_config = tfrbert_task.TFRBertDataConfig(
    input_path="/tmp/train.tfrecords",
    is_training=True,
    global_batch_size=8,
    list_size=2,
    dataset_fn='tfrecord',
    seq_length=64)

validation_data_config = tfrbert_task.TFRBertDataConfig(
    input_path="/tmp/test.tfrecords",
    is_training=False,
    global_batch_size=8,
    list_size=2,
    dataset_fn='tfrecord',
    seq_length=64)

Define Task

We define a task configuration which defines the training and validation dataset along with the model. This configuration creates a TFRBertTask object that can be trained using Orbit.

# Set up task config
task_config = tfrbert_task.TFRBertConfig(
    init_checkpoint='/tmp/tfrbert/uncased_L-12_H-768_A-12/bert_model.ckpt',
    train_data=train_data_config,
    validation_data=validation_data_config,
    model=tfrbert_task.TFRBertModelConfig(
        encoder=encoders.EncoderConfig(
            bert=encoders.BertEncoderConfig(num_layers=12))))

# Set up TFRBertTask
task = tfrbert_task.TFRBertTask(
    task_config,
    label_spec=label_spec,
    dataset_fn=tf.data.TFRecordDataset,
    logging_dir='/tmp/model_dir')

Train and evaluate the model

We define the training loop here to train and evaluate the model. We define the metrics, create train and eval datasets and train the model for a specific number of training steps.

metrics = task.build_metrics()
model = task.build_model()
task.initialize(model)
train_dataset = task.build_inputs(task_config.train_data)
vali_dataset = task.build_inputs(task_config.validation_data)
train_iterator = iter(train_dataset)
vali_iterator = iter(vali_dataset)
optimizer = tf.keras.optimizers.Adam(lr=1e-6)

NUM_TRAIN_STEPS = 100
EVAL_STEPS = 10
for train_step in range(NUM_TRAIN_STEPS):
  task.train_step(next(train_iterator), model, optimizer, metrics=metrics)
  train_metrics = {m.name: m.result().numpy() for m in metrics}
  print("Training metrics for epoch: " + str(train_step) + " ", train_metrics)

  if train_step % EVAL_STEPS == 0:
    task.validation_step(next(train_iterator), model, metrics=metrics)
    vali_metrics = {m.name: m.result().numpy() for m in metrics}
    print("Validation metrics for epoch: " + str(train_step) + " ",
          vali_metrics)
2022-12-14 12:15:49.221104: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23.
Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
WARNING:absl:`lr` is deprecated, please use `learning_rate` instead, or use the legacy optimizer, e.g.,tf.keras.optimizers.legacy.Adam.
Training metrics for epoch: 0  {'MAP': 0.9375, 'NDCG@1': 0.73214287, 'NDCG@5': 0.912364, 'NDCG@10': 0.912364, 'MRR@1': 0.875, 'MRR@5': 0.9375, 'MRR@10': 0.9375}
Validation metrics for epoch: 0  {'MAP': 0.96875, 'NDCG@1': 0.66369045, 'NDCG@5': 0.89421266, 'NDCG@10': 0.89421266, 'MRR@1': 0.9375, 'MRR@5': 0.96875, 'MRR@10': 0.96875}
Training metrics for epoch: 1  {'MAP': 0.9583333, 'NDCG@1': 0.6507936, 'NDCG@5': 0.88817185, 'NDCG@10': 0.88817185, 'MRR@1': 0.9166667, 'MRR@5': 0.9583333, 'MRR@10': 0.9583333}
Training metrics for epoch: 2  {'MAP': 0.96875, 'NDCG@1': 0.6577381, 'NDCG@5': 0.89149714, 'NDCG@10': 0.89149714, 'MRR@1': 0.9375, 'MRR@5': 0.96875, 'MRR@10': 0.96875}
Training metrics for epoch: 3  {'MAP': 0.975, 'NDCG@1': 0.68095237, 'NDCG@5': 0.89981496, 'NDCG@10': 0.89981496, 'MRR@1': 0.95, 'MRR@5': 0.975, 'MRR@10': 0.975}
Training metrics for epoch: 4  {'MAP': 0.9791667, 'NDCG@1': 0.71031743, 'NDCG@5': 0.9095955, 'NDCG@10': 0.9095955, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 5  {'MAP': 0.98214287, 'NDCG@1': 0.7091837, 'NDCG@5': 0.9085163, 'NDCG@10': 0.9085163, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}
Training metrics for epoch: 6  {'MAP': 0.9765625, 'NDCG@1': 0.68526787, 'NDCG@5': 0.8999288, 'NDCG@10': 0.8999288, 'MRR@1': 0.953125, 'MRR@5': 0.9765625, 'MRR@10': 0.9765625}
Training metrics for epoch: 7  {'MAP': 0.9791667, 'NDCG@1': 0.7030423, 'NDCG@5': 0.90591866, 'NDCG@10': 0.90591866, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 8  {'MAP': 0.98125, 'NDCG@1': 0.7255953, 'NDCG@5': 0.9132517, 'NDCG@10': 0.9132517, 'MRR@1': 0.9625, 'MRR@5': 0.98125, 'MRR@10': 0.98125}
Training metrics for epoch: 9  {'MAP': 0.97727275, 'NDCG@1': 0.7229437, 'NDCG@5': 0.9117598, 'NDCG@10': 0.9117598, 'MRR@1': 0.95454544, 'MRR@5': 0.97727275, 'MRR@10': 0.97727275}
Training metrics for epoch: 10  {'MAP': 0.9791667, 'NDCG@1': 0.7311508, 'NDCG@5': 0.91436106, 'NDCG@10': 0.91436106, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Validation metrics for epoch: 10  {'MAP': 0.97596157, 'NDCG@1': 0.7339744, 'NDCG@5': 0.9146096, 'NDCG@10': 0.9146096, 'MRR@1': 0.9519231, 'MRR@5': 0.97596157, 'MRR@10': 0.97596157}
Training metrics for epoch: 11  {'MAP': 0.9776786, 'NDCG@1': 0.73511904, 'NDCG@5': 0.9151535, 'NDCG@10': 0.9151535, 'MRR@1': 0.95535713, 'MRR@5': 0.9776786, 'MRR@10': 0.9776786}
Training metrics for epoch: 12  {'MAP': 0.975, 'NDCG@1': 0.7253969, 'NDCG@5': 0.91220075, 'NDCG@10': 0.91220075, 'MRR@1': 0.95, 'MRR@5': 0.975, 'MRR@10': 0.975}
Training metrics for epoch: 13  {'MAP': 0.9765625, 'NDCG@1': 0.73363096, 'NDCG@5': 0.9150943, 'NDCG@10': 0.9150943, 'MRR@1': 0.953125, 'MRR@5': 0.9765625, 'MRR@10': 0.9765625}
Training metrics for epoch: 14  {'MAP': 0.97794116, 'NDCG@1': 0.7366947, 'NDCG@5': 0.91642684, 'NDCG@10': 0.91642684, 'MRR@1': 0.9558824, 'MRR@5': 0.97794116, 'MRR@10': 0.97794116}
Training metrics for epoch: 15  {'MAP': 0.9791667, 'NDCG@1': 0.7433862, 'NDCG@5': 0.9187641, 'NDCG@10': 0.9187641, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 16  {'MAP': 0.9769737, 'NDCG@1': 0.73903507, 'NDCG@5': 0.91679335, 'NDCG@10': 0.91679335, 'MRR@1': 0.95394737, 'MRR@5': 0.9769737, 'MRR@10': 0.9769737}
Training metrics for epoch: 17  {'MAP': 0.978125, 'NDCG@1': 0.7407738, 'NDCG@5': 0.91760796, 'NDCG@10': 0.91760796, 'MRR@1': 0.95625, 'MRR@5': 0.978125, 'MRR@10': 0.978125}
Training metrics for epoch: 18  {'MAP': 0.9791667, 'NDCG@1': 0.73781174, 'NDCG@5': 0.9168396, 'NDCG@10': 0.9168396, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 19  {'MAP': 0.9801136, 'NDCG@1': 0.7464827, 'NDCG@5': 0.91967636, 'NDCG@10': 0.91967636, 'MRR@1': 0.96022725, 'MRR@5': 0.9801136, 'MRR@10': 0.9801136}
Training metrics for epoch: 20  {'MAP': 0.9782609, 'NDCG@1': 0.7380952, 'NDCG@5': 0.91687906, 'NDCG@10': 0.91687906, 'MRR@1': 0.95652175, 'MRR@5': 0.9782609, 'MRR@10': 0.9782609}
Validation metrics for epoch: 20  {'MAP': 0.9791667, 'NDCG@1': 0.74900794, 'NDCG@5': 0.92034245, 'NDCG@10': 0.92034245, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 21  {'MAP': 0.98, 'NDCG@1': 0.7561905, 'NDCG@5': 0.9226987, 'NDCG@10': 0.9226987, 'MRR@1': 0.96, 'MRR@5': 0.98, 'MRR@10': 0.98}
Training metrics for epoch: 22  {'MAP': 0.9807692, 'NDCG@1': 0.7545787, 'NDCG@5': 0.92208344, 'NDCG@10': 0.92208344, 'MRR@1': 0.96153843, 'MRR@5': 0.9807692, 'MRR@10': 0.9807692}
Training metrics for epoch: 23  {'MAP': 0.9814815, 'NDCG@1': 0.75837743, 'NDCG@5': 0.9234321, 'NDCG@10': 0.9234321, 'MRR@1': 0.962963, 'MRR@5': 0.9814815, 'MRR@10': 0.9814815}
Training metrics for epoch: 24  {'MAP': 0.98214287, 'NDCG@1': 0.7568027, 'NDCG@5': 0.9228346, 'NDCG@10': 0.9228346, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}
Training metrics for epoch: 25  {'MAP': 0.98275864, 'NDCG@1': 0.7627258, 'NDCG@5': 0.9247799, 'NDCG@10': 0.9247799, 'MRR@1': 0.9655172, 'MRR@5': 0.98275864, 'MRR@10': 0.98275864}
Training metrics for epoch: 26  {'MAP': 0.98333335, 'NDCG@1': 0.76468253, 'NDCG@5': 0.9253864, 'NDCG@10': 0.9253864, 'MRR@1': 0.96666664, 'MRR@5': 0.98333335, 'MRR@10': 0.98333335}
Training metrics for epoch: 27  {'MAP': 0.983871, 'NDCG@1': 0.765361, 'NDCG@5': 0.9257851, 'NDCG@10': 0.9257851, 'MRR@1': 0.9677419, 'MRR@5': 0.983871, 'MRR@10': 0.983871}
Training metrics for epoch: 28  {'MAP': 0.984375, 'NDCG@1': 0.7671131, 'NDCG@5': 0.92632234, 'NDCG@10': 0.92632234, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}
Training metrics for epoch: 29  {'MAP': 0.98295456, 'NDCG@1': 0.7638889, 'NDCG@5': 0.92527056, 'NDCG@10': 0.92527056, 'MRR@1': 0.96590906, 'MRR@5': 0.98295456, 'MRR@10': 0.98295456}
Training metrics for epoch: 30  {'MAP': 0.9834559, 'NDCG@1': 0.7708334, 'NDCG@5': 0.9274685, 'NDCG@10': 0.9274685, 'MRR@1': 0.9669118, 'MRR@5': 0.9834559, 'MRR@10': 0.9834559}
Validation metrics for epoch: 30  {'MAP': 0.98392856, 'NDCG@1': 0.76513606, 'NDCG@5': 0.9256894, 'NDCG@10': 0.9262823, 'MRR@1': 0.9678571, 'MRR@5': 0.98392856, 'MRR@10': 0.98392856}
Training metrics for epoch: 31  {'MAP': 0.984375, 'NDCG@1': 0.765377, 'NDCG@5': 0.92589486, 'NDCG@10': 0.9264713, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}
Training metrics for epoch: 32  {'MAP': 0.9831081, 'NDCG@1': 0.765444, 'NDCG@5': 0.92567044, 'NDCG@10': 0.92623127, 'MRR@1': 0.9662162, 'MRR@5': 0.9831081, 'MRR@10': 0.9831081}
Training metrics for epoch: 33  {'MAP': 0.98355263, 'NDCG@1': 0.768797, 'NDCG@5': 0.92667186, 'NDCG@10': 0.92721796, 'MRR@1': 0.96710527, 'MRR@5': 0.98355263, 'MRR@10': 0.98355263}
Training metrics for epoch: 34  {'MAP': 0.98397434, 'NDCG@1': 0.7728938, 'NDCG@5': 0.92802, 'NDCG@10': 0.9285521, 'MRR@1': 0.96794873, 'MRR@5': 0.98397434, 'MRR@10': 0.98397434}
Training metrics for epoch: 35  {'MAP': 0.984375, 'NDCG@1': 0.775, 'NDCG@5': 0.92878187, 'NDCG@10': 0.92930067, 'MRR@1': 0.96875, 'MRR@5': 0.984375, 'MRR@10': 0.984375}
Training metrics for epoch: 36  {'MAP': 0.9832317, 'NDCG@1': 0.77482575, 'NDCG@5': 0.92850894, 'NDCG@10': 0.9290151, 'MRR@1': 0.9664634, 'MRR@5': 0.9832317, 'MRR@10': 0.9832317}
Training metrics for epoch: 37  {'MAP': 0.98214287, 'NDCG@1': 0.770975, 'NDCG@5': 0.9271499, 'NDCG@10': 0.927644, 'MRR@1': 0.96428573, 'MRR@5': 0.98214287, 'MRR@10': 0.98214287}
Training metrics for epoch: 38  {'MAP': 0.98255813, 'NDCG@1': 0.7743632, 'NDCG@5': 0.9282532, 'NDCG@10': 0.9287358, 'MRR@1': 0.96511626, 'MRR@5': 0.98255813, 'MRR@10': 0.98255813}
Training metrics for epoch: 39  {'MAP': 0.98295456, 'NDCG@1': 0.7746212, 'NDCG@5': 0.928235, 'NDCG@10': 0.9287066, 'MRR@1': 0.96590906, 'MRR@5': 0.98295456, 'MRR@10': 0.98295456}
Training metrics for epoch: 40  {'MAP': 0.98194444, 'NDCG@1': 0.77685183, 'NDCG@5': 0.9288045, 'NDCG@10': 0.9292657, 'MRR@1': 0.9638889, 'MRR@5': 0.98194444, 'MRR@10': 0.98194444}
Validation metrics for epoch: 40  {'MAP': 0.98097825, 'NDCG@1': 0.7758799, 'NDCG@5': 0.9284471, 'NDCG@10': 0.9288983, 'MRR@1': 0.9619565, 'MRR@5': 0.98233694, 'MRR@10': 0.98233694}
Training metrics for epoch: 41  {'MAP': 0.98138297, 'NDCG@1': 0.7730496, 'NDCG@5': 0.927543, 'NDCG@10': 0.92798454, 'MRR@1': 0.96276593, 'MRR@5': 0.98271275, 'MRR@10': 0.98271275}
Training metrics for epoch: 42  {'MAP': 0.9791667, 'NDCG@1': 0.7703373, 'NDCG@5': 0.9263746, 'NDCG@10': 0.9268069, 'MRR@1': 0.9583333, 'MRR@5': 0.98046875, 'MRR@10': 0.98046875}
Training metrics for epoch: 43  {'MAP': 0.97959185, 'NDCG@1': 0.77502424, 'NDCG@5': 0.9278771, 'NDCG@10': 0.9283007, 'MRR@1': 0.9591837, 'MRR@5': 0.9808673, 'MRR@10': 0.9808673}
Training metrics for epoch: 44  {'MAP': 0.97875, 'NDCG@1': 0.7705952, 'NDCG@5': 0.9264264, 'NDCG@10': 0.92684144, 'MRR@1': 0.9575, 'MRR@5': 0.98, 'MRR@10': 0.98}
Training metrics for epoch: 45  {'MAP': 0.9791667, 'NDCG@1': 0.7708916, 'NDCG@5': 0.9264465, 'NDCG@10': 0.9268534, 'MRR@1': 0.9583333, 'MRR@5': 0.98039216, 'MRR@10': 0.98039216}
Training metrics for epoch: 46  {'MAP': 0.9795673, 'NDCG@1': 0.7732371, 'NDCG@5': 0.9271634, 'NDCG@10': 0.9275625, 'MRR@1': 0.95913464, 'MRR@5': 0.9807692, 'MRR@10': 0.9807692}
Training metrics for epoch: 47  {'MAP': 0.9799528, 'NDCG@1': 0.77279866, 'NDCG@5': 0.9270702, 'NDCG@10': 0.92746174, 'MRR@1': 0.9599057, 'MRR@5': 0.9811321, 'MRR@10': 0.9811321}
Training metrics for epoch: 48  {'MAP': 0.9780093, 'NDCG@1': 0.7718253, 'NDCG@5': 0.92525107, 'NDCG@10': 0.9256354, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 49  {'MAP': 0.9784091, 'NDCG@1': 0.77077913, 'NDCG@5': 0.925101, 'NDCG@10': 0.9254783, 'MRR@1': 0.9590909, 'MRR@5': 0.9795455, 'MRR@10': 0.9795455}
Training metrics for epoch: 50  {'MAP': 0.9776786, 'NDCG@1': 0.7694515, 'NDCG@5': 0.92459637, 'NDCG@10': 0.92496693, 'MRR@1': 0.95758927, 'MRR@5': 0.97879463, 'MRR@10': 0.97879463}
Validation metrics for epoch: 50  {'MAP': 0.9780702, 'NDCG@1': 0.7703634, 'NDCG@5': 0.9249188, 'NDCG@10': 0.92528284, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 51  {'MAP': 0.9773707, 'NDCG@1': 0.7690887, 'NDCG@5': 0.9244347, 'NDCG@10': 0.9247925, 'MRR@1': 0.95689654, 'MRR@5': 0.9784483, 'MRR@10': 0.9784483}
Training metrics for epoch: 52  {'MAP': 0.97775424, 'NDCG@1': 0.76432604, 'NDCG@5': 0.9228256, 'NDCG@10': 0.9231773, 'MRR@1': 0.9576271, 'MRR@5': 0.9788136, 'MRR@10': 0.9788136}
Training metrics for epoch: 53  {'MAP': 0.978125, 'NDCG@1': 0.7664682, 'NDCG@5': 0.9235073, 'NDCG@10': 0.9238531, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 54  {'MAP': 0.9784836, 'NDCG@1': 0.7661983, 'NDCG@5': 0.9234862, 'NDCG@10': 0.9238264, 'MRR@1': 0.9590164, 'MRR@5': 0.9795082, 'MRR@10': 0.9795082}
Training metrics for epoch: 55  {'MAP': 0.97883064, 'NDCG@1': 0.76766515, 'NDCG@5': 0.92405087, 'NDCG@10': 0.92438555, 'MRR@1': 0.9596774, 'MRR@5': 0.9798387, 'MRR@10': 0.9798387}
Training metrics for epoch: 56  {'MAP': 0.9791667, 'NDCG@1': 0.7696523, 'NDCG@5': 0.9246806, 'NDCG@10': 0.92501, 'MRR@1': 0.96031743, 'MRR@5': 0.98015875, 'MRR@10': 0.98015875}
Training metrics for epoch: 57  {'MAP': 0.9785156, 'NDCG@1': 0.7679501, 'NDCG@5': 0.92400306, 'NDCG@10': 0.9243273, 'MRR@1': 0.9589844, 'MRR@5': 0.9794922, 'MRR@10': 0.9794922}
Training metrics for epoch: 58  {'MAP': 0.97884613, 'NDCG@1': 0.76767397, 'NDCG@5': 0.92397565, 'NDCG@10': 0.92429495, 'MRR@1': 0.9596154, 'MRR@5': 0.9798077, 'MRR@10': 0.9798077}
Training metrics for epoch: 59  {'MAP': 0.9782197, 'NDCG@1': 0.7655122, 'NDCG@5': 0.92325014, 'NDCG@10': 0.92356455, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 60  {'MAP': 0.97761196, 'NDCG@1': 0.76341504, 'NDCG@5': 0.92254627, 'NDCG@10': 0.92285603, 'MRR@1': 0.95708954, 'MRR@5': 0.9785448, 'MRR@10': 0.9785448}
Validation metrics for epoch: 60  {'MAP': 0.97794116, 'NDCG@1': 0.76216733, 'NDCG@5': 0.9222364, 'NDCG@10': 0.9227017, 'MRR@1': 0.9577206, 'MRR@5': 0.9788603, 'MRR@10': 0.9788603}
Training metrics for epoch: 61  {'MAP': 0.9782609, 'NDCG@1': 0.7623361, 'NDCG@5': 0.9223936, 'NDCG@10': 0.9228522, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 62  {'MAP': 0.9785714, 'NDCG@1': 0.76164967, 'NDCG@5': 0.92231655, 'NDCG@10': 0.92276853, 'MRR@1': 0.9589286, 'MRR@5': 0.9794643, 'MRR@10': 0.9794643}
Training metrics for epoch: 63  {'MAP': 0.97711265, 'NDCG@1': 0.75796443, 'NDCG@5': 0.9210157, 'NDCG@10': 0.9214613, 'MRR@1': 0.9559859, 'MRR@5': 0.97799295, 'MRR@10': 0.97799295}
Training metrics for epoch: 64  {'MAP': 0.9765625, 'NDCG@1': 0.7581018, 'NDCG@5': 0.9209682, 'NDCG@10': 0.9214076, 'MRR@1': 0.9548611, 'MRR@5': 0.9774306, 'MRR@10': 0.9774306}
Training metrics for epoch: 65  {'MAP': 0.97602737, 'NDCG@1': 0.7577462, 'NDCG@5': 0.9208503, 'NDCG@10': 0.92128366, 'MRR@1': 0.9537671, 'MRR@5': 0.9768836, 'MRR@10': 0.9768836}
Training metrics for epoch: 66  {'MAP': 0.9755068, 'NDCG@1': 0.75546974, 'NDCG@5': 0.9200356, 'NDCG@10': 0.92046314, 'MRR@1': 0.9527027, 'MRR@5': 0.9763514, 'MRR@10': 0.9763514}
Training metrics for epoch: 67  {'MAP': 0.97583336, 'NDCG@1': 0.7563492, 'NDCG@5': 0.9203415, 'NDCG@10': 0.9207634, 'MRR@1': 0.9533333, 'MRR@5': 0.9766667, 'MRR@10': 0.9766667}
Training metrics for epoch: 68  {'MAP': 0.9761513, 'NDCG@1': 0.7581454, 'NDCG@5': 0.9209124, 'NDCG@10': 0.9213287, 'MRR@1': 0.95394737, 'MRR@5': 0.9769737, 'MRR@10': 0.9769737}
Training metrics for epoch: 69  {'MAP': 0.97564936, 'NDCG@1': 0.7578077, 'NDCG@5': 0.92080134, 'NDCG@10': 0.92121226, 'MRR@1': 0.9529221, 'MRR@5': 0.97646105, 'MRR@10': 0.97646105}
Training metrics for epoch: 70  {'MAP': 0.97596157, 'NDCG@1': 0.75862336, 'NDCG@5': 0.92108566, 'NDCG@10': 0.92149127, 'MRR@1': 0.95352566, 'MRR@5': 0.97676283, 'MRR@10': 0.97676283}
Validation metrics for epoch: 70  {'MAP': 0.97626585, 'NDCG@1': 0.75761, 'NDCG@5': 0.9207071, 'NDCG@10': 0.9211076, 'MRR@1': 0.9541139, 'MRR@5': 0.977057, 'MRR@10': 0.977057}
Training metrics for epoch: 71  {'MAP': 0.9765625, 'NDCG@1': 0.75796133, 'NDCG@5': 0.9209201, 'NDCG@10': 0.92131555, 'MRR@1': 0.9546875, 'MRR@5': 0.97734374, 'MRR@10': 0.97734374}
Training metrics for epoch: 72  {'MAP': 0.9768519, 'NDCG@1': 0.75727516, 'NDCG@5': 0.92068696, 'NDCG@10': 0.9210776, 'MRR@1': 0.9552469, 'MRR@5': 0.97762346, 'MRR@10': 0.97762346}
Training metrics for epoch: 73  {'MAP': 0.97713417, 'NDCG@1': 0.7589286, 'NDCG@5': 0.9212119, 'NDCG@10': 0.9215977, 'MRR@1': 0.95579267, 'MRR@5': 0.97789633, 'MRR@10': 0.97789633}
Training metrics for epoch: 74  {'MAP': 0.97740966, 'NDCG@1': 0.7618331, 'NDCG@5': 0.92216116, 'NDCG@10': 0.92254233, 'MRR@1': 0.9563253, 'MRR@5': 0.97816265, 'MRR@10': 0.97816265}
Training metrics for epoch: 75  {'MAP': 0.9776786, 'NDCG@1': 0.761267, 'NDCG@5': 0.9219771, 'NDCG@10': 0.92235374, 'MRR@1': 0.9568452, 'MRR@5': 0.97842264, 'MRR@10': 0.97842264}
Training metrics for epoch: 76  {'MAP': 0.97794116, 'NDCG@1': 0.759874, 'NDCG@5': 0.9215532, 'NDCG@10': 0.9219254, 'MRR@1': 0.95735294, 'MRR@5': 0.9786765, 'MRR@10': 0.9786765}
Training metrics for epoch: 77  {'MAP': 0.9781977, 'NDCG@1': 0.75754434, 'NDCG@5': 0.9208438, 'NDCG@10': 0.92121166, 'MRR@1': 0.95784885, 'MRR@5': 0.9789244, 'MRR@10': 0.9789244}
Training metrics for epoch: 78  {'MAP': 0.9784483, 'NDCG@1': 0.7573208, 'NDCG@5': 0.9208061, 'NDCG@10': 0.92116976, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.9791667}
Training metrics for epoch: 79  {'MAP': 0.9786932, 'NDCG@1': 0.758861, 'NDCG@5': 0.92129385, 'NDCG@10': 0.9216534, 'MRR@1': 0.9588068, 'MRR@5': 0.97940344, 'MRR@10': 0.97940344}
Training metrics for epoch: 80  {'MAP': 0.97893256, 'NDCG@1': 0.7603666, 'NDCG@5': 0.9217707, 'NDCG@10': 0.9221262, 'MRR@1': 0.95926964, 'MRR@5': 0.9796348, 'MRR@10': 0.9796348}
Validation metrics for epoch: 80  {'MAP': 0.9791667, 'NDCG@1': 0.75925934, 'NDCG@5': 0.92172426, 'NDCG@10': 0.92167276, 'MRR@1': 0.9583333, 'MRR@5': 0.9791667, 'MRR@10': 0.97986114}
Training metrics for epoch: 81  {'MAP': 0.9793956, 'NDCG@1': 0.7595501, 'NDCG@5': 0.9217872, 'NDCG@10': 0.92173624, 'MRR@1': 0.9587912, 'MRR@5': 0.9793956, 'MRR@10': 0.9800824}
Training metrics for epoch: 82  {'MAP': 0.97961956, 'NDCG@1': 0.7606108, 'NDCG@5': 0.92218626, 'NDCG@10': 0.92213583, 'MRR@1': 0.9592391, 'MRR@5': 0.97961956, 'MRR@10': 0.98029894}
Training metrics for epoch: 83  {'MAP': 0.9798387, 'NDCG@1': 0.7601127, 'NDCG@5': 0.9220197, 'NDCG@10': 0.92196983, 'MRR@1': 0.9596774, 'MRR@5': 0.9798387, 'MRR@10': 0.9805108}
Training metrics for epoch: 84  {'MAP': 0.9800532, 'NDCG@1': 0.7596252, 'NDCG@5': 0.92185676, 'NDCG@10': 0.9218074, 'MRR@1': 0.9601064, 'MRR@5': 0.9800532, 'MRR@10': 0.9807181}
Training metrics for epoch: 85  {'MAP': 0.9802632, 'NDCG@1': 0.7612783, 'NDCG@5': 0.9224118, 'NDCG@10': 0.922363, 'MRR@1': 0.9605263, 'MRR@5': 0.9802632, 'MRR@10': 0.98092103}
Training metrics for epoch: 86  {'MAP': 0.98046875, 'NDCG@1': 0.7626489, 'NDCG@5': 0.9228422, 'NDCG@10': 0.92279387, 'MRR@1': 0.9609375, 'MRR@5': 0.98046875, 'MRR@10': 0.9811198}
Training metrics for epoch: 87  {'MAP': 0.9806701, 'NDCG@1': 0.7639913, 'NDCG@5': 0.9232637, 'NDCG@10': 0.92321587, 'MRR@1': 0.9613402, 'MRR@5': 0.9806701, 'MRR@10': 0.9813144}
Training metrics for epoch: 88  {'MAP': 0.9808673, 'NDCG@1': 0.76372707, 'NDCG@5': 0.92320555, 'NDCG@10': 0.9231582, 'MRR@1': 0.9617347, 'MRR@5': 0.9808673, 'MRR@10': 0.9815051}
Training metrics for epoch: 89  {'MAP': 0.9810606, 'NDCG@1': 0.76611364, 'NDCG@5': 0.92398125, 'NDCG@10': 0.9239344, 'MRR@1': 0.9621212, 'MRR@5': 0.9810606, 'MRR@10': 0.9816919}
Training metrics for epoch: 90  {'MAP': 0.980625, 'NDCG@1': 0.7645835, 'NDCG@5': 0.9234557, 'NDCG@10': 0.92340934, 'MRR@1': 0.96125, 'MRR@5': 0.980625, 'MRR@10': 0.98125}
Validation metrics for epoch: 90  {'MAP': 0.98081684, 'NDCG@1': 0.7650285, 'NDCG@5': 0.923649, 'NDCG@10': 0.9235569, 'MRR@1': 0.9616337, 'MRR@5': 0.98081684, 'MRR@10': 0.98143566}
Training metrics for epoch: 91  {'MAP': 0.9810049, 'NDCG@1': 0.76581484, 'NDCG@5': 0.92394495, 'NDCG@10': 0.92385375, 'MRR@1': 0.9620098, 'MRR@5': 0.9810049, 'MRR@10': 0.9816176}
Training metrics for epoch: 92  {'MAP': 0.9811893, 'NDCG@1': 0.7662392, 'NDCG@5': 0.9240845, 'NDCG@10': 0.9239942, 'MRR@1': 0.9623786, 'MRR@5': 0.9811893, 'MRR@10': 0.98179615}
Training metrics for epoch: 93  {'MAP': 0.9813702, 'NDCG@1': 0.7678001, 'NDCG@5': 0.9246149, 'NDCG@10': 0.9245255, 'MRR@1': 0.96274036, 'MRR@5': 0.9813702, 'MRR@10': 0.98197114}
Training metrics for epoch: 94  {'MAP': 0.9815476, 'NDCG@1': 0.7676306, 'NDCG@5': 0.92459214, 'NDCG@10': 0.92450356, 'MRR@1': 0.96309525, 'MRR@5': 0.9815476, 'MRR@10': 0.98214287}
Training metrics for epoch: 95  {'MAP': 0.9817217, 'NDCG@1': 0.7698228, 'NDCG@5': 0.9253036, 'NDCG@10': 0.92521584, 'MRR@1': 0.9634434, 'MRR@5': 0.9817217, 'MRR@10': 0.9823113}
Training metrics for epoch: 96  {'MAP': 0.9807243, 'NDCG@1': 0.7683023, 'NDCG@5': 0.9247515, 'NDCG@10': 0.92466456, 'MRR@1': 0.9614486, 'MRR@5': 0.9807243, 'MRR@10': 0.9813084}
Training metrics for epoch: 97  {'MAP': 0.9797454, 'NDCG@1': 0.76714087, 'NDCG@5': 0.92425805, 'NDCG@10': 0.9241719, 'MRR@1': 0.9594907, 'MRR@5': 0.9797454, 'MRR@10': 0.9803241}
Training metrics for epoch: 98  {'MAP': 0.9799312, 'NDCG@1': 0.7682942, 'NDCG@5': 0.9246202, 'NDCG@10': 0.92453486, 'MRR@1': 0.9598624, 'MRR@5': 0.9799312, 'MRR@10': 0.9805046}
Training metrics for epoch: 99  {'MAP': 0.9801136, 'NDCG@1': 0.767695, 'NDCG@5': 0.9244149, 'NDCG@10': 0.92433035, 'MRR@1': 0.96022725, 'MRR@5': 0.9801136, 'MRR@10': 0.98068184}