Created
October 10, 2022 01:43
-
-
Save crosstyan/a2dd8ba6e479e6002c553acd7ee050b5 to your computer and use it in GitHub Desktop.
source from discord SD Training Labs.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
modules\prompt_parser.py file. | |
v2.pt can be loaded by putting it in the main folder of the repo and adding | |
--------------------------------------------------------------------------- | |
import torch | |
from torch import nn | |
from modules import devices | |
class VectorAdjustPrior(nn.Module): | |
def __init__(self, hidden_size, inter_dim=64): | |
super().__init__() | |
self.vector_proj = nn.Linear(hidden_size*2, inter_dim, bias=True) | |
self.out_proj = nn.Linear(hidden_size+inter_dim, hidden_size, bias=True) | |
def forward(self, z): | |
b, s = z.shape[0:2] | |
x1 = torch.mean(z, dim=1).repeat(s, 1) | |
x2 = z.reshape(b*s, -1) | |
x = torch.cat((x1, x2), dim=1) | |
x = self.vector_proj(x) | |
x = torch.cat((x2, x), dim=1) | |
x = self.out_proj(x) | |
x = x.reshape(b, s, -1) | |
return x | |
@classmethod | |
def load_model(cls, model_path, hidden_size=768, inter_dim=64): | |
model = cls(hidden_size=hidden_size, inter_dim=inter_dim) | |
model.load_state_dict(torch.load(model_path)["state_dict"]) | |
model.to(devices.device) | |
return model | |
vap = VectorAdjustPrior.load_model('v2.pt').cuda() | |
------------------------------------------------------------------ | |
after the | |
---------------------------- | |
import lark | |
--------------------------- | |
line and adding | |
--------------------- | |
conds = vap(conds) | |
--------------------- | |
between | |
--------------------------------------------- | |
conds = model.get_learned_conditioning(texts) | |
--------------------------------------------- | |
and | |
--------------------------------------------- | |
cond_schedule = [] lines. | |
--------------------------------------------- | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment