Created
February 12, 2024 08:15
-
-
Save laksjdjf/dd317efdb6e4320dfae9203aca5c6290 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
load from sampling/custom_sampling/scheulers | |
input text like "999,893,...,156" | |
connect to SamplerCustom | |
''' | |
import torch | |
class TextScheduler: | |
@classmethod | |
def INPUT_TYPES(s): | |
return {"required":{"model": ("MODEL",), "timesteps": ("STRING", {"multiline": True}), "verbose": ("BOOLEAN", )}} | |
RETURN_TYPES = ("SIGMAS",) | |
CATEGORY = "sampling/custom_sampling/schedulers" | |
FUNCTION = "get_sigmas" | |
def get_sigmas(self, model, timesteps, verbose): | |
timesteps = [float(timestep) for timestep in timesteps.replace(" ", "").split(",")] | |
sigmas = model.model.model_sampling.sigma(torch.tensor(timesteps)) | |
sigmas = torch.cat([sigmas, torch.tensor([0])]) | |
if verbose: | |
print("sigmas:", sigmas.tolist()) | |
return (sigmas, ) | |
NODE_CLASS_MAPPINGS = { | |
"TextScheduler": TextScheduler, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment