Created
January 25, 2023 20:44
-
-
Save nousr/4cdca926ca5ad21a89f0e80d341d4b15 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@lru_cache(maxsize=None) | |
def load_safety_model(clip_model): | |
"""load the safety model""" | |
import torch # pylint: disable=import-outside-toplevel | |
import autokeras as ak # pylint: disable=import-outside-toplevel | |
from tensorflow.keras.models import load_model # pylint: disable=import-outside-toplevel | |
class H14_NSFW_Detector(nn.Module): | |
def __init__(self, input_size=1024): | |
super().__init__() | |
self.input_size = input_size | |
self.layers = nn.Sequential( | |
nn.Linear(self.input_size, 1024), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 2048), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(2048, 1024), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 256), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(256, 128), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(128, 16), | |
nn.Linear(16, 1) | |
) | |
def forward(self, x): | |
return self.layers(x) | |
cache_folder = get_cache_folder(clip_model) | |
if clip_model == "ViT-L/14": | |
model_dir = cache_folder + "/clip_autokeras_binary_nsfw" | |
dim = 768 | |
elif clip_model == "ViT-B/32": | |
model_dir = cache_folder + "/clip_autokeras_nsfw_b32" | |
dim = 512 | |
elif clip_model == "open_clip:ViT-H-14": | |
model_dir = cache_folder + "/h14_nsfw_detector" | |
else: | |
raise ValueError(f"Safety model for {clip_model} not available.") | |
if not os.path.exists(model_dir): | |
os.makedirs(cache_folder, exist_ok=True) | |
from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel | |
path_to_zip_file = cache_folder + "/clip_autokeras_binary_nsfw.zip" | |
if clip_model == "ViT-L/14": | |
url_model = "https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip" | |
elif clip_model == "ViT-B/32": | |
url_model = ( | |
"https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_nsfw_b32.zip" | |
) | |
elif clip_model == "open_clip:ViT-H-14": | |
url_model = "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/raw/main/h14_nsfw.pth" | |
else: | |
raise ValueError("Unknown model {}".format(clip_model)) # pylint: disable=consider-using-f-string | |
urlretrieve(url_model, path_to_zip_file) | |
import zipfile # pylint: disable=import-outside-toplevel | |
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: | |
zip_ref.extractall(cache_folder) | |
FAKE_BATCH_SIZE = 10**3 | |
FAKE_BATCH = np.random.rand(FAKE_BATCH_SIZE, dim).astype("float32") | |
if clip_model == "open_clip:ViT-H-14": | |
state = torch.load(os.path.join(model_dir, "h14_nsfw.pth"), map_location="cpu") | |
loaded_model = H14_NSFW_Detector() | |
loaded_model.load_state_dict(state) | |
loaded_model(torch.from_numpy(FAKE_BATCH)) | |
else: | |
loaded_model = load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS) | |
loaded_model.predict(FAKE_BATCH, batch_size=FAKE_BATCH_SIZE) | |
return loaded_model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment