Skip to content

Instantly share code, notes, and snippets.

@rbiswasfc
Last active March 14, 2025 04:22
Show Gist options
  • Save rbiswasfc/2f95833c8c0702b525d4cada23600b9f to your computer and use it in GitHub Desktop.
Save rbiswasfc/2f95833c8c0702b525d4cada23600b9f to your computer and use it in GitHub Desktop.
EEDI Data Prep
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"id": "0e1161c0",
"cell_type": "code",
"source": "import ast\nimport os\nimport random\nimport kagglehub\nimport json\n\nimport pandas as pd\nimport numpy as np\nfrom tqdm.auto import tqdm\nfrom copy import deepcopy\n\npd.options.display.max_colwidth = None",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "6ea16a33",
"cell_type": "code",
"source": "from collections import defaultdict\nfrom sentence_transformers import SentenceTransformer\nfrom sklearn.metrics.pairwise import cosine_similarity",
"execution_count": 2,
"outputs": []
},
{
"metadata": {},
"id": "bee10d2f",
"cell_type": "markdown",
"source": "# Competition data"
},
{
"metadata": {
"trusted": true
},
"id": "2f55fcde",
"cell_type": "code",
"source": "data_dir = \"/Users/rajabiswas/.cache/kagglehub/competitions/eedi-mining-misconceptions-in-mathematics\"\n\ndf = pd.read_csv(os.path.join(data_dir, \"train.csv\"))\ncontent_df = pd.read_csv(os.path.join(data_dir, \"misconception_mapping.csv\"))",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "dc94e09c",
"cell_type": "code",
"source": "fold_df = pd.read_parquet(\"../data/scratch/five_folds.parquet\")\ndf = pd.merge(df, fold_df, on=\"QuestionId\")\n\nfold = 0\ntrain_df = df[df.kfold != fold].copy()\nvalid_df = df[df.kfold == fold].copy()",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "0a650f76",
"cell_type": "code",
"source": "train_df.shape, valid_df.shape",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "((1495, 16), (374, 16))"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "ccd5a816",
"cell_type": "code",
"source": "valid_misconception_ids = set()\nfor idx, row in valid_df.iterrows():\n for letter in [\"A\", \"B\", \"C\", \"D\"]:\n if pd.notna(row[f\"Misconception{letter}Id\"]):\n valid_misconception_ids.add(row[f\"Misconception{letter}Id\"])\nprint(f\"# of validation misconceptions: {len(valid_misconception_ids)}\")\n\ntrain_misconception_ids = set()\nfor idx, row in train_df.iterrows():\n for letter in [\"A\", \"B\", \"C\", \"D\"]:\n if pd.notna(row[f\"Misconception{letter}Id\"]):\n train_misconception_ids.add(row[f\"Misconception{letter}Id\"])\nprint(f\"# of training misconceptions: {len(train_misconception_ids)}\")\n\n# remove intersections\nvalid_misconception_ids = valid_misconception_ids - train_misconception_ids\nprint(f\"# of new validation misconceptions: {len(valid_misconception_ids)}\")",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "# of validation misconceptions: 435\n# of training misconceptions: 1378\n# of new validation misconceptions: 226\n"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "c1147d9f",
"cell_type": "code",
"source": "vqids = valid_df['QuestionId'].unique().tolist()",
"execution_count": 7,
"outputs": []
},
{
"metadata": {},
"id": "2b0322f7",
"cell_type": "markdown",
"source": "# Ranker Predictions"
},
{
"metadata": {
"trusted": true
},
"id": "67937121",
"cell_type": "code",
"source": "data_dir = kagglehub.dataset_download(\"conjuring92/eedi-ranker-silver-v3-teacher\")\nranker_df = pd.read_parquet(os.path.join(data_dir, \"train.parquet\"))\nranker_df_valid = pd.read_parquet(os.path.join(data_dir, \"valid_ff.parquet\"))",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "1eeb78db",
"cell_type": "code",
"source": "ranker_df.label.value_counts()",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "label\n0 441400\n1 18948\nName: count, dtype: int64"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "f6debfb2",
"cell_type": "code",
"source": "ranker_df.sample()",
"execution_count": 11,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_id</th>\n <th>SubjectName</th>\n <th>ConstructName</th>\n <th>QuestionText</th>\n <th>CorrectAnswerText</th>\n <th>InCorrectAnswerText</th>\n <th>MisconceptionName</th>\n <th>AllOptionText</th>\n <th>label</th>\n <th>teacher_score</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>231856</th>\n <td>60489_A</td>\n <td>473</td>\n <td>Fractal Geometry and Iterative Patterns</td>\n <td>Interpret a pictogram where the symbols are not evenly spaced</td>\n <td>![A pictogram showing a symbol of a triangle. The symbol is repeated 5 times in a row for the first category, then 4 times with the second triangle symbol being slightly further apart from the others for the second category, then 3 times with the third triangle symbol being slightly further apart from the others for the third category.]() Which shape has the greatest frequency?</td>\n <td>They all have the same frequency</td>\n <td>The first shape</td>\n <td>Underestimates the impact of the size of images in a misleading statistical diagram</td>\n <td>\\n- The first shape\\n- The second shape\\n- The third shape\\n- They all have the same frequency</td>\n <td>0</td>\n <td>-0.625</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id content_id SubjectName \\\n231856 60489_A 473 Fractal Geometry and Iterative Patterns \n\n ConstructName \\\n231856 Interpret a pictogram where the symbols are not evenly spaced \n\n QuestionText \\\n231856 ![A pictogram showing a symbol of a triangle. The symbol is repeated 5 times in a row for the first category, then 4 times with the second triangle symbol being slightly further apart from the others for the second category, then 3 times with the third triangle symbol being slightly further apart from the others for the third category.]() Which shape has the greatest frequency? \n\n CorrectAnswerText InCorrectAnswerText \\\n231856 They all have the same frequency The first shape \n\n MisconceptionName \\\n231856 Underestimates the impact of the size of images in a misleading statistical diagram \n\n AllOptionText \\\n231856 \\n- The first shape\\n- The second shape\\n- The third shape\\n- They all have the same frequency \n\n label teacher_score \n231856 0 -0.625 "
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "926ae101",
"cell_type": "code",
"source": "# remove examples with low teacher scores for positives --\nteacher_cutoff = 2.0 # 3.0\nbad_df = ranker_df[(ranker_df['label']==1) & (ranker_df['teacher_score']<teacher_cutoff)].copy()\nbad_df.shape",
"execution_count": 16,
"outputs": [
{
"data": {
"text/plain": "(2178, 11)"
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "fe4578b8",
"cell_type": "code",
"source": "bad_qids = bad_df['query_id'].values.tolist()\nbad_qids = [x.split(\"_\")[0] for x in bad_qids]\nbad_qids = list(map(int, bad_qids))\nbad_qids = [x for x in bad_qids if x >= 2000]\nlen(bad_qids)",
"execution_count": 17,
"outputs": [
{
"data": {
"text/plain": "2085"
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "7cf76715",
"cell_type": "code",
"source": "ranker_df[\"QuestionId\"] = ranker_df['query_id'].apply(lambda x: x.split(\"_\")[0])\nranker_df[\"QuestionId\"] = ranker_df[\"QuestionId\"].astype(int)",
"execution_count": 18,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "7bf8faca",
"cell_type": "code",
"source": "ranker_df = ranker_df[~ranker_df['QuestionId'].isin(bad_qids)].copy()\nranker_df = ranker_df.reset_index(drop=True)",
"execution_count": 19,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "247f144a",
"cell_type": "code",
"source": "ranker_df.shape",
"execution_count": 20,
"outputs": [
{
"data": {
"text/plain": "(384909, 12)"
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"id": "e3e41be3",
"cell_type": "markdown",
"source": "# MCQ Data"
},
{
"metadata": {
"trusted": true
},
"id": "7c4b7758",
"cell_type": "code",
"source": "data_dir = kagglehub.dataset_download(\"conjuring92/eedi-silver-v3\")\nmcq_df = pd.read_csv(os.path.join(data_dir, \"train.csv\"))\ncontent_df = pd.read_csv(os.path.join(data_dir, \"misconception_mapping.csv\"))\n\nfor letter in ['A', 'B', 'C', 'D']:\n mcq_df = mcq_df.merge(content_df, left_on=f'Misconception{letter}Id', right_on='MisconceptionId', how='left')\n mcq_df = mcq_df.rename(columns={'MisconceptionName': f'Misconception{letter}Name'})\n mcq_df = mcq_df.drop('MisconceptionId', axis=1)\nmcq_df.shape",
"execution_count": 22,
"outputs": [
{
"data": {
"text/plain": "(12473, 20)"
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "7e808cce",
"cell_type": "code",
"source": "mcq_df = mcq_df[~mcq_df['QuestionId'].isin(bad_qids)].copy()\nmcq_df.shape",
"execution_count": 23,
"outputs": [
{
"data": {
"text/plain": "(10594, 20)"
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "da0ab038",
"cell_type": "code",
"source": "# FULLFIT = True\n\n# if not FULLFIT:\n# mcq_df = mcq_df[~mcq_df[\"QuestionId\"].isin(vqids)].copy()\n# mcq_df = mcq_df.reset_index(drop=True)\n# mcq_df.shape",
"execution_count": 24,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "4b740b64",
"cell_type": "code",
"source": "mcq_df.sample()",
"execution_count": 25,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>QuestionId</th>\n <th>ConstructId</th>\n <th>ConstructName</th>\n <th>SubjectId</th>\n <th>SubjectName</th>\n <th>CorrectAnswer</th>\n <th>QuestionText</th>\n <th>AnswerAText</th>\n <th>AnswerBText</th>\n <th>AnswerCText</th>\n <th>AnswerDText</th>\n <th>MisconceptionAId</th>\n <th>MisconceptionBId</th>\n <th>MisconceptionCId</th>\n <th>MisconceptionDId</th>\n <th>source</th>\n <th>MisconceptionAName</th>\n <th>MisconceptionBName</th>\n <th>MisconceptionCName</th>\n <th>MisconceptionDName</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>496</th>\n <td>300644</td>\n <td>-1</td>\n <td>Convert between g and tonne</td>\n <td>-1</td>\n <td>Weight Units</td>\n <td>D</td>\n <td>A large container weighs \\( 2 \\mathrm{~tonnes} \\). What is this in grams?</td>\n <td>\\( 2000 \\mathrm{~g} \\)</td>\n <td>\\( 20000 \\mathrm{~g} \\)</td>\n <td>\\( 200000 \\mathrm{~g} \\)</td>\n <td>\\( 2000000 \\mathrm{~g} \\)</td>\n <td>666.0</td>\n <td>765.0</td>\n <td>784.0</td>\n <td>NaN</td>\n <td>group</td>\n <td>Thinks grams and tonnes are the same</td>\n <td>Thinks there are 10kg in a tonne</td>\n <td>Thinks there are 100g in a kilogram</td>\n <td>NaN</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " QuestionId ConstructId ConstructName SubjectId \\\n496 300644 -1 Convert between g and tonne -1 \n\n SubjectName CorrectAnswer \\\n496 Weight Units D \n\n QuestionText \\\n496 A large container weighs \\( 2 \\mathrm{~tonnes} \\). What is this in grams? \n\n AnswerAText AnswerBText \\\n496 \\( 2000 \\mathrm{~g} \\) \\( 20000 \\mathrm{~g} \\) \n\n AnswerCText AnswerDText MisconceptionAId \\\n496 \\( 200000 \\mathrm{~g} \\) \\( 2000000 \\mathrm{~g} \\) 666.0 \n\n MisconceptionBId MisconceptionCId MisconceptionDId source \\\n496 765.0 784.0 NaN group \n\n MisconceptionAName MisconceptionBName \\\n496 Thinks grams and tonnes are the same Thinks there are 10kg in a tonne \n\n MisconceptionCName MisconceptionDName \n496 Thinks there are 100g in a kilogram NaN "
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "c87e8570",
"cell_type": "code",
"source": "def count_nonempty_misconceptions(df):\n total_count = 0\n for letter in ['A', 'B', 'C', 'D']:\n count = df[f'Misconception{letter}Name'].notna().sum()\n total_count += count\n print(f\"Total non-empty MisconceptionNames: {total_count}\")\n\ncount_nonempty_misconceptions(mcq_df)",
"execution_count": 26,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Total non-empty MisconceptionNames: 16706\n"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "01965ed5",
"cell_type": "code",
"source": "mcq_df = mcq_df.drop_duplicates(subset=['QuestionId'])\nmcq_df = mcq_df.dropna(subset=['MisconceptionAId', 'MisconceptionBId', 'MisconceptionCId', 'MisconceptionDId'], how='all')\nmcq_df.shape",
"execution_count": 27,
"outputs": [
{
"data": {
"text/plain": "(10594, 20)"
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "216f048c",
"cell_type": "code",
"source": "mcq_df = mcq_df[~mcq_df.apply(lambda row: pd.notna(row[f'Misconception{row.CorrectAnswer}Id']), axis=1)]\nprint(f\"Shape after dropping: {mcq_df.shape}\")",
"execution_count": 28,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Shape after dropping: (10594, 20)\n"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "a399f944",
"cell_type": "code",
"source": "# mcq_df.tail()",
"execution_count": 29,
"outputs": []
},
{
"metadata": {},
"id": "bbe88b94",
"cell_type": "markdown",
"source": "# Maps "
},
{
"metadata": {
"trusted": true
},
"id": "10d037fb",
"cell_type": "code",
"source": "mcq_df.shape, content_df.shape",
"execution_count": 30,
"outputs": [
{
"data": {
"text/plain": "((10594, 20), (4791, 2))"
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "8d0169be",
"cell_type": "code",
"source": "pred_map = defaultdict(list)\nteacher_score_map = {}\n\nfor _, row in ranker_df.iterrows():\n query_id = row['query_id']\n content_id = row['content_id']\n teacher_score = row['teacher_score']\n \n pred_map[query_id].append(content_id)\n teacher_score_map[f\"{query_id}|{content_id}\"] = teacher_score",
"execution_count": 31,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "4f72dd9f",
"cell_type": "code",
"source": "for _, row in ranker_df_valid.iterrows():\n query_id = row['query_id']\n content_id = row['content_id']\n teacher_score = row['teacher_score']\n \n pred_map[query_id].append(content_id)\n teacher_score_map[f\"{query_id}|{content_id}\"] = teacher_score",
"execution_count": 32,
"outputs": []
},
{
"metadata": {},
"id": "15abfd6f",
"cell_type": "markdown",
"source": "# Denoise HN"
},
{
"metadata": {
"trusted": true
},
"id": "ea035d8b",
"cell_type": "code",
"source": "pred_df = pd.DataFrame(list(pred_map.items()), columns=['query_id', 'content_ids'])\npred_df.head()",
"execution_count": 33,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_ids</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1_A</td>\n <td>[2142, 2398, 2581, 2068, 838, 1755, 418, 2372, 5727, 167, 1871, 143, 2078, 2277, 2070, 891, 1256, 2256, 4523, 1535, 519, 1421, 1606, 113, 628]</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1_B</td>\n <td>[143, 891, 167, 418, 2078, 220, 979, 519, 2068, 1540, 2372, 4522, 5207, 1755, 3610, 2567, 1593, 4523, 628, 80, 113, 59, 1871, 1153]</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1_C</td>\n <td>[2142, 143, 2277, 2070, 167, 1079, 3464, 418, 547, 2581, 838, 320, 519, 1755, 113, 5727, 3412, 2436, 891, 1421, 2068, 1535, 3632, 688]</td>\n </tr>\n <tr>\n <th>3</th>\n <td>2_A</td>\n <td>[1287, 676, 4964, 1073, 1866, 5156, 2386, 3475, 3974, 276, 1521, 2555, 4245, 1408, 4136, 310, 5661, 4297, 3908, 632, 5550, 306, 1797, 4261, 912]</td>\n </tr>\n <tr>\n <th>4</th>\n <td>2_C</td>\n <td>[1287, 306, 2551, 1338, 5550, 5156, 1408, 2319, 3908, 365, 1975, 2439, 4261, 3974, 5661, 397, 1073, 4136, 4245, 3438, 1059, 691, 4379, 3417]</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id \\\n0 1_A \n1 1_B \n2 1_C \n3 2_A \n4 2_C \n\n content_ids \n0 [2142, 2398, 2581, 2068, 838, 1755, 418, 2372, 5727, 167, 1871, 143, 2078, 2277, 2070, 891, 1256, 2256, 4523, 1535, 519, 1421, 1606, 113, 628] \n1 [143, 891, 167, 418, 2078, 220, 979, 519, 2068, 1540, 2372, 4522, 5207, 1755, 3610, 2567, 1593, 4523, 628, 80, 113, 59, 1871, 1153] \n2 [2142, 143, 2277, 2070, 167, 1079, 3464, 418, 547, 2581, 838, 320, 519, 1755, 113, 5727, 3412, 2436, 891, 1421, 2068, 1535, 3632, 688] \n3 [1287, 676, 4964, 1073, 1866, 5156, 2386, 3475, 3974, 276, 1521, 2555, 4245, 1408, 4136, 310, 5661, 4297, 3908, 632, 5550, 306, 1797, 4261, 912] \n4 [1287, 306, 2551, 1338, 5550, 5156, 1408, 2319, 3908, 365, 1975, 2439, 4261, 3974, 5661, 397, 1073, 4136, 4245, 3438, 1059, 691, 4379, 3417] "
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "7a14b088",
"cell_type": "code",
"source": "pred_df[pred_df['query_id']=='0_D']",
"execution_count": 34,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_ids</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>15855</th>\n <td>0_D</td>\n <td>[1672, 1005, 1507, 2532, 1392, 706, 2488, 2306, 315, 1345, 1516, 4377, 5251, 328, 5594, 3524, 3353, 2518, 871, 4557, 158, 4465, 1999, 987, 1963, 2051, 2181, 2449, 638, 5622, 5528, 4789, 3110, 2586, 256, 1226, 488, 1316, 1011, 5587, 4149, 1670, 657, 4051, 5241, 1090, 1941, 1336]</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id \\\n15855 0_D \n\n content_ids \n15855 [1672, 1005, 1507, 2532, 1392, 706, 2488, 2306, 315, 1345, 1516, 4377, 5251, 328, 5594, 3524, 3353, 2518, 871, 4557, 158, 4465, 1999, 987, 1963, 2051, 2181, 2449, 638, 5622, 5528, 4789, 3110, 2586, 256, 1226, 488, 1316, 1011, 5587, 4149, 1670, 657, 4051, 5241, 1090, 1941, 1336] "
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "66e9c87e",
"cell_type": "code",
"source": "true_map = dict()\n\nfor idx, row in mcq_df.iterrows():\n for letter in \"ABCD\":\n query_id = f\"{row['QuestionId']}_{letter}\"\n misconception_id = row[f'Misconception{letter}Id']\n if pd.notna(misconception_id):\n true_map[query_id] = int(misconception_id)\n\npred_df['true_id'] = pred_df['query_id'].map(true_map)\npred_df.head()",
"execution_count": 35,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_ids</th>\n <th>true_id</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>1_A</td>\n <td>[2142, 2398, 2581, 2068, 838, 1755, 418, 2372, 5727, 167, 1871, 143, 2078, 2277, 2070, 891, 1256, 2256, 4523, 1535, 519, 1421, 1606, 113, 628]</td>\n <td>2142</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1_B</td>\n <td>[143, 891, 167, 418, 2078, 220, 979, 519, 2068, 1540, 2372, 4522, 5207, 1755, 3610, 2567, 1593, 4523, 628, 80, 113, 59, 1871, 1153]</td>\n <td>143</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1_C</td>\n <td>[2142, 143, 2277, 2070, 167, 1079, 3464, 418, 547, 2581, 838, 320, 519, 1755, 113, 5727, 3412, 2436, 891, 1421, 2068, 1535, 3632, 688]</td>\n <td>2142</td>\n </tr>\n <tr>\n <th>3</th>\n <td>2_A</td>\n <td>[1287, 676, 4964, 1073, 1866, 5156, 2386, 3475, 3974, 276, 1521, 2555, 4245, 1408, 4136, 310, 5661, 4297, 3908, 632, 5550, 306, 1797, 4261, 912]</td>\n <td>1287</td>\n </tr>\n <tr>\n <th>4</th>\n <td>2_C</td>\n <td>[1287, 306, 2551, 1338, 5550, 5156, 1408, 2319, 3908, 365, 1975, 2439, 4261, 3974, 5661, 397, 1073, 4136, 4245, 3438, 1059, 691, 4379, 3417]</td>\n <td>1287</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id \\\n0 1_A \n1 1_B \n2 1_C \n3 2_A \n4 2_C \n\n content_ids \\\n0 [2142, 2398, 2581, 2068, 838, 1755, 418, 2372, 5727, 167, 1871, 143, 2078, 2277, 2070, 891, 1256, 2256, 4523, 1535, 519, 1421, 1606, 113, 628] \n1 [143, 891, 167, 418, 2078, 220, 979, 519, 2068, 1540, 2372, 4522, 5207, 1755, 3610, 2567, 1593, 4523, 628, 80, 113, 59, 1871, 1153] \n2 [2142, 143, 2277, 2070, 167, 1079, 3464, 418, 547, 2581, 838, 320, 519, 1755, 113, 5727, 3412, 2436, 891, 1421, 2068, 1535, 3632, 688] \n3 [1287, 676, 4964, 1073, 1866, 5156, 2386, 3475, 3974, 276, 1521, 2555, 4245, 1408, 4136, 310, 5661, 4297, 3908, 632, 5550, 306, 1797, 4261, 912] \n4 [1287, 306, 2551, 1338, 5550, 5156, 1408, 2319, 3908, 365, 1975, 2439, 4261, 3974, 5661, 397, 1073, 4136, 4245, 3438, 1059, 691, 4379, 3417] \n\n true_id \n0 2142 \n1 143 \n2 2142 \n3 1287 \n4 1287 "
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "046d7647",
"cell_type": "code",
"source": "mcq_df.shape",
"execution_count": 36,
"outputs": [
{
"data": {
"text/plain": "(10594, 20)"
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "25cdb523",
"cell_type": "code",
"source": "pred_df[pred_df['query_id']=='0_D']",
"execution_count": 37,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_ids</th>\n <th>true_id</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>15855</th>\n <td>0_D</td>\n <td>[1672, 1005, 1507, 2532, 1392, 706, 2488, 2306, 315, 1345, 1516, 4377, 5251, 328, 5594, 3524, 3353, 2518, 871, 4557, 158, 4465, 1999, 987, 1963, 2051, 2181, 2449, 638, 5622, 5528, 4789, 3110, 2586, 256, 1226, 488, 1316, 1011, 5587, 4149, 1670, 657, 4051, 5241, 1090, 1941, 1336]</td>\n <td>1672</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id \\\n15855 0_D \n\n content_ids \\\n15855 [1672, 1005, 1507, 2532, 1392, 706, 2488, 2306, 315, 1345, 1516, 4377, 5251, 328, 5594, 3524, 3353, 2518, 871, 4557, 158, 4465, 1999, 987, 1963, 2051, 2181, 2449, 638, 5622, 5528, 4789, 3110, 2586, 256, 1226, 488, 1316, 1011, 5587, 4149, 1670, 657, 4051, 5241, 1090, 1941, 1336] \n\n true_id \n15855 1672 "
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "954e0191",
"cell_type": "code",
"source": "pred_df = pred_df[~pred_df['true_id'].isna()].copy()\npred_df = pred_df.reset_index(drop=True)\npred_df.shape",
"execution_count": 38,
"outputs": [
{
"data": {
"text/plain": "(16706, 3)"
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "eed3b900",
"cell_type": "code",
"source": "def get_true_content_index(row):\n try:\n return row['content_ids'].index(row['true_id'])\n except ValueError:\n return -1\n\npred_df['true_content_index'] = pred_df.apply(get_true_content_index, axis=1)",
"execution_count": 39,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "6eb1d214",
"cell_type": "code",
"source": "# pred_df.true_content_index.value_counts()",
"execution_count": 40,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "ff915a1f",
"cell_type": "code",
"source": "pred_df['true_id'] = pred_df['true_id'].astype(int)",
"execution_count": 41,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "35118650",
"cell_type": "code",
"source": "pred_df.sample(5)",
"execution_count": 42,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_ids</th>\n <th>true_id</th>\n <th>true_content_index</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>8622</th>\n <td>72415_B</td>\n <td>[1925, 74, 3383, 1066, 2377, 1115, 2376, 1226, 2058, 845, 2252, 108, 1011, 4203, 969, 162, 1411, 939, 2386, 102, 373, 558, 1058, 4167]</td>\n <td>1925</td>\n <td>0</td>\n </tr>\n <tr>\n <th>15058</th>\n <td>403236_A</td>\n <td>[1736, 4687, 4697, 5234, 4475, 4482, 427, 299, 4302, 4545, 3108, 4686, 581, 4039, 4851, 1902, 17, 3099, 348, 13, 3615, 1339, 3177, 1013]</td>\n <td>1736</td>\n <td>0</td>\n </tr>\n <tr>\n <th>4602</th>\n <td>11518_D</td>\n <td>[4388, 5056, 440, 210, 4027, 2253, 5364, 3813, 672, 3012, 4198, 469, 215, 1474, 3811, 5078, 3751, 936, 2284, 713, 1240, 3113, 3317, 3200, 3658]</td>\n <td>4388</td>\n <td>0</td>\n </tr>\n <tr>\n <th>6363</th>\n <td>14516_B</td>\n <td>[768, 1375, 2560, 530, 607, 274, 4317, 5660, 174, 973, 4447, 2440, 5657, 5698, 4220, 4578, 4459, 1763, 4452, 1201, 4481, 5006, 5631, 4357]</td>\n <td>768</td>\n <td>0</td>\n </tr>\n <tr>\n <th>9729</th>\n <td>300787_A</td>\n <td>[2345, 649, 1780, 660, 1678, 4452, 4176, 2528, 1292, 743, 881, 2353, 4999, 1668, 1802, 4417, 653, 4048, 4588, 828, 2450, 71, 1671, 2090, 584]</td>\n <td>2345</td>\n <td>0</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id \\\n8622 72415_B \n15058 403236_A \n4602 11518_D \n6363 14516_B \n9729 300787_A \n\n content_ids \\\n8622 [1925, 74, 3383, 1066, 2377, 1115, 2376, 1226, 2058, 845, 2252, 108, 1011, 4203, 969, 162, 1411, 939, 2386, 102, 373, 558, 1058, 4167] \n15058 [1736, 4687, 4697, 5234, 4475, 4482, 427, 299, 4302, 4545, 3108, 4686, 581, 4039, 4851, 1902, 17, 3099, 348, 13, 3615, 1339, 3177, 1013] \n4602 [4388, 5056, 440, 210, 4027, 2253, 5364, 3813, 672, 3012, 4198, 469, 215, 1474, 3811, 5078, 3751, 936, 2284, 713, 1240, 3113, 3317, 3200, 3658] \n6363 [768, 1375, 2560, 530, 607, 274, 4317, 5660, 174, 973, 4447, 2440, 5657, 5698, 4220, 4578, 4459, 1763, 4452, 1201, 4481, 5006, 5631, 4357] \n9729 [2345, 649, 1780, 660, 1678, 4452, 4176, 2528, 1292, 743, 881, 2353, 4999, 1668, 1802, 4417, 653, 4048, 4588, 828, 2450, 71, 1671, 2090, 584] \n\n true_id true_content_index \n8622 1925 0 \n15058 1736 0 \n4602 4388 0 \n6363 768 0 \n9729 2345 0 "
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "92e30f5f",
"cell_type": "code",
"source": "# pred_df.true_content_index.value_counts()",
"execution_count": 43,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "5d40c017",
"cell_type": "code",
"source": "def add_teacher_scores(row):\n return [teacher_score_map.get(f\"{row['query_id']}|{content_id}\", 0) for content_id in row['content_ids']]\n\npred_df['teacher_logits'] = pred_df.apply(add_teacher_scores, axis=1)",
"execution_count": 44,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "51f5486c",
"cell_type": "code",
"source": "pred_df[pred_df['query_id']=='0_D']",
"execution_count": 45,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_ids</th>\n <th>true_id</th>\n <th>true_content_index</th>\n <th>teacher_logits</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>15855</th>\n <td>0_D</td>\n <td>[1672, 1005, 1507, 2532, 1392, 706, 2488, 2306, 315, 1345, 1516, 4377, 5251, 328, 5594, 3524, 3353, 2518, 871, 4557, 158, 4465, 1999, 987, 1963, 2051, 2181, 2449, 638, 5622, 5528, 4789, 3110, 2586, 256, 1226, 488, 1316, 1011, 5587, 4149, 1670, 657, 4051, 5241, 1090, 1941, 1336]</td>\n <td>1672</td>\n <td>0</td>\n <td>[3.5, 5.249999903142452, 5.499999988824129, 5.9999997690320015, 4.249999910593033, 4.812499798834324, 2.187500074505806, 3.937499776482582, 4.812499836087227, -0.5624999403953552, 4.187499985098839, -0.125, 1.6875000298023224, 3.937499962747097, 3.2500000298023224, 1.5000000596046448, 1.4375, 3.1250000447034836, -0.9375, 1.1874999403953552, 0.0625, 2.499999985098839, -0.625, -1.4999999701976776, 3.187500074505806, -1.3749999701976776, 2.375000014901161, 0.5, -1.1875000596046448, 0.1875, -0.4375, 1.7499999105930328, 1.0000000298023224, 1.2500000596046448, -1.0625, -0.5625000596046448, -0.625, -2.4999999403953552, -0.5, 1.3124999403953552, -0.8750000596046448, -1.8749999403953552, 0.0, -2.6875000298023224, -0.7500000596046448, -1.1875000298023224, 1.8749999552965164, 0.1875]</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id \\\n15855 0_D \n\n content_ids \\\n15855 [1672, 1005, 1507, 2532, 1392, 706, 2488, 2306, 315, 1345, 1516, 4377, 5251, 328, 5594, 3524, 3353, 2518, 871, 4557, 158, 4465, 1999, 987, 1963, 2051, 2181, 2449, 638, 5622, 5528, 4789, 3110, 2586, 256, 1226, 488, 1316, 1011, 5587, 4149, 1670, 657, 4051, 5241, 1090, 1941, 1336] \n\n true_id true_content_index \\\n15855 1672 0 \n\n teacher_logits \n15855 [3.5, 5.249999903142452, 5.499999988824129, 5.9999997690320015, 4.249999910593033, 4.812499798834324, 2.187500074505806, 3.937499776482582, 4.812499836087227, -0.5624999403953552, 4.187499985098839, -0.125, 1.6875000298023224, 3.937499962747097, 3.2500000298023224, 1.5000000596046448, 1.4375, 3.1250000447034836, -0.9375, 1.1874999403953552, 0.0625, 2.499999985098839, -0.625, -1.4999999701976776, 3.187500074505806, -1.3749999701976776, 2.375000014901161, 0.5, -1.1875000596046448, 0.1875, -0.4375, 1.7499999105930328, 1.0000000298023224, 1.2500000596046448, -1.0625, -0.5625000596046448, -0.625, -2.4999999403953552, -0.5, 1.3124999403953552, -0.8750000596046448, -1.8749999403953552, 0.0, -2.6875000298023224, -0.7500000596046448, -1.1875000298023224, 1.8749999552965164, 0.1875] "
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "fef36b89",
"cell_type": "code",
"source": "def stable_softmax(x, temp=1.0):\n x = np.array(x) / temp\n x_max = np.max(x)\n exp_x = np.exp(x - x_max)\n return exp_x / np.sum(exp_x)",
"execution_count": 46,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "25e8a828",
"cell_type": "code",
"source": "pred_df['teacher_probs'] = pred_df['teacher_logits'].apply(stable_softmax)",
"execution_count": 47,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "88703d1b",
"cell_type": "code",
"source": "pred_df.sample()",
"execution_count": 48,
"outputs": [
{
"data": {
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>query_id</th>\n <th>content_ids</th>\n <th>true_id</th>\n <th>true_content_index</th>\n <th>teacher_logits</th>\n <th>teacher_probs</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>2618</th>\n <td>1382_B</td>\n <td>[172, 1081, 2078, 1166, 4701, 1246, 1510, 1048, 633, 2170, 5017, 1909, 546, 684, 907, 1514, 1206, 5457, 4658, 217, 2147, 896, 2128, 1916]</td>\n <td>172</td>\n <td>0</td>\n <td>[7.312499975785613, 2.0625, 3.312500111758709, -0.8125, 1.1874999701976776, 0.9375000298023224, 1.5624999403953552, 2.3124999552965164, -1.5625, -0.8125, -1.2500000596046448, 1.3125, 0.125, 1.5625, 6.687499949708581, 2.062500014901161, 3.687500111758709, -0.8749999403953552, 0.8125, 1.2499999701976776, 1.3124999701976776, -0.1875, 0.4375, 2.937499925494194]</td>\n <td>[0.612506021827765, 0.0032141366969789405, 0.011218440638378482, 0.00018132918428097404, 0.001339851475094765, 0.0010434774401992801, 0.0019494723350276118, 0.004127033027136988, 8.56538416901439e-05, 0.00018132918428097404, 0.00011707491365117373, 0.0015182506715903333, 0.0004630402934190451, 0.0019494724512252213, 0.3278508396706364, 0.0032141367448733096, 0.016322734814244625, 0.00017034301453609357, 0.0009208655814488337, 0.0014262644710112977, 0.001518250626342938, 0.0003387675154974123, 0.0006329010413373982, 0.007710312539352721]</td>\n </tr>\n </tbody>\n</table>\n</div>",
"text/plain": " query_id \\\n2618 1382_B \n\n content_ids \\\n2618 [172, 1081, 2078, 1166, 4701, 1246, 1510, 1048, 633, 2170, 5017, 1909, 546, 684, 907, 1514, 1206, 5457, 4658, 217, 2147, 896, 2128, 1916] \n\n true_id true_content_index \\\n2618 172 0 \n\n teacher_logits \\\n2618 [7.312499975785613, 2.0625, 3.312500111758709, -0.8125, 1.1874999701976776, 0.9375000298023224, 1.5624999403953552, 2.3124999552965164, -1.5625, -0.8125, -1.2500000596046448, 1.3125, 0.125, 1.5625, 6.687499949708581, 2.062500014901161, 3.687500111758709, -0.8749999403953552, 0.8125, 1.2499999701976776, 1.3124999701976776, -0.1875, 0.4375, 2.937499925494194] \n\n teacher_probs \n2618 [0.612506021827765, 0.0032141366969789405, 0.011218440638378482, 0.00018132918428097404, 0.001339851475094765, 0.0010434774401992801, 0.0019494723350276118, 0.004127033027136988, 8.56538416901439e-05, 0.00018132918428097404, 0.00011707491365117373, 0.0015182506715903333, 0.0004630402934190451, 0.0019494724512252213, 0.3278508396706364, 0.0032141367448733096, 0.016322734814244625, 0.00017034301453609357, 0.0009208655814488337, 0.0014262644710112977, 0.001518250626342938, 0.0003387675154974123, 0.0006329010413373982, 0.007710312539352721] "
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "2a4c088f",
"cell_type": "code",
"source": "pred_df['pos_score'] = pred_df.apply(lambda row: row['teacher_probs'][row['true_content_index']] if row['true_content_index'] != -1 else 0, axis=1)\npred_df['pos_score'].describe()",
"execution_count": 49,
"outputs": [
{
"data": {
"text/plain": "count 16706.000000\nmean 0.477579\nstd 0.319052\nmin 0.000087\n25% 0.174123\n50% 0.465565\n75% 0.775886\nmax 0.999168\nName: pos_score, dtype: float64"
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "ed051fb2",
"cell_type": "code",
"source": "all_content_ids = set(pred_df['true_id'].values.tolist())\nprint(len(all_content_ids))",
"execution_count": 50,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "4118\n"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "356b8dd4",
"cell_type": "code",
"source": "def filter_content_ids(row):\n return [cid for cid, score in zip(row['content_ids'], row['teacher_probs']) if score < row['cutoff']]\n\ndef fill_to_n(row, n=24):\n if row['num_filtered'] >= n:\n return row['filtered_content_ids'][:n]\n else:\n additional_needed = n - row['num_filtered']\n candidates = list(all_content_ids - set(row['filtered_content_ids']))\n additional_ids = random.sample(candidates, additional_needed)\n return row['filtered_content_ids'] + additional_ids",
"execution_count": 51,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "114fd2c1",
"cell_type": "code",
"source": "margin = 0.9\n\ncurr_df = deepcopy(pred_df)\ncurr_df['cutoff'] = curr_df['pos_score'].apply(lambda x: x * margin)\ncurr_df['filtered_content_ids'] = curr_df.apply(filter_content_ids, axis=1)\ncurr_df['num_filtered'] = curr_df['filtered_content_ids'].apply(len)",
"execution_count": 52,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "1a04a9d2",
"cell_type": "code",
"source": "# curr_df['num_filtered'].value_counts()",
"execution_count": 53,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "315ed324",
"cell_type": "code",
"source": "# curr_df['num_filtered'].value_counts()",
"execution_count": 54,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "27be88ab",
"cell_type": "code",
"source": "curr_df['final_content_ids'] = curr_df.apply(fill_to_n, axis=1)\ncurr_df['final_num_filtered'] = curr_df['final_content_ids'].apply(len)\n\nneg_map = dict(zip(curr_df['query_id'], curr_df['final_content_ids']))",
"execution_count": 55,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "534bbac4",
"cell_type": "code",
"source": "curr_df['final_num_filtered'].value_counts()",
"execution_count": 56,
"outputs": [
{
"data": {
"text/plain": "final_num_filtered\n24 16706\nName: count, dtype: int64"
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "5e041615",
"cell_type": "code",
"source": "len(neg_map['780_D'])",
"execution_count": 57,
"outputs": [
{
"data": {
"text/plain": "24"
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"id": "8fe176ea",
"cell_type": "markdown",
"source": "# Save"
},
{
"metadata": {
"trusted": true
},
"id": "ea320d41",
"cell_type": "code",
"source": "save_dir = \"../data/embedding_mix/silver_v3\"\nos.makedirs(save_dir, exist_ok=True)",
"execution_count": 58,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "cec0e856",
"cell_type": "code",
"source": "with open(os.path.join(save_dir, f\"hn_mapping.json\"), \"w\") as f:\n json.dump(neg_map, f)",
"execution_count": 59,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "9834b807",
"cell_type": "code",
"source": "keep_cols = [\n 'QuestionId',\n 'ConstructId',\n 'ConstructName',\n 'SubjectId',\n 'SubjectName',\n \n 'CorrectAnswer',\n 'QuestionText',\n 'AnswerAText',\n 'AnswerBText',\n 'AnswerCText',\n 'AnswerDText',\n \n 'MisconceptionAId',\n 'MisconceptionBId',\n 'MisconceptionCId',\n 'MisconceptionDId',\n]\n\nvalid_df = valid_df[keep_cols].copy()\nmcq_df = mcq_df[keep_cols].copy()\n\n# ff_df = pd.concat([mcq_df, valid_df]).reset_index(drop=True)\nff_df = mcq_df.copy() # pd.concat([mcq_df, valid_df]).reset_index(drop=True)\n\nff_df.to_csv(f\"{save_dir}/train.csv\", index=False)",
"execution_count": 60,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "9c7a7bf1",
"cell_type": "code",
"source": "content_df.to_csv(f\"{save_dir}/misconception_mapping.csv\", index=False)",
"execution_count": 61,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "38b090ae",
"cell_type": "code",
"source": "with open(os.path.join(save_dir, f\"teacher_mapping.json\"), \"w\") as f:\n json.dump(teacher_score_map, f)",
"execution_count": 62,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "7ccc177a",
"cell_type": "code",
"source": "kagglehub.dataset_upload(\"conjuring92/eedi-embed-mix-silver-v3\", save_dir)",
"execution_count": 63,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "Uploading Dataset https://www.kaggle.com/datasets/conjuring92/eedi-embed-mix-silver-v3 ...\nStarting upload for file ../data/embedding_mix/silver_v3/teacher_mapping.json\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "Uploading: 100%|█| 13.7M/13.7M [00:08<00:00, 1.60"
},
{
"name": "stdout",
"output_type": "stream",
"text": "Upload successful: ../data/embedding_mix/silver_v3/teacher_mapping.json (13MB)\nStarting upload for file ../data/embedding_mix/silver_v3/hn_mapping.json\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "\nUploading: 100%|█| 2.52M/2.52M [00:02<00:00, 881k"
},
{
"name": "stdout",
"output_type": "stream",
"text": "Upload successful: ../data/embedding_mix/silver_v3/hn_mapping.json (2MB)\nStarting upload for file ../data/embedding_mix/silver_v3/misconception_mapping.csv\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "\nUploading: 100%|█| 323k/323k [00:02<00:00, 154kB/"
},
{
"name": "stdout",
"output_type": "stream",
"text": "Upload successful: ../data/embedding_mix/silver_v3/misconception_mapping.csv (315KB)\nStarting upload for file ../data/embedding_mix/silver_v3/train.csv\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "\nUploading: 100%|█| 3.13M/3.13M [00:02<00:00, 1.09"
},
{
"name": "stdout",
"output_type": "stream",
"text": "Upload successful: ../data/embedding_mix/silver_v3/train.csv (3MB)\n"
},
{
"name": "stderr",
"output_type": "stream",
"text": "\n"
},
{
"name": "stdout",
"output_type": "stream",
"text": "Your dataset instance has been created.\nFiles are being processed...\nSee at: https://www.kaggle.com/datasets/conjuring92/eedi-embed-mix-silver-v3\n"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "9110b40c",
"cell_type": "code",
"source": "# teacher_score_map",
"execution_count": 96,
"outputs": []
},
{
"metadata": {},
"id": "d850043b",
"cell_type": "markdown",
"source": "# Sanity Checks"
},
{
"metadata": {
"trusted": true
},
"id": "6be32429",
"cell_type": "code",
"source": "import os\nimport json\nimport pandas as pd\n\nsave_dir = \"../data/embedding_mix/silver_v3\"\nfiles_to_check = [\"hn_mapping.json\", \"teacher_mapping.json\", \"misconception_mapping.csv\", \"train.csv\"]\n\nfor file in files_to_check:\n assert os.path.exists(os.path.join(save_dir, file)), f\"{file} does not exist in {save_dir}\"\n\nwith open(os.path.join(save_dir, \"hn_mapping.json\"), \"r\") as f:\n hn_mapping = json.load(f)\n\nwith open(os.path.join(save_dir, \"teacher_mapping.json\"), \"r\") as f:\n teacher_mapping = json.load(f)\n\nmisconception_df = pd.read_csv(os.path.join(save_dir, \"misconception_mapping.csv\"))\ntrain_df = pd.read_csv(os.path.join(save_dir, \"train.csv\"))\n\nassert len(hn_mapping) > 0, \"hn_mapping is empty\"\nsample_key = next(iter(hn_mapping))\nassert isinstance(hn_mapping[sample_key], list), \"hn_mapping values should be lists\"\n\nassert len(teacher_mapping) > 0, \"teacher_mapping is empty\"\nsample_key = next(iter(teacher_mapping))\nassert isinstance(teacher_mapping[sample_key], (int, float)), \"teacher_mapping values should be numeric\"\n\nassert not misconception_df.empty, \"misconception_df is empty\"\nassert set(misconception_df.columns) == {\"MisconceptionName\", \"MisconceptionId\"}, \"Unexpected columns in misconception_df\"\n\nassert not train_df.empty, \"train_df is empty\"\nexpected_columns = {\n 'QuestionId', 'ConstructId', 'ConstructName', 'SubjectId', 'SubjectName',\n 'CorrectAnswer', 'QuestionText', 'AnswerAText', 'AnswerBText', 'AnswerCText', 'AnswerDText',\n 'MisconceptionAId', 'MisconceptionBId', 'MisconceptionCId', 'MisconceptionDId'\n}\nassert set(train_df.columns) == expected_columns, \"Unexpected columns in train_df\"\n\nassert not train_df['QuestionId'].isna().any(), \"NaN values found in QuestionId column\"\nassert not train_df['CorrectAnswer'].isna().any(), \"NaN values found in CorrectAnswer column\"\n\n# Check data types\nassert train_df['QuestionId'].dtype == 'int64', \"QuestionId should be int64\"\nassert train_df['CorrectAnswer'].isin(['A', 'B', 'C', 'D']).all(), \"CorrectAnswer should only contain A, B, C, or D\"\n\n# Check consistency between files\ntrain_misconceptions = set(train_df['MisconceptionAId'].dropna()) | set(train_df['MisconceptionBId'].dropna()) | \\\n set(train_df['MisconceptionCId'].dropna()) | set(train_df['MisconceptionDId'].dropna())\nmapping_misconceptions = set(misconception_df['MisconceptionId'])\nassert train_misconceptions.issubset(mapping_misconceptions), \"Misconceptions in train_df not found in misconception_df\"\n\ntrain_question_ids = set(train_df['QuestionId'])\nhn_question_ids = set(int(key.split('_')[0]) for key in hn_mapping.keys())\nassert train_question_ids.issubset(hn_question_ids), \"Not all questions in train_df have entries in hn_mapping\"\n\nprint(\"All sanity checks passed successfully!\")",
"execution_count": 64,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "All sanity checks passed successfully!\n"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "b1dcab3b",
"cell_type": "code",
"source": "len(train_question_ids.difference(set(hn_question_ids)))",
"execution_count": 65,
"outputs": [
{
"data": {
"text/plain": "0"
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "64bb0a93",
"cell_type": "code",
"source": "# hn_mapping['1829_D']",
"execution_count": 66,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "824b871c",
"cell_type": "code",
"source": "# train_question_ids",
"execution_count": 67,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "f1abff55",
"cell_type": "code",
"source": "assert train_df['QuestionId'].nunique() == len(train_df), \"Duplicate QuestionIds found in train_df\"\n\nhn_misconceptions = set()\nfor misconceptions in hn_mapping.values():\n hn_misconceptions.update(misconceptions)\nassert hn_misconceptions.issubset(set(misconception_df['MisconceptionId'])), \"Some MisconceptionIds in hn_mapping not found in misconception_df\"\n\nassert all(len(key.split('|')) == 2 for key in teacher_mapping.keys()), \"Unexpected format in teacher_mapping keys\"\n\ntrain_misconceptions = pd.concat([train_df[f'Misconception{letter}Id'].dropna() for letter in 'ABCD'])\nassert set(train_misconceptions).issubset(set(misconception_df['MisconceptionId'])), \"Some MisconceptionIds in train_df not found in misconception_df\"\n\nmisconception_counts = train_df[['MisconceptionAId', 'MisconceptionBId', 'MisconceptionCId', 'MisconceptionDId']].notna().sum(axis=1)\nassert misconception_counts.max() <= 3, \"Some questions have more than 3 misconceptions\"\n\nfor _, row in train_df.iterrows():\n correct_answer = row['CorrectAnswer']\n assert pd.isna(row[f'Misconception{correct_answer}Id']), f\"Correct answer has a misconception for QuestionId {row['QuestionId']}\"\n\nprint(\"All additional sanity checks passed successfully!\")",
"execution_count": 68,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "All additional sanity checks passed successfully!\n"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "0ee49cb9",
"cell_type": "code",
"source": "train_df.QuestionId.value_counts()",
"execution_count": 69,
"outputs": [
{
"data": {
"text/plain": "QuestionId\n300000 1\n15238 1\n15245 1\n15248 1\n15251 1\n ..\n401476 1\n401478 1\n401486 1\n401490 1\n1856 1\nName: count, Length: 10594, dtype: int64"
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": true
},
"id": "1eb18664",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "cca950ed",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "34bdbd50",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "7167a0df",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"id": "1909808a",
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3 (ipykernel)",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.10.15",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"vscode": {
"interpreter": {
"hash": "9966d838c5789fe326f76162c2a6fc0341b2fd9319a92dbbd869a89bb7177318"
}
},
"gist": {
"id": "",
"data": {
"description": "20_embed_dataset_v3.ipynb",
"public": false
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment