Last active
August 29, 2022 14:29
-
-
Save Pangoraw/07f21ae6ad25bf549db585e5f355ef2a to your computer and use it in GitHub Desktop.
Stable diffusion txt2img using the CPU
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From 4e32b5ebfc3bb54cabf192488e43e768781eeafb Mon Sep 17 00:00:00 2001 | |
From: Paul Berg <[email protected]> | |
Date: Mon, 29 Aug 2022 16:28:55 +0200 | |
Subject: [PATCH] Allow using the CPU for txt2img | |
--- | |
configs/stable-diffusion/v1-inference.yaml | 2 ++ | |
scripts/txt2img.py | 18 ++++++++++++------ | |
2 files changed, 14 insertions(+), 6 deletions(-) | |
diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml | |
index d4effe5..b7239eb 100644 | |
--- a/configs/stable-diffusion/v1-inference.yaml | |
+++ b/configs/stable-diffusion/v1-inference.yaml | |
@@ -68,3 +68,5 @@ model: | |
cond_stage_config: | |
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder | |
+ params: | |
+ device: "cpu" | |
diff --git a/scripts/txt2img.py b/scripts/txt2img.py | |
index 59c16a1..c49d450 100644 | |
--- a/scripts/txt2img.py | |
+++ b/scripts/txt2img.py | |
@@ -45,7 +45,7 @@ def numpy_to_pil(images): | |
return pil_images | |
-def load_model_from_config(config, ckpt, verbose=False): | |
+def load_model_from_config(config, ckpt, device="cuda", verbose=False): | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
@@ -60,7 +60,7 @@ def load_model_from_config(config, ckpt, verbose=False): | |
print("unexpected keys:") | |
print(u) | |
- model.cuda() | |
+ model.to(device) | |
model.eval() | |
return model | |
@@ -226,6 +226,12 @@ def main(): | |
choices=["full", "autocast"], | |
default="autocast" | |
) | |
+ parser.add_argument( | |
+ "--device", | |
+ type=str, | |
+ choices=["cuda", "cpu"], | |
+ default="cuda", | |
+ ) | |
opt = parser.parse_args() | |
if opt.laion400m: | |
@@ -237,15 +243,15 @@ def main(): | |
seed_everything(opt.seed) | |
config = OmegaConf.load(f"{opt.config}") | |
- model = load_model_from_config(config, f"{opt.ckpt}") | |
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
+ device = opt.device | |
+ model = load_model_from_config(config, f"{opt.ckpt}", device) | |
model = model.to(device) | |
if opt.plms: | |
- sampler = PLMSSampler(model) | |
+ sampler = PLMSSampler(model, device=device) | |
else: | |
- sampler = DDIMSampler(model) | |
+ sampler = DDIMSampler(model, device=device) | |
os.makedirs(opt.outdir, exist_ok=True) | |
outpath = opt.outdir | |
-- | |
2.34.1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment