Skip to content

Instantly share code, notes, and snippets.

@cycyyy
Created March 6, 2021 03:52
Show Gist options
  • Save cycyyy/9d66143aa080b008134c2eb826b2d1d5 to your computer and use it in GitHub Desktop.
Save cycyyy/9d66143aa080b008134c2eb826b2d1d5 to your computer and use it in GitHub Desktop.
from deepctr_torch.inputs import SparseFeat, DenseFeat
import numpy as np
import torch
from torch import nn
import torch.utils.data as td
import torch.nn.functional as F
from tqdm import tqdm
import sys
MAX = sys.maxsize
#torch.set_deterministic(True)
torch.manual_seed(0)
np.random.seed(0)
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import prepare_data
dpath = os.getenv('BBPATH', '..')
device = torch.device("cuda:0")
#device = torch.device("cpu")
small_dataset = False
print(device, small_dataset)
use_dram = True
class SparseOnlyModel(torch.nn.Module):
def __init__(self, feature_columns, hidden_size, batch_size, binary=False, dims=128):
super(SparseOnlyModel, self).__init__()
self.binary = binary
# self.embedding_tables = nn.ModuleList()
self.cache_tables = nn.ModuleList()
self.embedding_tables = []
input_size = 0
for feature_column in feature_columns:
self.embedding_tables.append(nn.Embedding(feature_column.vocabulary_size, dims, sparse=True))
self.cache_tables.append(nn.Embedding(batch_size, dims, sparse=True))
input_size += dims
self.current_mapping = torch.full((len(feature_columns), batch_size), MAX, dtype=torch.long, device=device)
self.fc1 = nn.Linear(input_size, hidden_size[0])
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_size[0], hidden_size[1])
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_size[1], hidden_size[2])
self.relu3 = nn.ReLU()
self.fc4 = nn.Linear(hidden_size[2], 1)
if binary == True:
self.sigmoid = nn.Sigmoid()
def init_dram(self):
if use_dram == False:
return
for i in range(0, len(self.embedding_tables)):
self.embedding_tables[i].to('cpu')
def get_mapped_idx(self, x, y):
# index = torch.argsort(x)
# sorted_x = x[index]
# sorted_index = torch.searchsorted(sorted_x, y)
# return torch.take(index, sorted_index)
return torch.searchsorted(x, y)
def load(self, x):
veces = []
for i in range(0, len(self.embedding_tables)):
unique, _ = torch.unique(x[:, i], sorted=True, return_inverse=True)
self.current_mapping[i, 0:len(unique)] = unique
self.current_mapping[i, len(unique):] = MAX
self.cache_tables[i].weight.data[0:len(unique)] = self.embedding_tables[i].weight.data[unique].to(device)
real_idx = self.get_mapped_idx(self.current_mapping[i], x[:, i])
veces.append(self.cache_tables[i](real_idx))
return torch.cat(veces, 1)
def store(self):
if use_dram:
for i in range(0, len(self.embedding_tables)):
validate_items = torch.where(self.current_mapping[i] == MAX)[0]
if len(validate_items) == 0:
validate_items = len(self.current_mapping[i])
else:
validate_items = validate_items[0]
self.embedding_tables[i].weight.data[self.current_mapping[i][0:validate_items]] = self.cache_tables[i].weight.data[0:validate_items].to('cpu')
self.current_mapping[:] = MAX
def forward(self, x):
x = x.to(device)
if use_dram:
x = self.load(x)
else:
veces = []
for i in range(0, len(self.embedding_tables)):
veces.append(self.embedding_tables[i](x[:, i]))
x = torch.cat(veces, 1)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
x = self.relu3(x)
x = self.fc4(x)
if self.binary == True:
return self.sigmoid(x)
return x
def get_moivelen():
return prepare_data.build_movielens1m(path=dpath+"/movielens/ml-1m", cache_folder=dpath+"/.cache")
def get_criteo():
# return prepare_data.build_criteo(path=dpath+"/criteo/train.txt", cache_folder=dpath+"/.cache")
return prepare_data.build_avazu(path=dpath+"/avazu/train", cache_folder=dpath+"/.cache")
def generate_input():
if small_dataset:
feature_columns, _, raw_data, input_data, target = get_moivelen()
else:
feature_columns, _, raw_data, input_data, target = get_criteo()
y = raw_data[target].to_numpy()
del raw_data
feature_list = []
x = []
for feature_column in feature_columns:
if isinstance(feature_column, SparseFeat):
feature_list.append(feature_column)
x.append(input_data[feature_column.embedding_name].to_numpy())
x = np.array(x).T[:]
y = y[:]
train_tensor_data = td.TensorDataset(torch.from_numpy(x), torch.from_numpy(y))
return train_tensor_data, feature_list
def train(batch_size, epoch, device):
train_tensor_data, feature_list = generate_input()
train_loader = td.DataLoader(dataset=train_tensor_data, batch_size=batch_size)
if small_dataset:
binary = False
else:
binary = True
model = SparseOnlyModel(feature_list, [512, 256, 64], batch_size, binary).to(device)
print(model)
model.init_dram()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
if small_dataset:
loss_fuc = F.mse_loss
else:
loss_fuc = F.binary_cross_entropy
for e in range(epoch):
total_loss = 0.0
with tqdm(enumerate(train_loader), total=len(train_loader)) as t:
for index, (x, y) in t:
optimizer.zero_grad()
# x = x.to(device)
pred_y = model(x)
y = y.to(device).float()
loss = loss_fuc(pred_y, y)
total_loss += loss
loss.backward()
optimizer.step()
model.store()
print(e, ":", total_loss / len(train_loader))
if small_dataset:
train(2048, 10, device)
else:
train(8192, 1, device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment