{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "QZ_E1Ma1jKYT",
        "-6LGAl2AGj4j",
        "Ooplt3-YJJRb",
        "gUMZZ3XKSb1A"
      ],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Now use reranker on the retrieved chunks of text."
      ],
      "metadata": {
        "id": "phZHoiIHTVfT"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "reranker = FlagReranker('BAAI/bge-reranker-large', use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation"
      ],
      "metadata": {
        "id": "RNaWdHLy4idV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Test reranker to show it's capabilities. Returns similarity scores. With improved ranks.\n",
        "scores_test = reranker.compute_score([['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']])\n",
        "print(scores_test)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Dgllh5mwTDun",
        "outputId": "c29671a3-a78e-41d9-d1b4-97033051f2ad"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[-5.59765625, 5.76171875]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# make pairs of query and chunks\n",
        "query_and_chunks = [[query, chunk] for chunk in context_chunks_init]"
      ],
      "metadata": {
        "id": "W8O4-B3F4yTv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "scores_reranker = reranker.compute_score(query_and_chunks)\n",
        "print(scores_reranker)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "URNQfNks5NFo",
        "outputId": "3da4869b-934e-491a-a628-f65e5cae6e67"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[2.400390625, -1.37890625, -2.18359375, -3.408203125, -0.469482421875, -4.87890625, -3.712890625, -2.5546875]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# indexes sorted according to new rank\n",
        "max_idx_reranked = np.argsort(-np.array(scores_reranker))\n",
        "print(max_idx_reranked)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RpIcjK7b5pGr",
        "outputId": "eb7b07a3-447e-4763-aa05-5e5054c1ffc4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[0 4 1 2 7 3 6 5]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(f\"Query: {query}\\n\")\n",
        "context_chunks = []\n",
        "for idx in max_idx_reranked:\n",
        "  print(f\"Score: {scores_reranker[idx]:.3f}\")\n",
        "  print(context_chunks_init[idx].split('\\n')[1])\n",
        "  print(\"--------\")\n",
        "  context_chunks.append(context_chunks_init[idx])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "je05CsAJ515Y",
        "outputId": "89c0034c-ea6c-4c8a-ae15-e62a506ac7a5"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Query: how do i get student loan\n",
            "\n",
            "Score: 2.400\n",
            "Borrowing money as a student \n",
            "--------\n",
            "Score: -0.469\n",
            "Tips to successfully apply for a loan \n",
            "--------\n",
            "Score: -1.379\n",
            "Paying off your student loan early \n",
            "--------\n",
            "Score: -2.184\n",
            "How to budget as a student \n",
            "--------\n",
            "Score: -2.555\n",
            "Ways to borrow \n",
            "--------\n",
            "Score: -3.408\n",
            "What is a loan? \n",
            "--------\n",
            "Score: -3.713\n",
            "Should you get a student credit card? \n",
            "--------\n",
            "Score: -4.879\n",
            "How much can I borrow? \n",
            "--------\n"
          ]
        }
      ]
    }
  ]
}