Skip to content

Instantly share code, notes, and snippets.

@hamelsmu
Last active August 27, 2019 01:27
Show Gist options
  • Save hamelsmu/17d88ab978fc166f42e47c29c04e942c to your computer and use it in GitHub Desktop.
Save hamelsmu/17d88ab978fc166f42e47c29c04e942c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# coding: utf-8
# This notebook illustrates the use of a utility, `InferenceWrapper.df_to_emb` that can be used to perform inference in bulk.
# - **checkpointed model** (2.29 GB):
# `https://storage.googleapis.com/issue_label_bot/model/lang_model/models_22zkdqlr/best_22zkdqlr.pth`
# this file imports https://github.com/kubeflow/code-intelligence/blob/master/Issue_Embeddings/flask_app/inference.py
from inference import InferenceWrapper, pass_through
import pandas as pd
from numpy import concatenate as cat
import numpy as np
# !wget https://storage.googleapis.com/issue_label_bot/model/lang_model/models_22zkdqlr/trained_model_22zkdqlr.pkl
# #### Create an `InferenceWrapper` object:
wrapper = InferenceWrapper(model_path='/ds/notebooks',
model_file_name='trained_model_22zkdqlr.pkl')
##### Load the GFI Dataset
gfidf = pd.read_csv('gfi_data_all.csv')
train_mask = gfidf.split_name != 'test'
test_mask = gfidf.split_name == 'test'
# cutoff crazy long issues
len_cutoff = int(gfidf[train_mask].body.str.len().quantile(.95))
print(f'95th percentile body length: {len_cutoff:,}')
gfidf.body = gfidf.body.str[:len_cutoff]
assert gfidf.body.str.len().max() <= len_cutoff
# # Perform Batch Inference To Get Embeddings
# This retrieves the document embeddings for each issue
train_embeddings = wrapper.df_to_emb(gfidf[train_mask])
test_embeddings = wrapper.df_to_emb(gfidf[test_mask])
with open('gfi_train_emb.npy', 'wb') as f:
np.save(f, train_embeddings)
with open('gfi_test_emb.npy', 'wb') as f:
np.save(f, test_embeddings)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment