Last active
May 20, 2022 09:12
-
-
Save derms/8d2ad9c94691eac8abc1b4ab08ca7eb0 to your computer and use it in GitHub Desktop.
Fraud_Detection_with_GNN_v4.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/derms/8d2ad9c94691eac8abc1b4ab08ca7eb0/fraud_detection_with_gnn_v4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"**Colab-Specific Imports**" | |
], | |
"metadata": { | |
"id": "-cEs7vglbZL5" | |
}, | |
"id": "-cEs7vglbZL5" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install pyTigerGraph class_resolver\n", | |
"import torch\n", | |
"\n", | |
"def format_pytorch_version(version):\n", | |
" return version.split('+')[0]\n", | |
"\n", | |
"TORCH_version = torch.__version__\n", | |
"TORCH = format_pytorch_version(TORCH_version)\n", | |
"\n", | |
"def format_cuda_version(version):\n", | |
" return 'cu' + version.replace('.', '')\n", | |
"\n", | |
"CUDA_version = torch.version.cuda\n", | |
"CUDA = format_cuda_version(CUDA_version)\n", | |
"\n", | |
"!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n", | |
"!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n", | |
"!pip install torch-geometric " | |
], | |
"metadata": { | |
"id": "ZMd9jct4bX0P" | |
}, | |
"id": "ZMd9jct4bX0P", | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "62e15a33-2a71-403e-b5b2-192fcb8adef3", | |
"metadata": { | |
"jp-MarkdownHeadingCollapsed": true, | |
"tags": [], | |
"id": "62e15a33-2a71-403e-b5b2-192fcb8adef3" | |
}, | |
"source": [ | |
"## **Import Libraries and Functions**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8ec21336-bc07-4719-8e2b-6c0d65bac1d9", | |
"metadata": { | |
"id": "8ec21336-bc07-4719-8e2b-6c0d65bac1d9" | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"import warnings\n", | |
"import torch_geometric.transforms as T\n", | |
"import torch\n", | |
"import torch.nn.functional as F\n", | |
"from torch_geometric.nn import GCN\n", | |
"from torch_geometric.nn import GNNExplainer\n", | |
"from datetime import datetime\n", | |
"from pyTigerGraph.gds.metrics import Accumulator, Accuracy\n", | |
"from torch.utils.tensorboard import SummaryWriter\n", | |
"import matplotlib.pyplot as plt\n", | |
"plt.rcParams[\"figure.figsize\"] = (30,15)\n", | |
"from sklearn.metrics import roc_auc_score, average_precision_score, precision_score, recall_score, roc_curve, precision_recall_curve, f1_score\n", | |
"\n", | |
"def threshold_search(y_true, y_proba):\n", | |
" precision , recall, thresholds = precision_recall_curve(y_true, y_proba)\n", | |
" thresholds = np.append(thresholds, 1.0001) \n", | |
" F = 2 / (1/precision + 1/recall)\n", | |
" best_score = np.max(F)\n", | |
" best_th = thresholds[np.argmax(F)]\n", | |
" return best_th \n", | |
"\n", | |
"log_dir = \"work/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", | |
"train_log = SummaryWriter(log_dir+\"/train\")\n", | |
"valid_log = SummaryWriter(log_dir+\"/valid\")\n", | |
"\n", | |
"def sample_vertex():\n", | |
" for i in range(100):\n", | |
" node_idx = int(np.random.choice(np.where(y_pred1==1)[0],1))\n", | |
" try:\n", | |
" x, edge_index = batch.x, batch.edge_index\n", | |
" explainer = GNNExplainer(model, epochs=1)\n", | |
" node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)\n", | |
" ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=batch.y)\n", | |
" return node_idx\n", | |
" except:\n", | |
" pass\n", | |
" \n", | |
"warnings.filterwarnings(\"ignore\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "709e2a9d-4a27-4da6-91b4-cd3f619fc7dc", | |
"metadata": { | |
"id": "709e2a9d-4a27-4da6-91b4-cd3f619fc7dc" | |
}, | |
"source": [ | |
"## **Connect to TigerGraph Database**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "26ba49cd-b7ca-4e77-8643-685b67df78ce", | |
"metadata": { | |
"id": "26ba49cd-b7ca-4e77-8643-685b67df78ce" | |
}, | |
"outputs": [], | |
"source": [ | |
"from pyTigerGraph import TigerGraphConnection\n", | |
"\n", | |
"conn = TigerGraphConnection(\n", | |
" host=\"https://<your-server>.i.tgcloud.io/\", # Change the address to your database server's\n", | |
" graphname=\"Ethereum\",\n", | |
" username=\"tigergraph\",\n", | |
" password=\"tigergraph\"\n", | |
")\n", | |
"conn.apiToken = conn.getToken(\"kcuup7k2eb08gkplk17pvqre824bbo1v\", 86400) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e8e6a0d6-ff6b-4e66-af50-412ddd030765", | |
"metadata": { | |
"id": "e8e6a0d6-ff6b-4e66-af50-412ddd030765" | |
}, | |
"source": [ | |
"## **Add Degree Features inside TigerGraph**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "18736e09-6465-4e7a-bc34-0154d418d507", | |
"metadata": { | |
"id": "18736e09-6465-4e7a-bc34-0154d418d507" | |
}, | |
"outputs": [], | |
"source": [ | |
"conn.runInstalledQuery(\"degrees\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "dfa44145-f8de-4946-8b94-f990e92f0312", | |
"metadata": { | |
"id": "dfa44145-f8de-4946-8b94-f990e92f0312" | |
}, | |
"source": [ | |
"## **Add Amounts Features inside TigerGraph**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "42c4fa3f-6f3c-4c3f-8c71-3a9ef874c9d2", | |
"metadata": { | |
"id": "42c4fa3f-6f3c-4c3f-8c71-3a9ef874c9d2" | |
}, | |
"outputs": [], | |
"source": [ | |
"conn.runInstalledQuery(\"amounts\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7da014d6-5bf7-4f58-9aca-0013a3779b78", | |
"metadata": { | |
"id": "7da014d6-5bf7-4f58-9aca-0013a3779b78" | |
}, | |
"source": [ | |
"## **Check Count of Positive Labels**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5baf1203-e21c-44fa-a80f-11c2ff119bb6", | |
"metadata": { | |
"id": "5baf1203-e21c-44fa-a80f-11c2ff119bb6" | |
}, | |
"outputs": [], | |
"source": [ | |
"conn.getVertexCount(\"Account\",\"is_fraud = 1\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8313fd70-b751-423a-a0f0-8836408e1c36", | |
"metadata": { | |
"id": "8313fd70-b751-423a-a0f0-8836408e1c36" | |
}, | |
"source": [ | |
"## **Check Count of Negative Labels**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e12752fc-2e40-4560-a939-8d75d77c2c18", | |
"metadata": { | |
"id": "e12752fc-2e40-4560-a939-8d75d77c2c18" | |
}, | |
"outputs": [], | |
"source": [ | |
"conn.getVertexCount(\"Account\",\"is_fraud = 0\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a9fedffc-f323-4fe8-af4b-3cad2123b529", | |
"metadata": { | |
"id": "a9fedffc-f323-4fe8-af4b-3cad2123b529" | |
}, | |
"source": [ | |
"## **Run Train-Test Split inside TigerGraph**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "501ef419-b6f5-4681-925e-cd2e8400466f", | |
"metadata": { | |
"id": "501ef419-b6f5-4681-925e-cd2e8400466f" | |
}, | |
"outputs": [], | |
"source": [ | |
"%%time\n", | |
"split = conn.gds.vertexSplitter(is_training=0.8, is_validation=0.2)\n", | |
"split.run()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ad0214c2-f58c-4706-bea6-3420b701e877", | |
"metadata": { | |
"id": "ad0214c2-f58c-4706-bea6-3420b701e877" | |
}, | |
"source": [ | |
"## **Add PageRank Features inside TigerGraph**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "8e636ba6-b7c4-4076-b0eb-7d3957afa1db", | |
"metadata": { | |
"id": "8e636ba6-b7c4-4076-b0eb-7d3957afa1db" | |
}, | |
"outputs": [], | |
"source": [ | |
"feat = conn.gds.featurizer()\n", | |
"feat.installAlgorithm(\"tg_pagerank\")\n", | |
"tg_pagerank_params = {\n", | |
" \"v_type\": \"Account\",\n", | |
" \"e_type\": \"Transaction\",\n", | |
" \"result_attr\": \"pagerank\",\n", | |
" \"top_k\":5 \n", | |
"}\n", | |
"results = pd.json_normalize(feat.runAlgorithm(\"tg_pagerank\",tg_pagerank_params)[0]['@@top_scores_heap'])\n", | |
"results" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "56a1ae09-3f81-46b7-9046-1328da207924", | |
"metadata": { | |
"id": "56a1ae09-3f81-46b7-9046-1328da207924" | |
}, | |
"source": [ | |
"## **Add Hyperparameters for NeighborSampler and Graph Neural Network**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "b0c6390f-c337-4633-8756-118a6e5467be", | |
"metadata": { | |
"id": "b0c6390f-c337-4633-8756-118a6e5467be" | |
}, | |
"outputs": [], | |
"source": [ | |
"# Hyperparameters\n", | |
"hp = {\"batch_size\": 5000, \"num_neighbors\": 200, \"num_hops\": 3, \"hidden_dim\": 128, \"num_layers\": 2, \"dropout\": 0.05, \"lr\": 0.0075, \"l2_penalty\": 5e-5}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "da071c0b-4faf-4dbe-8848-e5a073afc169", | |
"metadata": { | |
"id": "da071c0b-4faf-4dbe-8848-e5a073afc169" | |
}, | |
"source": [ | |
"## **Define Train Neighbor Loader using TigerGraph**\n", | |
"\n", | |
"Output is provided directly in PyTorch Geometric Format" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "3b9ba960-429d-45c3-a4e2-76a155d4b233", | |
"metadata": { | |
"id": "3b9ba960-429d-45c3-a4e2-76a155d4b233" | |
}, | |
"outputs": [], | |
"source": [ | |
"train_loader = conn.gds.neighborLoader(\n", | |
" v_in_feats=[\"in_degree\",\"out_degree\",\"send_amount\",\"send_min\",\"recv_amount\",\"recv_min\",\"pagerank\"],\n", | |
" v_out_labels=[\"is_fraud\"],\n", | |
" v_extra_feats=[\"is_training\"],\n", | |
" output_format=\"PyG\",\n", | |
" batch_size=hp[\"batch_size\"],\n", | |
" num_neighbors=hp[\"num_neighbors\"],\n", | |
" num_hops=hp[\"num_hops\"],\n", | |
" filter_by = \"is_training\",\n", | |
" shuffle=True,\n", | |
" timeout=600000\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "48054085-ef03-4637-9015-8ad145626da1", | |
"metadata": { | |
"id": "48054085-ef03-4637-9015-8ad145626da1" | |
}, | |
"source": [ | |
"## **Define Validation Neighbor Loader using TigerGraph**\n", | |
"\n", | |
"Output is provided directly in PyTorch Geometric Format" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "67828279-31d2-47a3-ba2e-04b63c7833ee", | |
"metadata": { | |
"id": "67828279-31d2-47a3-ba2e-04b63c7833ee" | |
}, | |
"outputs": [], | |
"source": [ | |
"valid_loader = conn.gds.neighborLoader(\n", | |
" v_in_feats=[\"in_degree\",\"out_degree\",\"send_amount\",\"send_min\",\"recv_amount\",\"recv_min\",\"pagerank\"],\n", | |
" v_out_labels=[\"is_fraud\"],\n", | |
" v_extra_feats=[\"is_validation\"],\n", | |
" output_format=\"PyG\",\n", | |
" batch_size=hp[\"batch_size\"],\n", | |
" num_neighbors=hp[\"num_neighbors\"],\n", | |
" num_hops=hp[\"num_hops\"],\n", | |
" filter_by = \"is_validation\",\n", | |
" shuffle=True,\n", | |
" timeout=600000\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6d9cf220-0db3-4aa0-9721-7e1c22d0495a", | |
"metadata": { | |
"id": "6d9cf220-0db3-4aa0-9721-7e1c22d0495a" | |
}, | |
"source": [ | |
"## **Define Graph Convolutional Network Architecture using PyTorch Geometric**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e4b1f30a-a392-4855-b8c4-27720d1ff0d4", | |
"metadata": { | |
"id": "e4b1f30a-a392-4855-b8c4-27720d1ff0d4" | |
}, | |
"outputs": [], | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"\n", | |
"model = GCN(\n", | |
" in_channels=7,\n", | |
" hidden_channels=hp[\"hidden_dim\"],\n", | |
" num_layers=hp[\"num_layers\"],\n", | |
" out_channels=2,\n", | |
" dropout=hp[\"dropout\"],\n", | |
").to(device)\n", | |
"\n", | |
"optimizer = torch.optim.Adam(\n", | |
" model.parameters(), lr=hp[\"lr\"], weight_decay=hp[\"l2_penalty\"]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1c28c7c4-a261-42e0-b84d-4fd8505f3d0d", | |
"metadata": { | |
"id": "1c28c7c4-a261-42e0-b84d-4fd8505f3d0d" | |
}, | |
"source": [ | |
"## **Run Training Loop with AUC, AUCPR, Precision and Recall as Metrics**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "68eedee7-3350-479a-ae9a-a140859659dd", | |
"metadata": { | |
"id": "68eedee7-3350-479a-ae9a-a140859659dd" | |
}, | |
"outputs": [], | |
"source": [ | |
"global_steps = 0\n", | |
"logs = {}\n", | |
"for epoch in range(10):\n", | |
" # Train\n", | |
" print(\"Start Training epoch:\", epoch)\n", | |
" model.train()\n", | |
" epoch_train_loss = Accumulator()\n", | |
" \n", | |
" epoch_train_auc = []\n", | |
" epoch_train_prec = []\n", | |
" epoch_train_rec = []\n", | |
" epoch_train_apr = []\n", | |
" epoch_best_thr = []\n", | |
" \n", | |
" # Iterate through the loader to get a stream of subgraphs instead of the whole graph\n", | |
" for bid, batch in enumerate(train_loader):\n", | |
" # print(bid, batch)\n", | |
" if (batch.y.sum()==0):\n", | |
" continue\n", | |
" batchsize = batch.x.shape[0]\n", | |
" norm = T.NormalizeFeatures()\n", | |
" batch = norm(batch).to(device)\n", | |
" batch.x = batch.x.type(torch.FloatTensor)\n", | |
" batch.y = batch.y.type(torch.LongTensor)\n", | |
" \n", | |
" # Forward pass\n", | |
" out = model(batch.x, batch.edge_index, batch.edge_weight)\n", | |
" # Calculate loss\n", | |
" class_weight = torch.FloatTensor([1.0, 15.0])\n", | |
" loss = F.cross_entropy(out[batch.is_training], batch.y[batch.is_training], class_weight)\n", | |
" # f1_loss(batch.y[batch.is_training], out[batch.is_training], is_training=True)\n", | |
" # Backward pass\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" epoch_train_loss.update(loss.item() * batchsize, batchsize)\n", | |
" # Predict on training data\n", | |
" with torch.no_grad():\n", | |
" pred = out.argmax(dim=1)\n", | |
" y_pred = out[batch.is_training][:,1].cpu().numpy()\n", | |
" y_true = batch.y[batch.is_training].cpu().numpy()\n", | |
" # softmax = F.softmax(out, dim=1)[batch.is_training][:,1].cpu().numpy()\n", | |
" best_threshold = threshold_search(y_true, y_pred)\n", | |
" y_pred1 = (y_pred > best_threshold).astype(int)\n", | |
" \n", | |
" epoch_train_auc.append(roc_auc_score(y_true, y_pred))\n", | |
" epoch_train_prec.append(precision_score(y_true, y_pred1))\n", | |
" epoch_train_rec.append(recall_score(y_true, y_pred1))\n", | |
" epoch_train_apr.append(average_precision_score(y_true, y_pred))\n", | |
" epoch_best_thr.append(best_threshold)\n", | |
" \n", | |
" # Log training status after each batch\n", | |
" logs[\"loss\"] = epoch_train_loss.mean\n", | |
" logs[\"auc\"] = np.mean(epoch_train_auc)\n", | |
" logs[\"prec\"] = np.mean(epoch_train_prec)\n", | |
" logs[\"rec\"] = np.mean(epoch_train_rec)\n", | |
" logs[\"apr\"] = np.mean(epoch_train_apr)\n", | |
" logs[\"thr\"] = np.mean(epoch_best_thr)\n", | |
" \n", | |
" print(\n", | |
" \"Epoch {}, Train Batch {}, Loss {:.4f}, AUC {:.4f}, AUCPR {:.4f}, Precision {:.4f}, Recall {:.4f}\".format(\n", | |
" epoch, bid, logs[\"loss\"], logs[\"auc\"], logs[\"apr\"], logs[\"prec\"], logs[\"rec\"]\n", | |
" )\n", | |
" )\n", | |
" train_log.add_scalar(\"Loss\", logs[\"loss\"], global_steps)\n", | |
" train_log.add_scalar(\"AUC\", logs[\"auc\"], global_steps)\n", | |
" train_log.add_scalar(\"AUCPR\", logs[\"apr\"], global_steps)\n", | |
" train_log.flush()\n", | |
" global_steps += 1\n", | |
" # Evaluate\n", | |
" print(\"Start validation epoch:\", epoch)\n", | |
" model.eval()\n", | |
" epoch_val_loss = Accumulator()\n", | |
" epoch_val_prec = []\n", | |
" epoch_val_rec = []\n", | |
" epoch_val_auc = []\n", | |
" epoch_val_apr = []\n", | |
" \n", | |
" for batch in valid_loader:\n", | |
" batchsize = batch.x.shape[0]\n", | |
" norm = T.NormalizeFeatures()\n", | |
" batch = norm(batch).to(device)\n", | |
" with torch.no_grad():\n", | |
" # Forward pass\n", | |
" batch.x = batch.x.type(torch.FloatTensor)\n", | |
" batch.y = batch.y.type(torch.LongTensor)\n", | |
" out = model(batch.x, batch.edge_index) \n", | |
" # Calculate loss\n", | |
" class_weight = torch.FloatTensor([1.0, 20.0])\n", | |
" valid_loss = F.cross_entropy(out[batch.is_validation], batch.y[batch.is_validation], class_weight)\n", | |
" # f1_loss(batch.y[batch.is_validation], out[batch.is_validation])\n", | |
" epoch_val_loss.update(valid_loss.item() * batchsize, batchsize)\n", | |
" # Prediction\n", | |
" pred = out.argmax(dim=1)\n", | |
" y_pred = out[batch.is_validation][:,1].cpu().numpy()\n", | |
" y_true = batch.y[batch.is_validation].cpu().numpy()\n", | |
" # softmax = F.softmax(out, dim=1)[batch.is_validation][:,1].cpu().numpy()\n", | |
" y_pred1 = (y_pred > np.mean(epoch_best_thr)).astype(int)\n", | |
" \n", | |
" epoch_val_auc.append(roc_auc_score(y_true, y_pred))\n", | |
" epoch_val_prec.append(precision_score(y_true, y_pred1))\n", | |
" epoch_val_rec.append(recall_score(y_true, y_pred1))\n", | |
" epoch_val_apr.append(average_precision_score(y_true, y_pred))\n", | |
"\n", | |
" # Log testing result after each epoch\n", | |
" logs[\"val_loss\"] = epoch_val_loss.mean\n", | |
" logs[\"val_prec\"] = np.mean(epoch_val_prec)\n", | |
" logs[\"val_auc\"] = np.mean(epoch_val_auc)\n", | |
" logs[\"val_rec\"] = np.mean(epoch_val_rec)\n", | |
" logs[\"val_apr\"] = np.mean(epoch_val_apr)\n", | |
" print(\n", | |
" \"Epoch {}, Valid Loss {:.4f}, Valid AUC {:.4f}, Valid AUCPR {:.4f}, Valid Precision {:.4f}, Valid Recall {:.4f}\".format(\n", | |
" epoch, logs[\"val_loss\"], logs[\"val_auc\"], logs[\"val_apr\"], logs[\"val_prec\"], logs[\"val_rec\"]\n", | |
" )\n", | |
" )\n", | |
" valid_log.add_scalar(\"Loss\", logs[\"val_loss\"], global_steps)\n", | |
" valid_log.add_scalar(\"AUC\", logs[\"val_auc\"], global_steps)\n", | |
" valid_log.add_scalar(\"AUCPR\", logs[\"val_apr\"], global_steps)\n", | |
" valid_log.flush()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1918943d-6830-4bdf-8fd9-459e9601a812", | |
"metadata": { | |
"id": "1918943d-6830-4bdf-8fd9-459e9601a812" | |
}, | |
"source": [ | |
"## **Define and Run Explainability Model**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "cd1a1b0e-ed7a-4b58-94e0-8d22363ca13d", | |
"metadata": { | |
"id": "cd1a1b0e-ed7a-4b58-94e0-8d22363ca13d" | |
}, | |
"outputs": [], | |
"source": [ | |
"node_idx = sample_vertex()\n", | |
"x, edge_index = batch.x, batch.edge_index\n", | |
"explainer = GNNExplainer(model, epochs=100)\n", | |
"node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4de33869-2490-47aa-a7bc-da2626a0b3d0", | |
"metadata": { | |
"id": "4de33869-2490-47aa-a7bc-da2626a0b3d0" | |
}, | |
"source": [ | |
"## **Show Local Feature Importance**" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4b29d41d-e4c0-41c6-835f-789191ba931c", | |
"metadata": { | |
"id": "4b29d41d-e4c0-41c6-835f-789191ba931c" | |
}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"from IPython.display import display\n", | |
"\n", | |
"ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y = batch.y)\n", | |
"feature_values = list()\n", | |
"feature_values.append([\"{}\".format(node_idx)] + node_feat_mask.tolist()) # Center ID\n", | |
"important_accounts = set()\n", | |
"for src, dst, attr in G.edges(data=True):\n", | |
" edge_importance = attr[\"att\"]\n", | |
" if edge_importance >= 0.001:\n", | |
" if src == node_idx:\n", | |
" important_accounts.add(dst)\n", | |
" elif dst == node_idx:\n", | |
" important_accounts.add(src)\n", | |
"\n", | |
"subg_accts = list(important_accounts)\n", | |
"for acct_idx in important_accounts:\n", | |
" target_id = acct_idx\n", | |
" node_feat_mask, _ = explainer.explain_node(acct_idx, x, edge_index)\n", | |
" feature_values.append([\"{} ({})\".format(target_id, acct_idx)] + node_feat_mask.tolist())\n", | |
"feature_names = ['in_degree', 'out_degree', 'send_amount', 'send_min', 'recv_amount', 'recv_min', 'pagerank']\n", | |
"df = pd.DataFrame(feature_values, columns=[\"Account ID\"] + feature_names).set_index(\"Account ID\")\n", | |
"print(\"---- Normalized Feature Importance\")\n", | |
"print(\" \")\n", | |
"display(df.style.background_gradient(axis=1))\n", | |
"print(\"---- Normalized Feature Values\")\n", | |
"print(\" \")\n", | |
"feature_df = pd.DataFrame(batch.x.cpu().numpy(), columns = feature_names)\n", | |
"display(feature_df.loc[subg_accts].rename(index=lambda acct: \"{}\".format(node_idx)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## **Show TensorBoard**" | |
], | |
"metadata": { | |
"id": "trn7y606tWQO" | |
}, | |
"id": "trn7y606tWQO" | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%load_ext tensorboard\n", | |
"%tensorboard --logdir work" | |
], | |
"metadata": { | |
"id": "mnucJhGOtVUX" | |
}, | |
"id": "mnucJhGOtVUX", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Tigergraph Pytorch", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.12" | |
}, | |
"colab": { | |
"name": "Fraud_Detection_with_GNN_v4.ipynb", | |
"provenance": [], | |
"toc_visible": true, | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment