Skip to content

Instantly share code, notes, and snippets.

@franckjay
Last active March 25, 2024 02:54
Show Gist options
  • Save franckjay/0bffd79f05ffa0b8aeca5366c09328d9 to your computer and use it in GitHub Desktop.
Save franckjay/0bffd79f05ffa0b8aeca5366c09328d9 to your computer and use it in GitHub Desktop.
DataSet for PyTorch with multiple embeddings
class DictDataset(Dataset):
def __init__(self, data_dict, norm_target=1, scaler=None):
self.norm_target = norm_target
self.data_df = build_pandas_ranking(data_dict)
self.scaler = scaler
# Build out the features that are continuous variables
self.float_features = []
for feat in self.data_df.columns:
valid = True
for non_float_keyword in [
"user_id",
"product_id",
"combined_score",
"target",
]:
if non_float_keyword in feat:
valid = False
continue
if valid:
self.float_features.append(feat)
self.u_cats = [
_feat for _feat in self.data_df.columns if _feat.startswith("user_id")
]
self.i_cats = [
_feat for _feat in self.data_df.columns if _feat.startswith("product_id")
]
self.targets = [
_feat
for _feat in self.data_df.columns
if _feat.startswith("combined_score")
]
logging.debug(
"Float features: %s",
[_ for _ in self.float_features if "embedding" not in _],
)
if not self.scaler:
# If we haven't trained a scaler yet, do so here
self.scaler = RobustScaler().fit(self.data_df[self.float_features])
# Scale the data from the float features into its own dataframe
self.float_df = pd.DataFrame(
self.scaler.transform(self.data_df[self.float_features]),
columns=self.float_features,
)
# Drop these from the normal DF as they are un-scaled
self.data_df = self.data_df.drop(self.float_features, axis=1)
def __getitem__(self, index):
return (
torch.Tensor(self.float_df.iloc[index].values),
torch.Tensor(self.data_df[self.u_cats].iloc[index].values),
torch.Tensor(self.data_df[self.i_cats].iloc[index].values),
torch.Tensor(
self.data_df[self.targets].iloc[index].values / self.norm_target
),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment