diff --git a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py index 9164de9f..cd68727c 100644 --- a/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py +++ b/extensions-builtin/sd_forge_controlnet/scripts/controlnet.py @@ -7,6 +7,7 @@ import torch import modules.scripts as scripts from modules import shared, script_callbacks, masking, images from modules.ui_components import InputAccordion +from modules.api.api import decode_base64_to_image import gradio as gr from lib_controlnet import global_state, external_code diff --git a/modules/api/api.py b/modules/api/api.py index 25ce7ca0..78d10969 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -24,6 +24,7 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin +from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import Any @@ -724,7 +725,7 @@ class Api: def get_sd_models(self): import modules.sd_models as sd_models - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename} for x in sd_models.checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in sd_models.checkpoints_list.values()] def get_sd_vaes(self): import modules.sd_vae as sd_vae diff --git a/modules/initialize.py b/modules/initialize.py index dd55d6c3..ec4d58a4 100644 --- a/modules/initialize.py +++ b/modules/initialize.py @@ -10,6 +10,16 @@ from threading import Thread from modules.timer import startup_timer +class HiddenPrints: + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout + + def imports(): logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -25,8 +35,16 @@ def imports(): import gradio # noqa: F401 startup_timer.record("import gradio") - from modules import paths, timer, import_hook, errors # noqa: F401 - startup_timer.record("setup paths") + with HiddenPrints(): + from modules import paths, timer, import_hook, errors # noqa: F401 + startup_timer.record("setup paths") + + import ldm.modules.encoders.modules # noqa: F401 + import ldm.modules.diffusionmodules.model + startup_timer.record("import ldm") + + import sgm.modules.encoders.modules # noqa: F401 + startup_timer.record("import sgm") from modules import shared_init shared_init.initialize() @@ -123,6 +141,11 @@ def initialize_rest(*, reload_script_modules=False): textual_inversion.textual_inversion.list_textual_inversion_templates() startup_timer.record("refresh textual inversion templates") + from modules import script_callbacks, sd_hijack_optimizations, sd_hijack + script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) + sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + from modules import sd_unet sd_unet.list_unets() startup_timer.record("scripts list_unets") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 8c1823a1..f933de64 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -391,15 +391,15 @@ def prepare_environment(): openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") assets_repo = os.environ.get('ASSETS_REPO', "https://github.com/AUTOMATIC1111/stable-diffusion-webui-assets.git") - # stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") - # stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") + stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') huggingface_guess_repo = os.environ.get('HUGGINGFACE_GUESS_REPO', 'https://github.com/lllyasviel/huggingface_guess.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') assets_commit_hash = os.environ.get('ASSETS_COMMIT_HASH', "6f7db241d2f8ba7457bac5ca9753331f0c266917") - # stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") - # stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") + stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") huggingface_guess_commit_hash = os.environ.get('HUGGINGFACE_GUESS_HASH', "78f7d1da6a00721a6670e33a9132fd73c4e987b4") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -456,8 +456,8 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(assets_repo, repo_dir('stable-diffusion-webui-assets'), "assets", assets_commit_hash) - # git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) - # git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) + git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(huggingface_guess_repo, repo_dir('huggingface_guess'), "huggingface_guess", huggingface_guess_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/paths.py b/modules/paths.py index 18494b6c..501ff658 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -36,8 +36,8 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl mute_sdxl_imports() path_dirs = [ - # (sd_path, 'ldm', 'Stable Diffusion', []), - # (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), + (sd_path, 'ldm', 'Stable Diffusion', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), (os.path.join(sd_path, '../huggingface_guess'), 'huggingface_guess/detection.py', 'huggingface_guess', []), @@ -53,13 +53,13 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) - # elif "sgm" in options: - # # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we - # # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. - # - # sys.path.insert(0, d) - # import sgm # noqa: F401 - # sys.path.pop(0) + elif "sgm" in options: + # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we + # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. + + sys.path.insert(0, d) + import sgm # noqa: F401 + sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eb06a849..f292073b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,3 +1,124 @@ +import torch +from torch.nn.functional import silu +from types import MethodType + +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches +from modules.hypernetworks import hypernetwork +from modules.shared import cmd_opts +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 + +import ldm.modules.attention +import ldm.modules.diffusionmodules.model +import ldm.modules.diffusionmodules.openaimodel +import ldm.models.diffusion.ddpm +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms +import ldm.modules.encoders.modules + +import sgm.modules.attention +import sgm.modules.diffusionmodules.model +import sgm.modules.diffusionmodules.openaimodel +import sgm.modules.encoders.modules + +attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + +# new memory efficient cross attention blocks do not support hypernets and we already +# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention +ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention +ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention + +# silence new console spam from SD2 +ldm.modules.attention.print = shared.ldm_print +ldm.modules.diffusionmodules.model.print = shared.ldm_print +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print + +optimizers = [] +current_optimizer: sd_hijack_optimizations.SdOptimization = None + +ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) +ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) + +sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) +sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) + + +def list_optimizers(): + new_optimizers = script_callbacks.list_optimizers_callback() + + new_optimizers = [x for x in new_optimizers if x.is_available()] + + new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) + + optimizers.clear() + optimizers.extend(new_optimizers) + + +def apply_optimizations(option=None): + return + + +def undo_optimizations(): + return + + +def fix_checkpoint(): + """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want + checkpoints to be added when not training (there's a warning)""" + + pass + + +def weighted_loss(sd_model, pred, target, mean=True): + #Calculate the weight normally, but ignore the mean + loss = sd_model._old_get_loss(pred, target, mean=False) + + #Check if we have weights available + weight = getattr(sd_model, '_custom_loss_weight', None) + if weight is not None: + loss *= weight + + #Return the loss, as mean if specified + return loss.mean() if mean else loss + +def weighted_forward(sd_model, x, c, w, *args, **kwargs): + try: + #Temporarily append weights to a place accessible during loss calc + sd_model._custom_loss_weight = w + + #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely + #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set + if not hasattr(sd_model, '_old_get_loss'): + sd_model._old_get_loss = sd_model.get_loss + sd_model.get_loss = MethodType(weighted_loss, sd_model) + + #Run the standard forward function, but with the patched 'get_loss' + return sd_model.forward(x, c, *args, **kwargs) + finally: + try: + #Delete temporary weights if appended + del sd_model._custom_loss_weight + except AttributeError: + pass + + #If we have an old loss function, reset the loss function to the original one + if hasattr(sd_model, '_old_get_loss'): + sd_model.get_loss = sd_model._old_get_loss + del sd_model._old_get_loss + +def apply_weighted_forward(sd_model): + #Add new function 'weighted_forward' that can be called to calc weighted loss + sd_model.weighted_forward = MethodType(weighted_forward, sd_model) + +def undo_weighted_forward(sd_model): + try: + del sd_model.weighted_forward + except AttributeError: + pass + + class StableDiffusionModelHijack: fixes = None layers = None @@ -35,201 +156,74 @@ class StableDiffusionModelHijack: pass +class EmbeddingsWithFixes(torch.nn.Module): + def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): + super().__init__() + self.wrapped = wrapped + self.embeddings = embeddings + self.textual_inversion_key = textual_inversion_key + self.weight = self.wrapped.weight + + def forward(self, input_ids): + batch_fixes = self.embeddings.fixes + self.embeddings.fixes = None + + inputs_embeds = self.wrapped(input_ids) + + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec + emb = devices.cond_cast_unet(vec) + emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) + + vecs.append(tensor) + + return torch.stack(vecs) + + +class TextualInversionEmbeddings(torch.nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): + super().__init__(num_embeddings, embedding_dim, **kwargs) + + self.embeddings = model_hijack + self.textual_inversion_key = textual_inversion_key + + @property + def wrapped(self): + return super().forward + + def forward(self, input_ids): + return EmbeddingsWithFixes.forward(self, input_ids) + + +def add_circular_option_to_conv_2d(): + conv2d_constructor = torch.nn.Conv2d.__init__ + + def conv2d_constructor_circular(self, *args, **kwargs): + return conv2d_constructor(self, *args, padding_mode='circular', **kwargs) + + torch.nn.Conv2d.__init__ = conv2d_constructor_circular + + model_hijack = StableDiffusionModelHijack() -# import torch -# from torch.nn.functional import silu -# from types import MethodType -# -# from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches -# from modules.hypernetworks import hypernetwork -# from modules.shared import cmd_opts -# from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18 -# -# import ldm.modules.attention -# import ldm.modules.diffusionmodules.model -# import ldm.modules.diffusionmodules.openaimodel -# import ldm.models.diffusion.ddpm -# import ldm.models.diffusion.ddim -# import ldm.models.diffusion.plms -# import ldm.modules.encoders.modules -# -# import sgm.modules.attention -# import sgm.modules.diffusionmodules.model -# import sgm.modules.diffusionmodules.openaimodel -# import sgm.modules.encoders.modules -# -# attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward -# diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity -# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# -# # new memory efficient cross attention blocks do not support hypernets and we already -# # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention -# ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention -# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention -# -# # silence new console spam from SD2 -# ldm.modules.attention.print = shared.ldm_print -# ldm.modules.diffusionmodules.model.print = shared.ldm_print -# ldm.util.print = shared.ldm_print -# ldm.models.diffusion.ddpm.print = shared.ldm_print -# -# optimizers = [] -# current_optimizer: sd_hijack_optimizations.SdOptimization = None -# -# ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward) -# ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward) -# -# sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward) -# sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward) -# -# -# def list_optimizers(): -# new_optimizers = script_callbacks.list_optimizers_callback() -# -# new_optimizers = [x for x in new_optimizers if x.is_available()] -# -# new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) -# -# optimizers.clear() -# optimizers.extend(new_optimizers) -# -# -# def apply_optimizations(option=None): -# return -# -# -# def undo_optimizations(): -# return -# -# -# def fix_checkpoint(): -# """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want -# checkpoints to be added when not training (there's a warning)""" -# -# pass -# -# -# def weighted_loss(sd_model, pred, target, mean=True): -# #Calculate the weight normally, but ignore the mean -# loss = sd_model._old_get_loss(pred, target, mean=False) -# -# #Check if we have weights available -# weight = getattr(sd_model, '_custom_loss_weight', None) -# if weight is not None: -# loss *= weight -# -# #Return the loss, as mean if specified -# return loss.mean() if mean else loss -# -# def weighted_forward(sd_model, x, c, w, *args, **kwargs): -# try: -# #Temporarily append weights to a place accessible during loss calc -# sd_model._custom_loss_weight = w -# -# #Replace 'get_loss' with a weight-aware one. Otherwise we need to reimplement 'forward' completely -# #Keep 'get_loss', but don't overwrite the previous old_get_loss if it's already set -# if not hasattr(sd_model, '_old_get_loss'): -# sd_model._old_get_loss = sd_model.get_loss -# sd_model.get_loss = MethodType(weighted_loss, sd_model) -# -# #Run the standard forward function, but with the patched 'get_loss' -# return sd_model.forward(x, c, *args, **kwargs) -# finally: -# try: -# #Delete temporary weights if appended -# del sd_model._custom_loss_weight -# except AttributeError: -# pass -# -# #If we have an old loss function, reset the loss function to the original one -# if hasattr(sd_model, '_old_get_loss'): -# sd_model.get_loss = sd_model._old_get_loss -# del sd_model._old_get_loss -# -# def apply_weighted_forward(sd_model): -# #Add new function 'weighted_forward' that can be called to calc weighted loss -# sd_model.weighted_forward = MethodType(weighted_forward, sd_model) -# -# def undo_weighted_forward(sd_model): -# try: -# del sd_model.weighted_forward -# except AttributeError: -# pass -# -# -# -# -# -# class EmbeddingsWithFixes(torch.nn.Module): -# def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'): -# super().__init__() -# self.wrapped = wrapped -# self.embeddings = embeddings -# self.textual_inversion_key = textual_inversion_key -# self.weight = self.wrapped.weight -# -# def forward(self, input_ids): -# batch_fixes = self.embeddings.fixes -# self.embeddings.fixes = None -# -# inputs_embeds = self.wrapped(input_ids) -# -# if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: -# return inputs_embeds -# -# vecs = [] -# for fixes, tensor in zip(batch_fixes, inputs_embeds): -# for offset, embedding in fixes: -# vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec -# emb = devices.cond_cast_unet(vec) -# emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) -# tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]).to(dtype=inputs_embeds.dtype) -# -# vecs.append(tensor) -# -# return torch.stack(vecs) -# -# -# class TextualInversionEmbeddings(torch.nn.Embedding): -# def __init__(self, num_embeddings: int, embedding_dim: int, textual_inversion_key='clip_l', **kwargs): -# super().__init__(num_embeddings, embedding_dim, **kwargs) -# -# self.embeddings = model_hijack -# self.textual_inversion_key = textual_inversion_key -# -# @property -# def wrapped(self): -# return super().forward -# -# def forward(self, input_ids): -# return EmbeddingsWithFixes.forward(self, input_ids) -# -# -# def add_circular_option_to_conv_2d(): -# conv2d_constructor = torch.nn.Conv2d.__init__ -# -# def conv2d_constructor_circular(self, *args, **kwargs): -# return conv2d_constructor(self, *args, padding_mode='circular', **kwargs) -# -# torch.nn.Conv2d.__init__ = conv2d_constructor_circular -# -# -# model_hijack = StableDiffusionModelHijack() -# -# -# def register_buffer(self, name, attr): -# """ -# Fix register buffer bug for Mac OS. -# """ -# -# if type(attr) == torch.Tensor: -# if attr.device != devices.device: -# attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) -# -# setattr(self, name, attr) -# -# -# ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer -# ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer +def register_buffer(self, name, attr): + """ + Fix register buffer bug for Mac OS. + """ + + if type(attr) == torch.Tensor: + if attr.device != devices.device: + attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) + + setattr(self, name, attr) + + +ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer +ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 696835ad..0269f1f5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,677 +1,677 @@ -# from __future__ import annotations -# import math -# import psutil -# import platform -# -# import torch -# from torch import einsum -# -# from ldm.util import default -# from einops import rearrange -# -# from modules import shared, errors, devices, sub_quadratic_attention -# from modules.hypernetworks import hypernetwork -# -# import ldm.modules.attention -# import ldm.modules.diffusionmodules.model -# -# import sgm.modules.attention -# import sgm.modules.diffusionmodules.model -# -# diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward -# -# -# class SdOptimization: -# name: str = None -# label: str | None = None -# cmd_opt: str | None = None -# priority: int = 0 -# -# def title(self): -# if self.label is None: -# return self.name -# -# return f"{self.name} - {self.label}" -# -# def is_available(self): -# return True -# -# def apply(self): -# pass -# -# def undo(self): -# ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward -# ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -# -# sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward -# sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward -# -# -# class SdOptimizationXformers(SdOptimization): -# name = "xformers" -# cmd_opt = "xformers" -# priority = 100 -# -# def is_available(self): -# return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) -# -# def apply(self): -# ldm.modules.attention.CrossAttention.forward = xformers_attention_forward -# ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward -# sgm.modules.attention.CrossAttention.forward = xformers_attention_forward -# sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward -# -# -# class SdOptimizationSdpNoMem(SdOptimization): -# name = "sdp-no-mem" -# label = "scaled dot product without memory efficient attention" -# cmd_opt = "opt_sdp_no_mem_attention" -# priority = 80 -# -# def is_available(self): -# return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) -# -# def apply(self): -# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward -# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward -# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward -# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward -# -# -# class SdOptimizationSdp(SdOptimizationSdpNoMem): -# name = "sdp" -# label = "scaled dot product" -# cmd_opt = "opt_sdp_attention" -# priority = 70 -# -# def apply(self): -# ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward -# ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward -# sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward -# sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward -# -# -# class SdOptimizationSubQuad(SdOptimization): -# name = "sub-quadratic" -# cmd_opt = "opt_sub_quad_attention" -# -# @property -# def priority(self): -# return 1000 if shared.device.type == 'mps' else 10 -# -# def apply(self): -# ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward -# ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward -# sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward -# sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward -# -# -# class SdOptimizationV1(SdOptimization): -# name = "V1" -# label = "original v1" -# cmd_opt = "opt_split_attention_v1" -# priority = 10 -# -# def apply(self): -# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 -# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 -# -# -# class SdOptimizationInvokeAI(SdOptimization): -# name = "InvokeAI" -# cmd_opt = "opt_split_attention_invokeai" -# -# @property -# def priority(self): -# return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 -# -# def apply(self): -# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI -# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI -# -# -# class SdOptimizationDoggettx(SdOptimization): -# name = "Doggettx" -# cmd_opt = "opt_split_attention" -# priority = 90 -# -# def apply(self): -# ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward -# ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward -# sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward -# sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward -# -# -# def list_optimizers(res): -# res.extend([ -# SdOptimizationXformers(), -# SdOptimizationSdpNoMem(), -# SdOptimizationSdp(), -# SdOptimizationSubQuad(), -# SdOptimizationV1(), -# SdOptimizationInvokeAI(), -# SdOptimizationDoggettx(), -# ]) -# -# -# if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: -# try: -# import xformers.ops -# shared.xformers_available = True -# except Exception: -# errors.report("Cannot import xformers", exc_info=True) -# -# -# def get_available_vram(): -# if shared.device.type == 'cuda': -# stats = torch.cuda.memory_stats(shared.device) -# mem_active = stats['active_bytes.all.current'] -# mem_reserved = stats['reserved_bytes.all.current'] -# mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) -# mem_free_torch = mem_reserved - mem_active -# mem_free_total = mem_free_cuda + mem_free_torch -# return mem_free_total -# else: -# return psutil.virtual_memory().available -# -# -# # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -# def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): -# h = self.heads -# -# q_in = self.to_q(x) -# context = default(context, x) -# -# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) -# k_in = self.to_k(context_k) -# v_in = self.to_v(context_v) -# del context, context_k, context_v, x -# -# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) -# del q_in, k_in, v_in -# -# dtype = q.dtype -# if shared.opts.upcast_attn: -# q, k, v = q.float(), k.float(), v.float() -# -# with devices.without_autocast(disable=not shared.opts.upcast_attn): -# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) -# for i in range(0, q.shape[0], 2): -# end = i + 2 -# s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) -# s1 *= self.scale -# -# s2 = s1.softmax(dim=-1) -# del s1 -# -# r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) -# del s2 -# del q, k, v -# -# r1 = r1.to(dtype) -# -# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) -# del r1 -# -# return self.to_out(r2) -# -# -# # taken from https://github.com/Doggettx/stable-diffusion and modified -# def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): -# h = self.heads -# -# q_in = self.to_q(x) -# context = default(context, x) -# -# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) -# k_in = self.to_k(context_k) -# v_in = self.to_v(context_v) -# -# dtype = q_in.dtype -# if shared.opts.upcast_attn: -# q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() -# -# with devices.without_autocast(disable=not shared.opts.upcast_attn): -# k_in = k_in * self.scale -# -# del context, x -# -# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) -# del q_in, k_in, v_in -# -# r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) -# -# mem_free_total = get_available_vram() -# -# gb = 1024 ** 3 -# tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() -# modifier = 3 if q.element_size() == 2 else 2.5 -# mem_required = tensor_size * modifier -# steps = 1 -# -# if mem_required > mem_free_total: -# steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) -# # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " -# # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") -# -# if steps > 64: -# max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 -# raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' -# f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') -# -# slice_size = q.shape[1] // steps -# for i in range(0, q.shape[1], slice_size): -# end = min(i + slice_size, q.shape[1]) -# s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) -# -# s2 = s1.softmax(dim=-1, dtype=q.dtype) -# del s1 -# -# r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) -# del s2 -# -# del q, k, v -# -# r1 = r1.to(dtype) -# -# r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) -# del r1 -# -# return self.to_out(r2) -# -# -# # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -# mem_total_gb = psutil.virtual_memory().total // (1 << 30) -# -# -# def einsum_op_compvis(q, k, v): -# s = einsum('b i d, b j d -> b i j', q, k) -# s = s.softmax(dim=-1, dtype=s.dtype) -# return einsum('b i j, b j d -> b i d', s, v) -# -# -# def einsum_op_slice_0(q, k, v, slice_size): -# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) -# for i in range(0, q.shape[0], slice_size): -# end = i + slice_size -# r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) -# return r -# -# -# def einsum_op_slice_1(q, k, v, slice_size): -# r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) -# for i in range(0, q.shape[1], slice_size): -# end = i + slice_size -# r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) -# return r -# -# -# def einsum_op_mps_v1(q, k, v): -# if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 -# return einsum_op_compvis(q, k, v) -# else: -# slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) -# if slice_size % 4096 == 0: -# slice_size -= 1 -# return einsum_op_slice_1(q, k, v, slice_size) -# -# -# def einsum_op_mps_v2(q, k, v): -# if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: -# return einsum_op_compvis(q, k, v) -# else: -# return einsum_op_slice_0(q, k, v, 1) -# -# -# def einsum_op_tensor_mem(q, k, v, max_tensor_mb): -# size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) -# if size_mb <= max_tensor_mb: -# return einsum_op_compvis(q, k, v) -# div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() -# if div <= q.shape[0]: -# return einsum_op_slice_0(q, k, v, q.shape[0] // div) -# return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) -# -# -# def einsum_op_cuda(q, k, v): -# stats = torch.cuda.memory_stats(q.device) -# mem_active = stats['active_bytes.all.current'] -# mem_reserved = stats['reserved_bytes.all.current'] -# mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) -# mem_free_torch = mem_reserved - mem_active -# mem_free_total = mem_free_cuda + mem_free_torch -# # Divide factor of safety as there's copying and fragmentation -# return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) -# -# -# def einsum_op(q, k, v): -# if q.device.type == 'cuda': -# return einsum_op_cuda(q, k, v) -# -# if q.device.type == 'mps': -# if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: -# return einsum_op_mps_v1(q, k, v) -# return einsum_op_mps_v2(q, k, v) -# -# # Smaller slices are faster due to L2/L3/SLC caches. -# # Tested on i7 with 8MB L3 cache. -# return einsum_op_tensor_mem(q, k, v, 32) -# -# -# def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): -# h = self.heads -# -# q = self.to_q(x) -# context = default(context, x) -# -# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) -# k = self.to_k(context_k) -# v = self.to_v(context_v) -# del context, context_k, context_v, x -# -# dtype = q.dtype -# if shared.opts.upcast_attn: -# q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() -# -# with devices.without_autocast(disable=not shared.opts.upcast_attn): -# k = k * self.scale -# -# q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) -# r = einsum_op(q, k, v) -# r = r.to(dtype) -# return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) -# -# # -- End of code from https://github.com/invoke-ai/InvokeAI -- -# -# -# # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 -# # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -# def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): -# assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." -# -# h = self.heads -# -# q = self.to_q(x) -# context = default(context, x) -# -# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) -# k = self.to_k(context_k) -# v = self.to_v(context_v) -# del context, context_k, context_v, x -# -# q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) -# k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) -# v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) -# -# if q.device.type == 'mps': -# q, k, v = q.contiguous(), k.contiguous(), v.contiguous() -# -# dtype = q.dtype -# if shared.opts.upcast_attn: -# q, k = q.float(), k.float() -# -# x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) -# -# x = x.to(dtype) -# -# x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) -# -# out_proj, dropout = self.to_out -# x = out_proj(x) -# x = dropout(x) -# -# return x -# -# -# def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): -# bytes_per_token = torch.finfo(q.dtype).bits//8 -# batch_x_heads, q_tokens, _ = q.shape -# _, k_tokens, _ = k.shape -# qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens -# -# if chunk_threshold is None: -# if q.device.type == 'mps': -# chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) -# else: -# chunk_threshold_bytes = int(get_available_vram() * 0.7) -# elif chunk_threshold == 0: -# chunk_threshold_bytes = None -# else: -# chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) -# -# if kv_chunk_size_min is None and chunk_threshold_bytes is not None: -# kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) -# elif kv_chunk_size_min == 0: -# kv_chunk_size_min = None -# -# if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: -# # the big matmul fits into our memory limit; do everything in 1 chunk, -# # i.e. send it down the unchunked fast-path -# kv_chunk_size = k_tokens -# -# with devices.without_autocast(disable=q.dtype == v.dtype): -# return sub_quadratic_attention.efficient_dot_product_attention( -# q, -# k, -# v, -# query_chunk_size=q_chunk_size, -# kv_chunk_size=kv_chunk_size, -# kv_chunk_size_min = kv_chunk_size_min, -# use_checkpoint=use_checkpoint, -# ) -# -# -# def get_xformers_flash_attention_op(q, k, v): -# if not shared.cmd_opts.xformers_flash_attention: -# return None -# -# try: -# flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp -# fw, bw = flash_attention_op -# if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): -# return flash_attention_op -# except Exception as e: -# errors.display_once(e, "enabling flash attention") -# -# return None -# -# -# def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): -# h = self.heads -# q_in = self.to_q(x) -# context = default(context, x) -# -# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) -# k_in = self.to_k(context_k) -# v_in = self.to_v(context_v) -# -# q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) -# -# del q_in, k_in, v_in -# -# dtype = q.dtype -# if shared.opts.upcast_attn: -# q, k, v = q.float(), k.float(), v.float() -# -# out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) -# -# out = out.to(dtype) -# -# b, n, h, d = out.shape -# out = out.reshape(b, n, h * d) -# return self.to_out(out) -# -# -# # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py -# # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -# def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): -# batch_size, sequence_length, inner_dim = x.shape -# -# if mask is not None: -# mask = self.prepare_attention_mask(mask, sequence_length, batch_size) -# mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) -# -# h = self.heads -# q_in = self.to_q(x) -# context = default(context, x) -# -# context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) -# k_in = self.to_k(context_k) -# v_in = self.to_v(context_v) -# -# head_dim = inner_dim // h -# q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) -# k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) -# v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) -# -# del q_in, k_in, v_in -# -# dtype = q.dtype -# if shared.opts.upcast_attn: -# q, k, v = q.float(), k.float(), v.float() -# -# # the output of sdp = (batch, num_heads, seq_len, head_dim) -# hidden_states = torch.nn.functional.scaled_dot_product_attention( -# q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False -# ) -# -# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) -# hidden_states = hidden_states.to(dtype) -# -# # linear proj -# hidden_states = self.to_out[0](hidden_states) -# # dropout -# hidden_states = self.to_out[1](hidden_states) -# return hidden_states -# -# -# def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): -# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): -# return scaled_dot_product_attention_forward(self, x, context, mask) -# -# -# def cross_attention_attnblock_forward(self, x): -# h_ = x -# h_ = self.norm(h_) -# q1 = self.q(h_) -# k1 = self.k(h_) -# v = self.v(h_) -# -# # compute attention -# b, c, h, w = q1.shape -# -# q2 = q1.reshape(b, c, h*w) -# del q1 -# -# q = q2.permute(0, 2, 1) # b,hw,c -# del q2 -# -# k = k1.reshape(b, c, h*w) # b,c,hw -# del k1 -# -# h_ = torch.zeros_like(k, device=q.device) -# -# mem_free_total = get_available_vram() -# -# tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() -# mem_required = tensor_size * 2.5 -# steps = 1 -# -# if mem_required > mem_free_total: -# steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) -# -# slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] -# for i in range(0, q.shape[1], slice_size): -# end = i + slice_size -# -# w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] -# w2 = w1 * (int(c)**(-0.5)) -# del w1 -# w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) -# del w2 -# -# # attend to values -# v1 = v.reshape(b, c, h*w) -# w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) -# del w3 -# -# h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] -# del v1, w4 -# -# h2 = h_.reshape(b, c, h, w) -# del h_ -# -# h3 = self.proj_out(h2) -# del h2 -# -# h3 += x -# -# return h3 -# -# -# def xformers_attnblock_forward(self, x): -# try: -# h_ = x -# h_ = self.norm(h_) -# q = self.q(h_) -# k = self.k(h_) -# v = self.v(h_) -# b, c, h, w = q.shape -# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) -# dtype = q.dtype -# if shared.opts.upcast_attn: -# q, k = q.float(), k.float() -# q = q.contiguous() -# k = k.contiguous() -# v = v.contiguous() -# out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) -# out = out.to(dtype) -# out = rearrange(out, 'b (h w) c -> b c h w', h=h) -# out = self.proj_out(out) -# return x + out -# except NotImplementedError: -# return cross_attention_attnblock_forward(self, x) -# -# -# def sdp_attnblock_forward(self, x): -# h_ = x -# h_ = self.norm(h_) -# q = self.q(h_) -# k = self.k(h_) -# v = self.v(h_) -# b, c, h, w = q.shape -# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) -# dtype = q.dtype -# if shared.opts.upcast_attn: -# q, k, v = q.float(), k.float(), v.float() -# q = q.contiguous() -# k = k.contiguous() -# v = v.contiguous() -# out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) -# out = out.to(dtype) -# out = rearrange(out, 'b (h w) c -> b c h w', h=h) -# out = self.proj_out(out) -# return x + out -# -# -# def sdp_no_mem_attnblock_forward(self, x): -# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): -# return sdp_attnblock_forward(self, x) -# -# -# def sub_quad_attnblock_forward(self, x): -# h_ = x -# h_ = self.norm(h_) -# q = self.q(h_) -# k = self.k(h_) -# v = self.v(h_) -# b, c, h, w = q.shape -# q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) -# q = q.contiguous() -# k = k.contiguous() -# v = v.contiguous() -# out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) -# out = rearrange(out, 'b (h w) c -> b c h w', h=h) -# out = self.proj_out(out) -# return x + out +from __future__ import annotations +import math +import psutil +import platform + +import torch +from torch import einsum + +from ldm.util import default +from einops import rearrange + +from modules import shared, errors, devices, sub_quadratic_attention +from modules.hypernetworks import hypernetwork + +import ldm.modules.attention +import ldm.modules.diffusionmodules.model + +import sgm.modules.attention +import sgm.modules.diffusionmodules.model + +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward + + +class SdOptimization: + name: str = None + label: str | None = None + cmd_opt: str | None = None + priority: int = 0 + + def title(self): + if self.label is None: + return self.name + + return f"{self.name} - {self.label}" + + def is_available(self): + return True + + def apply(self): + pass + + def undo(self): + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward + + +class SdOptimizationXformers(SdOptimization): + name = "xformers" + cmd_opt = "xformers" + priority = 100 + + def is_available(self): + return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) + + def apply(self): + ldm.modules.attention.CrossAttention.forward = xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + sgm.modules.attention.CrossAttention.forward = xformers_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + + +class SdOptimizationSdpNoMem(SdOptimization): + name = "sdp-no-mem" + label = "scaled dot product without memory efficient attention" + cmd_opt = "opt_sdp_no_mem_attention" + priority = 80 + + def is_available(self): + return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) + + def apply(self): + ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + + +class SdOptimizationSdp(SdOptimizationSdpNoMem): + name = "sdp" + label = "scaled dot product" + cmd_opt = "opt_sdp_attention" + priority = 70 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + + +class SdOptimizationSubQuad(SdOptimization): + name = "sub-quadratic" + cmd_opt = "opt_sub_quad_attention" + + @property + def priority(self): + return 1000 if shared.device.type == 'mps' else 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + + +class SdOptimizationV1(SdOptimization): + name = "V1" + label = "original v1" + cmd_opt = "opt_split_attention_v1" + priority = 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + + +class SdOptimizationInvokeAI(SdOptimization): + name = "InvokeAI" + cmd_opt = "opt_split_attention_invokeai" + + @property + def priority(self): + return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + + +class SdOptimizationDoggettx(SdOptimization): + name = "Doggettx" + cmd_opt = "opt_split_attention" + priority = 90 + + def apply(self): + ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + + +def list_optimizers(res): + res.extend([ + SdOptimizationXformers(), + SdOptimizationSdpNoMem(), + SdOptimizationSdp(), + SdOptimizationSubQuad(), + SdOptimizationV1(), + SdOptimizationInvokeAI(), + SdOptimizationDoggettx(), + ]) + + +if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: + try: + import xformers.ops + shared.xformers_available = True + except Exception: + errors.report("Cannot import xformers", exc_info=True) + + +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + del context, context_k, context_v, x + + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v + + r1 = r1.to(dtype) + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# taken from https://github.com/Doggettx/stable-diffusion and modified +def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps + for i in range(0, q.shape[1], slice_size): + end = min(i + slice_size, q.shape[1]) + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r1 = r1.to(dtype) + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- +mem_total_gb = psutil.virtual_memory().total // (1 << 30) + + +def einsum_op_compvis(q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + + +def einsum_op_slice_0(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + + +def einsum_op_slice_1(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) + return r + + +def einsum_op_mps_v1(q, k, v): + if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 + return einsum_op_compvis(q, k, v) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + if slice_size % 4096 == 0: + slice_size -= 1 + return einsum_op_slice_1(q, k, v, slice_size) + + +def einsum_op_mps_v2(q, k, v): + if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: + return einsum_op_compvis(q, k, v) + else: + return einsum_op_slice_0(q, k, v, 1) + + +def einsum_op_tensor_mem(q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return einsum_op_slice_0(q, k, v, q.shape[0] // div) + return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + + +def einsum_op_cuda(q, k, v): + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + # Divide factor of safety as there's copying and fragmentation + return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + + +def einsum_op(q, k, v): + if q.device.type == 'cuda': + return einsum_op_cuda(q, k, v) + + if q.device.type == 'mps': + if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18: + return einsum_op_mps_v1(q, k, v) + return einsum_op_mps_v2(q, k, v) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return einsum_op_tensor_mem(q, k, v, 32) + + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + +# -- End of code from https://github.com/invoke-ai/InvokeAI -- + + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface +def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + if q.device.type == 'mps': + q, k, v = q.contiguous(), k.contiguous(), v.contiguous() + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.to(dtype) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + if chunk_threshold is None: + if q.device.type == 'mps': + chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) + else: + chunk_threshold_bytes = int(get_available_vram() * 0.7) + elif chunk_threshold == 0: + chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) + + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + kv_chunk_size = k_tokens + + with devices.without_autocast(disable=q.dtype == v.dtype): + return sub_quadratic_attention.efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + +def get_xformers_flash_attention_op(q, k, v): + if not shared.cmd_opts.xformers_flash_attention: + return None + + try: + flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = flash_attention_op + if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + return flash_attention_op + except Exception as e: + errors.display_once(e, "enabling flash attention") + + return None + + +def xformers_attention_forward(self, x, context=None, mask=None, **kwargs): + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + q, k, v = (t.reshape(t.shape[0], t.shape[1], h, -1) for t in (q_in, k_in, v_in)) + + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + + out = out.to(dtype) + + b, n, h, d = out.shape + out = out.reshape(b, n, h * d) + return self.to_out(out) + + +# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py +# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs): + batch_size, sequence_length, inner_dim = x.shape + + if mask is not None: + mask = self.prepare_attention_mask(mask, sequence_length, batch_size) + mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) + + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + head_dim = inner_dim // h + q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return scaled_dot_product_attention_forward(self, x, context, mask) + + +def cross_attention_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + mem_free_total = get_available_vram() + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 + + +def xformers_attnblock_forward(self, x): + try: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + except NotImplementedError: + return cross_attention_attnblock_forward(self, x) + + +def sdp_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False) + out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out + + +def sdp_no_mem_attnblock_forward(self, x): + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): + return sdp_attnblock_forward(self, x) + + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index eb4a0af4..b4f03b13 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,154 +1,154 @@ -# import torch -# from packaging import version -# from einops import repeat -# import math -# -# from modules import devices -# from modules.sd_hijack_utils import CondFunc -# -# -# class TorchHijackForUnet: -# """ -# This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; -# this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 -# """ -# -# def __getattr__(self, item): -# if item == 'cat': -# return self.cat -# -# if hasattr(torch, item): -# return getattr(torch, item) -# -# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") -# -# def cat(self, tensors, *args, **kwargs): -# if len(tensors) == 2: -# a, b = tensors -# if a.shape[-2:] != b.shape[-2:]: -# a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") -# -# tensors = (a, b) -# -# return torch.cat(tensors, *args, **kwargs) -# -# -# th = TorchHijackForUnet() -# -# -# # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling -# def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): -# """Always make sure inputs to unet are in correct dtype.""" -# if isinstance(cond, dict): -# for y in cond.keys(): -# if isinstance(cond[y], list): -# cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] -# else: -# cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] -# -# with devices.autocast(): -# result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) -# if devices.unet_needs_upcast: -# return result.float() -# else: -# return result -# -# -# # Monkey patch to create timestep embed tensor on device, avoiding a block. -# def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): -# """ -# Create sinusoidal timestep embeddings. -# :param timesteps: a 1-D Tensor of N indices, one per batch element. -# These may be fractional. -# :param dim: the dimension of the output. -# :param max_period: controls the minimum frequency of the embeddings. -# :return: an [N x dim] Tensor of positional embeddings. -# """ -# if not repeat_only: -# half = dim // 2 -# freqs = torch.exp( -# -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half -# ) -# args = timesteps[:, None].float() * freqs[None] -# embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) -# if dim % 2: -# embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) -# else: -# embedding = repeat(timesteps, 'b -> b d', d=dim) -# return embedding -# -# -# # Monkey patch to SpatialTransformer removing unnecessary contiguous calls. -# # Prevents a lot of unnecessary aten::copy_ calls -# def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): -# # note: if no context is given, cross-attention defaults to self-attention -# if not isinstance(context, list): -# context = [context] -# b, c, h, w = x.shape -# x_in = x -# x = self.norm(x) -# if not self.use_linear: -# x = self.proj_in(x) -# x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) -# if self.use_linear: -# x = self.proj_in(x) -# for i, block in enumerate(self.transformer_blocks): -# x = block(x, context=context[i]) -# if self.use_linear: -# x = self.proj_out(x) -# x = x.view(b, h, w, c).permute(0, 3, 1, 2) -# if not self.use_linear: -# x = self.proj_out(x) -# return x + x_in -# -# -# class GELUHijack(torch.nn.GELU, torch.nn.Module): -# def __init__(self, *args, **kwargs): -# torch.nn.GELU.__init__(self, *args, **kwargs) -# def forward(self, x): -# if devices.unet_needs_upcast: -# return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) -# else: -# return torch.nn.GELU.forward(self, x) -# -# -# ddpm_edit_hijack = None -# def hijack_ddpm_edit(): -# global ddpm_edit_hijack -# if not ddpm_edit_hijack: -# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) -# CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) -# ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) -# -# -# unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast -# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) -# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) -# CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) -# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) -# -# if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): -# CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) -# CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) -# CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) -# -# first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 -# first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) -# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) -# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) -# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) -# -# CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) -# CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) -# -# -# def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): -# if devices.unet_needs_upcast and timesteps.dtype == torch.int64: -# dtype = torch.float32 -# else: -# dtype = devices.dtype_unet -# return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) -# -# -# CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) -# CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +import torch +from packaging import version +from einops import repeat +import math + +from modules import devices +from modules.sd_hijack_utils import CondFunc + + +class TorchHijackForUnet: + """ + This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; + this makes it possible to create pictures with dimensions that are multiples of 8 rather than 64 + """ + + def __getattr__(self, item): + if item == 'cat': + return self.cat + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def cat(self, tensors, *args, **kwargs): + if len(tensors) == 2: + a, b = tensors + if a.shape[-2:] != b.shape[-2:]: + a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest") + + tensors = (a, b) + + return torch.cat(tensors, *args, **kwargs) + + +th = TorchHijackForUnet() + + +# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): + """Always make sure inputs to unet are in correct dtype.""" + if isinstance(cond, dict): + for y in cond.keys(): + if isinstance(cond[y], list): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + else: + cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] + + with devices.autocast(): + result = orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs) + if devices.unet_needs_upcast: + return result.float() + else: + return result + + +# Monkey patch to create timestep embed tensor on device, avoiding a block. +def timestep_embedding(_, timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +# Monkey patch to SpatialTransformer removing unnecessary contiguous calls. +# Prevents a lot of unnecessary aten::copy_ calls +def spatial_transformer_forward(_, self, x: torch.Tensor, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = x.view(b, h, w, c).permute(0, 3, 1, 2) + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class GELUHijack(torch.nn.GELU, torch.nn.Module): + def __init__(self, *args, **kwargs): + torch.nn.GELU.__init__(self, *args, **kwargs) + def forward(self, x): + if devices.unet_needs_upcast: + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + else: + return torch.nn.GELU.forward(self, x) + + +ddpm_edit_hijack = None +def hijack_ddpm_edit(): + global ddpm_edit_hijack + if not ddpm_edit_hijack: + CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) + CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) + ddpm_edit_hijack = CondFunc('modules.models.diffusion.ddpm_edit.LatentDiffusion.apply_model', apply_model) + + +unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding) +CondFunc('ldm.modules.attention.SpatialTransformer.forward', spatial_transformer_forward) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) + +if version.parse(torch.__version__) <= version.parse("1.13.2") or torch.cuda.is_available(): + CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) + +first_stage_cond = lambda _, self, *args, **kwargs: devices.unet_needs_upcast and self.model.diffusion_model.dtype == torch.float16 +first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) + +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model) +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model) + + +def timestep_embedding_cast_result(orig_func, timesteps, *args, **kwargs): + if devices.unet_needs_upcast and timesteps.dtype == torch.int64: + dtype = torch.float32 + else: + dtype = devices.dtype_unet + return orig_func(timesteps, *args, **kwargs).to(dtype=dtype) + + +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', timestep_embedding_cast_result) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index faddd9a2..41e5087d 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -1,137 +1,137 @@ -# import os -# -# import torch -# -# from modules import shared, paths, sd_disable_initialization, devices -# -# sd_configs_path = shared.sd_configs_path -# sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") -# sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") -# -# -# config_default = shared.sd_default_config -# # config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") -# config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") -# config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") -# config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") -# config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") -# config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") -# config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") -# config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") -# config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") -# config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") -# config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") -# config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") -# config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") -# config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") -# -# -# def is_using_v_parameterization_for_sd2(state_dict): -# """ -# Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. -# """ -# -# import ldm.modules.diffusionmodules.openaimodel -# -# device = devices.device -# -# with sd_disable_initialization.DisableInitialization(): -# unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( -# use_checkpoint=False, -# use_fp16=False, -# image_size=32, -# in_channels=4, -# out_channels=4, -# model_channels=320, -# attention_resolutions=[4, 2, 1], -# num_res_blocks=2, -# channel_mult=[1, 2, 4, 4], -# num_head_channels=64, -# use_spatial_transformer=True, -# use_linear_in_transformer=True, -# transformer_depth=1, -# context_dim=1024, -# legacy=False -# ) -# unet.eval() -# -# with torch.no_grad(): -# unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} -# unet.load_state_dict(unet_sd, strict=True) -# unet.to(device=device, dtype=devices.dtype_unet) -# -# test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 -# x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 -# -# with devices.autocast(): -# out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() -# -# return out < -1 -# -# -# def guess_model_config_from_state_dict(sd, filename): -# sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) -# diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) -# sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) -# -# if "model.diffusion_model.x_embedder.proj.weight" in sd: -# return config_sd3 -# -# if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: -# if diffusion_model_input.shape[1] == 9: -# return config_sdxl_inpainting -# else: -# return config_sdxl -# -# if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: -# return config_sdxl_refiner -# elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: -# return config_depth_model -# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: -# return config_unclip -# elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: -# return config_unopenclip -# -# if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: -# if diffusion_model_input.shape[1] == 9: -# return config_sd2_inpainting -# # elif is_using_v_parameterization_for_sd2(sd): -# # return config_sd2v -# else: -# return config_sd2v -# -# if diffusion_model_input is not None: -# if diffusion_model_input.shape[1] == 9: -# return config_inpainting -# if diffusion_model_input.shape[1] == 8: -# return config_instruct_pix2pix -# -# if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: -# if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: -# return config_alt_diffusion_m18 -# return config_alt_diffusion -# -# return config_default -# -# -# def find_checkpoint_config(state_dict, info): -# if info is None: -# return guess_model_config_from_state_dict(state_dict, "") -# -# config = find_checkpoint_config_near_filename(info) -# if config is not None: -# return config -# -# return guess_model_config_from_state_dict(state_dict, info.filename) -# -# -# def find_checkpoint_config_near_filename(info): -# if info is None: -# return None -# -# config = f"{os.path.splitext(info.filename)[0]}.yaml" -# if os.path.exists(config): -# return config -# -# return None -# +import os + +import torch + +from modules import shared, paths, sd_disable_initialization, devices + +sd_configs_path = shared.sd_configs_path +sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") + + +config_default = shared.sd_default_config +# config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") +config_sdxl_inpainting = os.path.join(sd_configs_path, "sd_xl_inpaint.yaml") +config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") +config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") +config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") +config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") +config_alt_diffusion_m18 = os.path.join(sd_configs_path, "alt-diffusion-m18-inference.yaml") +config_sd3 = os.path.join(sd_configs_path, "sd3-inference.yaml") + + +def is_using_v_parameterization_for_sd2(state_dict): + """ + Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. + """ + + import ldm.modules.diffusionmodules.openaimodel + + device = devices.device + + with sd_disable_initialization.DisableInitialization(): + unet = ldm.modules.diffusionmodules.openaimodel.UNetModel( + use_checkpoint=False, + use_fp16=False, + image_size=32, + in_channels=4, + out_channels=4, + model_channels=320, + attention_resolutions=[4, 2, 1], + num_res_blocks=2, + channel_mult=[1, 2, 4, 4], + num_head_channels=64, + use_spatial_transformer=True, + use_linear_in_transformer=True, + transformer_depth=1, + context_dim=1024, + legacy=False + ) + unet.eval() + + with torch.no_grad(): + unet_sd = {k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if "model.diffusion_model." in k} + unet.load_state_dict(unet_sd, strict=True) + unet.to(device=device, dtype=devices.dtype_unet) + + test_cond = torch.ones((1, 2, 1024), device=device) * 0.5 + x_test = torch.ones((1, 4, 8, 8), device=device) * 0.5 + + with devices.autocast(): + out = (unet(x_test, torch.asarray([999], device=device), context=test_cond) - x_test).mean().cpu().item() + + return out < -1 + + +def guess_model_config_from_state_dict(sd, filename): + sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) + diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) + + if "model.diffusion_model.x_embedder.proj.weight" in sd: + return config_sd3 + + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: + if diffusion_model_input.shape[1] == 9: + return config_sdxl_inpainting + else: + return config_sdxl + + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: + return config_sdxl_refiner + elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + return config_depth_model + elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: + return config_unclip + elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 1024: + return config_unopenclip + + if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: + if diffusion_model_input.shape[1] == 9: + return config_sd2_inpainting + # elif is_using_v_parameterization_for_sd2(sd): + # return config_sd2v + else: + return config_sd2v + + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return config_inpainting + if diffusion_model_input.shape[1] == 8: + return config_instruct_pix2pix + + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: + if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: + return config_alt_diffusion_m18 + return config_alt_diffusion + + return config_default + + +def find_checkpoint_config(state_dict, info): + if info is None: + return guess_model_config_from_state_dict(state_dict, "") + + config = find_checkpoint_config_near_filename(info) + if config is not None: + return config + + return guess_model_config_from_state_dict(state_dict, info.filename) + + +def find_checkpoint_config_near_filename(info): + if info is None: + return None + + config = f"{os.path.splitext(info.filename)[0]}.yaml" + if os.path.exists(config): + return config + + return None + diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 0b84f2fc..3f1bab96 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,115 +1,115 @@ -# from __future__ import annotations -# -# import torch -# -# import sgm.models.diffusion -# import sgm.modules.diffusionmodules.denoiser_scaling -# import sgm.modules.diffusionmodules.discretizer -# from modules import devices, shared, prompt_parser -# from modules import torch_utils -# -# from backend import memory_management -# -# -# def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): -# -# for embedder in self.conditioner.embedders: -# embedder.ucg_rate = 0.0 -# -# width = getattr(batch, 'width', 1024) or 1024 -# height = getattr(batch, 'height', 1024) or 1024 -# is_negative_prompt = getattr(batch, 'is_negative_prompt', False) -# aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score -# -# devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) -# -# sdxl_conds = { -# "txt": batch, -# "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), -# "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), -# "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), -# "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), -# } -# -# force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) -# c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) -# -# return c -# -# -# def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): -# if self.model.diffusion_model.in_channels == 9: -# x = torch.cat([x] + cond['c_concat'], dim=1) -# -# return self.model(x, t, cond, *args, **kwargs) -# -# -# def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility -# return x -# -# -# sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -# sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -# sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding -# -# -# def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): -# res = [] -# -# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: -# encoded = embedder.encode_embedding_init_text(init_text, nvpt) -# res.append(encoded) -# -# return torch.cat(res, dim=1) -# -# -# def tokenize(self: sgm.modules.GeneralConditioner, texts): -# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: -# return embedder.tokenize(texts) -# -# raise AssertionError('no tokenizer available') -# -# -# -# def process_texts(self, texts): -# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: -# return embedder.process_texts(texts) -# -# -# def get_target_prompt_token_count(self, token_count): -# for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: -# return embedder.get_target_prompt_token_count(token_count) -# -# -# # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist -# sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text -# sgm.modules.GeneralConditioner.tokenize = tokenize -# sgm.modules.GeneralConditioner.process_texts = process_texts -# sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count -# -# -# def extend_sdxl(model): -# """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" -# -# dtype = torch_utils.get_param(model.model.diffusion_model).dtype -# model.model.diffusion_model.dtype = dtype -# model.model.conditioning_key = 'crossattn' -# model.cond_stage_key = 'txt' -# # model.cond_stage_model will be set in sd_hijack -# -# model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" -# -# discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() -# model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) -# -# model.conditioner.wrapped = torch.nn.Module() -# -# -# sgm.modules.attention.print = shared.ldm_print -# sgm.modules.diffusionmodules.model.print = shared.ldm_print -# sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print -# sgm.modules.encoders.modules.print = shared.ldm_print -# -# # this gets the code to load the vanilla attention that we override -# sgm.modules.attention.SDP_IS_AVAILABLE = True -# sgm.modules.attention.XFORMERS_IS_AVAILABLE = False +from __future__ import annotations + +import torch + +import sgm.models.diffusion +import sgm.modules.diffusionmodules.denoiser_scaling +import sgm.modules.diffusionmodules.discretizer +from modules import devices, shared, prompt_parser +from modules import torch_utils + +from backend import memory_management + + +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): + + for embedder in self.conditioner.embedders: + embedder.ucg_rate = 0.0 + + width = getattr(batch, 'width', 1024) or 1024 + height = getattr(batch, 'height', 1024) or 1024 + is_negative_prompt = getattr(batch, 'is_negative_prompt', False) + aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score + + devices_args = dict(device=self.forge_objects.clip.patcher.current_device, dtype=memory_management.text_encoder_dtype()) + + sdxl_conds = { + "txt": batch, + "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), + "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), + } + + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) + c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) + + return c + + +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond, *args, **kwargs): + if self.model.diffusion_model.in_channels == 9: + x = torch.cat([x] + cond['c_concat'], dim=1) + + return self.model(x, t, cond, *args, **kwargs) + + +def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility + return x + + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding + + +def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): + res = [] + + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: + encoded = embedder.encode_embedding_init_text(init_text, nvpt) + res.append(encoded) + + return torch.cat(res, dim=1) + + +def tokenize(self: sgm.modules.GeneralConditioner, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: + return embedder.tokenize(texts) + + raise AssertionError('no tokenizer available') + + + +def process_texts(self, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: + return embedder.process_texts(texts) + + +def get_target_prompt_token_count(self, token_count): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: + return embedder.get_target_prompt_token_count(token_count) + + +# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.tokenize = tokenize +sgm.modules.GeneralConditioner.process_texts = process_texts +sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count + + +def extend_sdxl(model): + """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" + + dtype = torch_utils.get_param(model.model.diffusion_model).dtype + model.model.diffusion_model.dtype = dtype + model.model.conditioning_key = 'crossattn' + model.cond_stage_key = 'txt' + # model.cond_stage_model will be set in sd_hijack + + model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" + + discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=torch.float32) + + model.conditioner.wrapped = torch.nn.Module() + + +sgm.modules.attention.print = shared.ldm_print +sgm.modules.diffusionmodules.model.print = shared.ldm_print +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print + +# this gets the code to load the vanilla attention that we override +sgm.modules.attention.SDP_IS_AVAILABLE = True +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False diff --git a/modules/shared_items.py b/modules/shared_items.py index 1568ba36..11f10b3f 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -35,7 +35,9 @@ def refresh_vae_list(): def cross_attention_optimizations(): - return ["Automatic"] + import modules.sd_hijack + + return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] def sd_unet_items():