Last active
September 7, 2023 16:51
-
-
Save napsternxg/e5ac5dd87ae4313e52ab489f612855e4 to your computer and use it in GitHub Desktop.
Sentence Transformer + Setfit classification head for inference without installing setfit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset, Dataset, DatasetDict | |
from sentence_transformers.losses import CosineSimilarityLoss | |
from sentence_transformers import SentenceTransformer | |
from setfit import SetFitModel, SetFitTrainer, sample_dataset | |
from sklearn.model_selection import train_test_split | |
import pandas as pd | |
import numpy as np | |
import json | |
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union | |
from pathlib import Path | |
import time | |
from tqdm import trange | |
import torch | |
from sklearn.metrics import classification_report | |
class SetFitModelFixed(SetFitModel): | |
@classmethod | |
def from_pretrained(cls, *args, include_layers=None, **kwargs): | |
# Allow dropping layers similar to miniLM-v2-L3 and L6 | |
obj = super(SetFitModelFixed, cls).from_pretrained(*args, **kwargs) | |
if include_layers is not None: | |
auto_model = obj.model_body._modules['0'].auto_model | |
auto_model.config.num_hidden_layers = len(include_layers) | |
auto_model.encoder.layer = torch.nn.ModuleList([ | |
l | |
for i, l in enumerate(auto_model.encoder.layer) | |
if i in include_layers | |
]) | |
obj.model_body._modules['0'].auto_model = auto_model | |
return obj | |
def fit( | |
self, | |
x_train: List[str], | |
y_train: Union[List[int], List[List[int]]], | |
num_epochs: int, | |
batch_size: Optional[int] = None, | |
learning_rate: Optional[float] = None, | |
body_learning_rate: Optional[float] = None, | |
l2_weight: Optional[float] = None, | |
max_length: Optional[int] = None, | |
show_progress_bar: Optional[bool] = None, | |
class_weights: Optional[List[float]] = None | |
) -> None: | |
if self.has_differentiable_head: # train with pyTorch | |
device = self.model_body.device | |
self.model_body.train() | |
self.model_head.train() | |
dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length) | |
criterion = self.model_head.get_loss_fn() | |
if hasattr(self, "class_weights"): | |
print(f"Using {self.class_weights=}") | |
# This hack allows us to bypass passing class weight via trainer which is TODO | |
class_weights = self.class_weights | |
if class_weights is not None: | |
print(f"Using {class_weights=}") | |
criterion.weight = torch.Tensor(class_weights).to(self.model_head.device) | |
optimizer = self._prepare_optimizer(learning_rate, body_learning_rate, l2_weight) | |
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5) | |
for epoch_idx in trange(num_epochs, desc="Epoch", disable=not show_progress_bar): | |
for batch in dataloader: | |
features, labels = batch | |
optimizer.zero_grad() | |
# to model's device | |
features = {k: v.to(device) for k, v in features.items()} | |
labels = labels.to(device) | |
outputs = self.model_body(features) | |
if self.normalize_embeddings: | |
embeddings = outputs["sentence_embedding"] | |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
outputs["sentence_embedding"] = embeddings | |
outputs = self.model_head(outputs) | |
logits = outputs["logits"] | |
loss = criterion(logits, labels) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
else: # train with sklearn | |
embeddings = self.model_body.encode(x_train, normalize_embeddings=self.normalize_embeddings) | |
self.model_head.fit(embeddings, y_train) | |
def save_pretrained(self, model_save_path, class_names, non_askic_eligible_idx): | |
super().save_pretrained(model_save_path) | |
print(f"Saving extra items like model_head, model_head.config, sentence_transformer_classifier") | |
torch.save(self.model_head, Path(model_save_path) / "model_head.pt") | |
torch.save(self.model_head.state_dict(), Path(model_save_path) / "model_head.state_dict.pt") | |
with open(Path(model_save_path) / "model_head.config.json", "w+") as fp: | |
json.dump(self.model_head.get_config_dict(), fp) | |
with open(Path(model_save_path) / "sentence_transformer_classifier.config.json", "w+") as fp: | |
json.dump(dict( | |
class_names=list([c.item() for c in class_names]), | |
marginalize_negative=list(non_askic_eligible_idx) | |
), fp) | |
# class_weights = compute_class_weight("balanced", classes=class_names, y=df_train[label_col]) | |
include_layers = {0, 5, 11} # default=None, select every 5th layer | |
model = SetFitModelFixed.from_pretrained( | |
model_type, | |
use_differentiable_head=True, | |
head_params={"out_features": num_classes}, | |
normalize_embeddings=normalize_embeddings, | |
multi_target_strategy=multi_target_strategy, # "one-vs-rest" | |
model_kwargs={"max_seq_length": 128}, | |
include_layers=include_layers | |
) | |
model.class_weights = class_weights | |
def merge_models(base, other_models): | |
for k, v in base.state_dict().items(): | |
print(k, base.state_dict()[k].shape) | |
base.state_dict()[k] += sum([other_model.state_dict()[k] for other_model in other_models]) | |
base.state_dict()[k] /= len(other_models) + 1 | |
print(base.state_dict()[k].shape) | |
"""Usage: | |
merge_models(model.model_body, [model_A.model_body, model_B.model_body]) | |
merge_models(model.model_head, [model_A.model_head, model_B.model_head]) | |
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import subprocess | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Dict, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
from huggingface_hub import PyTorchModelHubMixin | |
from instalog import instalog | |
from sentence_transformers import SentenceTransformer, models | |
from torch import nn | |
from functools import lru_cache | |
class ClassificationHead(models.Dense): | |
""" | |
A ClassificationHead head that supports multi-class classification for end-to-end training. | |
Binary classification is treated as 2-class classification. | |
To be compatible with Sentence Transformers, we inherit `Dense` from: | |
https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Dense.py | |
Args: | |
in_features (`int`, *optional*): | |
The embedding dimension from the output of the sentence_transformer body. If `None`, defaults to `LazyLinear`. | |
out_features (`int`, defaults to `2`): | |
The number of targets. If set `out_features` to 1 for binary classification, it will be changed to 2 as 2-class classification. | |
temperature (`float`, defaults to `1.0`): | |
A logits' scaling factor. Higher values make the model less confident and lower values make | |
it more confident. | |
eps (`float`, defaults to `1e-5`): | |
A value for numerical stability when scaling logits. | |
bias (`bool`, *optional*, defaults to `True`): | |
Whether to add bias to the head. | |
device (`torch.device`, str, *optional*): | |
The device the model will be sent to. If `None`, will check whether GPU is available. | |
multitarget (`bool`, defaults to `False`): | |
Enable multi-target classification by making `out_features` binary predictions instead | |
of a single multinomial prediction. | |
""" | |
def __init__( | |
self, | |
in_features: Optional[int] = None, | |
out_features: int = 2, | |
temperature: float = 1.0, | |
eps: float = 1e-5, | |
bias: bool = True, | |
device: Optional[Union[torch.device, str]] = None, | |
multitarget: bool = False, | |
) -> None: | |
super(models.Dense, self).__init__() # init on models.Dense's parent: nn.Module | |
if out_features == 1: | |
logger.warning( | |
"Change `out_features` from 1 to 2 since we use `CrossEntropyLoss` for binary classification." | |
) | |
out_features = 2 | |
if in_features is not None: | |
self.linear = nn.Linear(in_features, out_features, bias=bias) | |
else: | |
self.linear = nn.LazyLinear(out_features, bias=bias) | |
self.in_features = in_features | |
self.out_features = out_features | |
self.temperature = temperature | |
self.eps = eps | |
self.bias = bias | |
self._device = device or "cuda" if torch.cuda.is_available() else "cpu" | |
self.multitarget = multitarget | |
self.to(self._device) | |
self.apply(self._init_weight) | |
def forward( | |
self, | |
features: Union[Dict[str, torch.Tensor], torch.Tensor], | |
temperature: Optional[float] = None, | |
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor]]: | |
""" | |
SetFitHead can accept embeddings in: | |
1. Output format (`dict`) from Sentence-Transformers. | |
2. Pure `torch.Tensor`. | |
Args: | |
features (`Dict[str, torch.Tensor]` or `torch.Tensor): | |
The embeddings from the encoder. If using `dict` format, | |
make sure to store embeddings under the key: 'sentence_embedding' | |
and the outputs will be under the key: 'prediction'. | |
temperature (`float`, *optional*): | |
A logits' scaling factor. Higher values make the model less | |
confident and lower values make it more confident. | |
Will override the temperature given during initialization. | |
Returns: | |
[`Dict[str, torch.Tensor]` or `Tuple[torch.Tensor]`] | |
""" | |
temperature = temperature or self.temperature | |
is_features_dict = False # whether `features` is dict or not | |
if isinstance(features, dict): | |
assert "sentence_embedding" in features | |
is_features_dict = True | |
x = features["sentence_embedding"] if is_features_dict else features | |
logits = self.linear(x) | |
logits = logits / (temperature + self.eps) | |
if self.multitarget: # multiple targets per item | |
probs = torch.sigmoid(logits) | |
else: # one target per item | |
probs = nn.functional.softmax(logits, dim=-1) | |
if is_features_dict: | |
features.update( | |
{ | |
"logits": logits, | |
"probs": probs, | |
} | |
) | |
return features | |
return logits, probs | |
def predict_proba(self, x_test: torch.Tensor) -> torch.Tensor: | |
self.eval() | |
return self(x_test)[1] | |
def predict(self, x_test: torch.Tensor) -> torch.Tensor: | |
probs = self.predict_proba(x_test) | |
if self.multitarget: | |
return torch.where(probs >= 0.5, 1, 0) | |
return torch.argmax(probs, dim=-1) | |
def get_loss_fn(self): | |
if self.multitarget: # if sigmoid output | |
return torch.nn.BCEWithLogitsLoss() | |
return torch.nn.CrossEntropyLoss() | |
def get_config_dict(self) -> Dict[str, Optional[Union[int, float, bool]]]: | |
return { | |
"in_features": self.in_features, | |
"out_features": self.out_features, | |
"temperature": self.temperature, | |
"bias": self.bias, | |
"device": self.device.type, # store the string of the device, instead of `torch.device` | |
} | |
@property | |
def device(self) -> torch.device: | |
""" | |
`torch.device`: The device on which the model is placed. | |
Reference from: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py#L869 | |
""" | |
return next(self.parameters()).device | |
@staticmethod | |
def _init_weight(module): | |
if isinstance(module, nn.Linear): | |
torch.nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
torch.nn.init.constant_(module.bias, 1e-2) | |
def __repr__(self): | |
return f"{type(self).__name__}({self.get_config_dict()})" | |
@dataclass | |
class SentenceTransformerClassifier(nn.Module, PyTorchModelHubMixin): | |
def __init__( | |
self, | |
model_body: Optional[SentenceTransformer] = None, | |
model_head: Optional[ClassificationHead] = None, | |
multi_target_strategy: Optional[str] = None, | |
l2_weight: float = 1e-2, | |
normalize_embeddings: bool = False, | |
class_names: Optional[List[str]] = None, | |
marginalize_negative: Optional[List[int]] = None, | |
): | |
super().__init__() | |
self.model_body = model_body | |
self.model_head = model_head | |
self.multi_target_strategy = multi_target_strategy | |
self.l2_weight = l2_weight | |
self.normalize_embeddings = normalize_embeddings | |
self.class_names = class_names | |
self.marginalize_negative = marginalize_negative | |
def forward(self, inputs): | |
embeddings = self.model_body.encode( | |
inputs, | |
normalize_embeddings=self.normalize_embeddings, | |
convert_to_tensor=True, | |
) | |
outputs = self.model_head(embeddings) | |
return outputs | |
def _output_type_conversion( | |
self, outputs: torch.Tensor, as_numpy: bool = False | |
) -> Union[torch.Tensor, np.ndarray]: | |
"""Return `outputs` in the desired type: | |
* Numpy array if no differentiable head is used. | |
* Torch tensor if a differentiable head is used. | |
Returns: | |
Union[torch.Tensor, np.ndarray]: The input, correctly converted to the desired type. | |
""" | |
if as_numpy: | |
outputs = outputs.detach().cpu().numpy() | |
return outputs | |
def _marginalize_prediction(self, preds: torch.Tensor) -> torch.Tensor: | |
if not self.marginalize_negative: | |
raise RuntimeError( | |
f"Please initialize marginalize_negative. Found: {self.marginalize_negative}." | |
) | |
return 1 - preds[:, self.marginalize_negative].sum(axis=-1) | |
def predict( | |
self, | |
x_test: List[str], | |
as_numpy: bool = False, | |
marginalize: bool = False, | |
threshold: float = 0.5, | |
) -> Union[torch.Tensor, "ndarray"]: | |
with torch.no_grad(): | |
embeddings = self.model_body.encode( | |
x_test, | |
normalize_embeddings=self.normalize_embeddings, | |
convert_to_tensor=True, | |
) | |
if marginalize: | |
outputs = self.model_head.predict_proba(embeddings) | |
outputs = self._marginalize_prediction(outputs) > threshold | |
else: | |
outputs = self.model_head.predict(embeddings) | |
return self._output_type_conversion(outputs, as_numpy=as_numpy) | |
def predict_proba( | |
self, x_test: List[str], as_numpy: bool = False, marginalize: bool = False | |
) -> Union[torch.Tensor, "ndarray"]: | |
with torch.no_grad(): | |
embeddings = self.model_body.encode( | |
x_test, | |
normalize_embeddings=self.normalize_embeddings, | |
convert_to_tensor=True, | |
) | |
outputs = self.model_head.predict_proba(embeddings) | |
if marginalize: | |
outputs = self._marginalize_prediction(outputs) | |
return self._output_type_conversion(outputs, as_numpy=as_numpy) | |
class_names = [ | |
"Brand Comparison", | |
"Comparative or Substitutes", | |
"Complements and Pairings", | |
"Conceptual Attributes", | |
"Contextual Ideas", | |
"Health", | |
"Instacart", | |
"Occassions", | |
"Product", | |
"Products with attributes", | |
"Shopping Lists", | |
] | |
non_askic_eligible_idx = [5, 6, 8, 9] | |
def load_classifier(model_path): | |
head_config = { | |
"in_features": 384, | |
"out_features": 11, | |
"temperature": 1.0, | |
"bias": True, | |
"device": "cpu", | |
} | |
model_body = SentenceTransformer(model_path).to("cpu") | |
model_body | |
model_head = ClassificationHead(**head_config) | |
model_head | |
model_head_state_dict = torch.load( | |
Path(model_path) / "model_head.state_dict.pt", map_location=torch.device("cpu") | |
) | |
model_head_state_dict | |
model_head.load_state_dict(model_head_state_dict) | |
st_model = SentenceTransformerClassifier( | |
model_body=model_body, model_head=model_head, class_names=class_names | |
) | |
st_model = st_model.eval() | |
return st_model | |
class Classifier(object): | |
def __init__(self, model_path, threshhold) -> None: | |
self.model_path = model_path | |
self.threshhold = threshhold | |
self.model = load_classifier(model_path) | |
@lru_cache(maxsize=100_000) | |
def __call__(self, query): | |
prediction = self.model.predict_proba([query], marginalize=True) | |
output = dict( | |
is_true=prediction[0] > self.threshhold, | |
true_prob=prediction[0], | |
) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment