tiled diffusion
This commit is contained in:
@@ -7,37 +7,65 @@
|
||||
from __future__ import division
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import ldm_patched.modules.model_management
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
import ldm_patched.modules.model_patcher
|
||||
from ldm_patched.modules.model_base import BaseModel
|
||||
from backend import memory_management
|
||||
from backend.misc.image_resize import adaptive_resize
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
from typing import List, Union, Tuple, Dict
|
||||
from ldm_patched.contrib.external import ImageScale
|
||||
import ldm_patched.modules.utils
|
||||
from backend.patcher.controlnet import ControlNet, T2IAdapter
|
||||
|
||||
|
||||
class ImageScale:
|
||||
def upscale(self, image, upscale_method, width, height, crop):
|
||||
if width == 0 and height == 0:
|
||||
s = image
|
||||
else:
|
||||
samples = image.movedim(-1, 1)
|
||||
|
||||
if width == 0:
|
||||
width = max(1, round(samples.shape[3] * height / samples.shape[2]))
|
||||
elif height == 0:
|
||||
height = max(1, round(samples.shape[2] * width / samples.shape[3]))
|
||||
|
||||
s = adaptive_resize(samples, width, height, upscale_method, crop)
|
||||
s = s.movedim(1, -1)
|
||||
return (s,)
|
||||
|
||||
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
|
||||
|
||||
def ceildiv(big, small):
|
||||
# Correct ceiling division that avoids floating-point errors and importing math.ceil.
|
||||
return -(big // -small)
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class BlendMode(Enum): # i.e. LayerType
|
||||
FOREGROUND = 'Foreground'
|
||||
BACKGROUND = 'Background'
|
||||
|
||||
|
||||
class Processing: ...
|
||||
|
||||
|
||||
class Device: ...
|
||||
|
||||
|
||||
devices = Device()
|
||||
devices.device = ldm_patched.modules.model_management.get_torch_device()
|
||||
devices.device = memory_management.get_torch_device()
|
||||
|
||||
|
||||
def null_decorator(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
keep_signature = null_decorator
|
||||
controlnet = null_decorator
|
||||
stablesr = null_decorator
|
||||
@@ -45,6 +73,7 @@ grid_bbox = null_decorator
|
||||
custom_bbox = null_decorator
|
||||
noise_inverse = null_decorator
|
||||
|
||||
|
||||
class BBox:
|
||||
''' grid bbox '''
|
||||
|
||||
@@ -59,6 +88,7 @@ class BBox:
|
||||
def __getitem__(self, idx: int) -> int:
|
||||
return self.box[idx]
|
||||
|
||||
|
||||
def split_bboxes(w: int, h: int, tile_w: int, tile_h: int, overlap: int = 16, init_weight: Union[Tensor, float] = 1.0) -> Tuple[List[BBox], Tensor]:
|
||||
cols = ceildiv((w - overlap), (tile_w - overlap))
|
||||
rows = ceildiv((h - overlap), (tile_h - overlap))
|
||||
@@ -78,16 +108,17 @@ def split_bboxes(w:int, h:int, tile_w:int, tile_h:int, overlap:int=16, init_weig
|
||||
|
||||
return bbox_list, weight
|
||||
|
||||
|
||||
class CustomBBox(BBox):
|
||||
''' region control bbox '''
|
||||
pass
|
||||
|
||||
|
||||
class AbstractDiffusion:
|
||||
def __init__(self):
|
||||
self.method = self.__class__.__name__
|
||||
self.pbar = None
|
||||
|
||||
|
||||
self.w: int = 0
|
||||
self.h: int = 0
|
||||
self.tile_width: int = None
|
||||
@@ -167,6 +198,7 @@ class AbstractDiffusion:
|
||||
return torch.cat([x for _ in range(n)], dim=0)[:concat_to]
|
||||
shape = [n] + [1] * r_dims # [N, 1, ...]
|
||||
return x.repeat(shape)
|
||||
|
||||
def update_pbar(self):
|
||||
if self.pbar.n >= self.pbar.total:
|
||||
self.pbar.close()
|
||||
@@ -180,6 +212,7 @@ class AbstractDiffusion:
|
||||
else:
|
||||
self.step_count = sampling_step
|
||||
self.inner_loop_count = 0
|
||||
|
||||
def reset_buffer(self, x_in: Tensor):
|
||||
# Judge if the shape of x_in is the same as the shape of x_buffer
|
||||
if self.x_buffer is None or self.x_buffer.shape != x_in.shape:
|
||||
@@ -319,22 +352,22 @@ class AbstractDiffusion:
|
||||
if dtype is None: dtype = x_dtype
|
||||
if isinstance(control, T2IAdapter):
|
||||
width, height = control.scale_image_to(PW, PH)
|
||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device)
|
||||
control.cond_hint = adaptive_resize(control.cond_hint_original, width, height, 'nearest-exact', "center").float().to(control.device)
|
||||
if control.channels_in == 1 and control.cond_hint.shape[1] > 1:
|
||||
control.cond_hint = torch.mean(control.cond_hint, 1, keepdim=True)
|
||||
elif control.__class__.__name__ == 'ControlLLLiteAdvanced':
|
||||
if control.sub_idxs is not None and control.cond_hint_original.shape[0] >= control.full_latent_length:
|
||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
||||
control.cond_hint = adaptive_resize(control.cond_hint_original[control.sub_idxs], PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
||||
else:
|
||||
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
||||
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
||||
else:
|
||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
||||
control.cond_hint = adaptive_resize(control.cond_hint_original, PW, PH, 'nearest-exact', "center").to(dtype=dtype, device=control.device)
|
||||
else:
|
||||
if (PH, PW) == (control.cond_hint_original.shape[-2], control.cond_hint_original.shape[-1]):
|
||||
control.cond_hint = control.cond_hint_original.clone().to(dtype=dtype, device=control.device)
|
||||
else:
|
||||
control.cond_hint = ldm_patched.modules.utils.common_upscale(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device)
|
||||
control.cond_hint = adaptive_resize(control.cond_hint_original, PW, PH, 'nearest-exact', 'center').to(dtype=dtype, device=control.device)
|
||||
|
||||
# Broadcast then tile
|
||||
#
|
||||
@@ -350,8 +383,11 @@ class AbstractDiffusion:
|
||||
control.cond_hint = self.control_params[tuple_key][param_id][batch_id]
|
||||
control = control.previous_controlnet
|
||||
|
||||
|
||||
import numpy as np
|
||||
from numpy import pi, exp, sqrt
|
||||
|
||||
|
||||
def gaussian_weights(tile_w: int, tile_h: int) -> Tensor:
|
||||
'''
|
||||
Copy from the original implementation of Mixture of Diffusers
|
||||
@@ -366,12 +402,14 @@ def gaussian_weights(tile_w:int, tile_h:int) -> Tensor:
|
||||
w = np.outer(y_probs, x_probs)
|
||||
return torch.from_numpy(w).to(devices.device, dtype=torch.float32)
|
||||
|
||||
|
||||
class CondDict: ...
|
||||
|
||||
|
||||
class MultiDiffusion(AbstractDiffusion):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, model_function: BaseModel.apply_model, args: dict):
|
||||
def __call__(self, model_function, args: dict):
|
||||
x_in: Tensor = args["input"]
|
||||
t_in: Tensor = args["timestep"]
|
||||
c_in: dict = args["c"]
|
||||
@@ -395,7 +433,7 @@ class MultiDiffusion(AbstractDiffusion):
|
||||
# Background sampling (grid bbox)
|
||||
if self.draw_background:
|
||||
for batch_id, bboxes in enumerate(self.batched_bboxes):
|
||||
if ldm_patched.modules.model_management.processing_interrupted():
|
||||
if memory_management.processing_interrupted():
|
||||
# self.pbar.close()
|
||||
return x_in
|
||||
|
||||
@@ -439,6 +477,7 @@ class MultiDiffusion(AbstractDiffusion):
|
||||
|
||||
return x_out
|
||||
|
||||
|
||||
class MixtureOfDiffusers(AbstractDiffusion):
|
||||
"""
|
||||
Mixture-of-Diffusers Implementation
|
||||
@@ -470,7 +509,7 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
||||
return self.tile_weights
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, model_function: BaseModel.apply_model, args: dict):
|
||||
def __call__(self, model_function, args: dict):
|
||||
x_in: Tensor = args["input"]
|
||||
t_in: Tensor = args["timestep"]
|
||||
c_in: dict = args["c"]
|
||||
@@ -497,7 +536,7 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
||||
# Global sampling
|
||||
if self.draw_background:
|
||||
for batch_id, bboxes in enumerate(self.batched_bboxes): # batch_id is the `Latent tile batch size`
|
||||
if ldm_patched.modules.model_management.processing_interrupted():
|
||||
if memory_management.processing_interrupted():
|
||||
# self.pbar.close()
|
||||
return x_in
|
||||
|
||||
@@ -574,6 +613,8 @@ class MixtureOfDiffusers(AbstractDiffusion):
|
||||
|
||||
|
||||
MAX_RESOLUTION = 8192
|
||||
|
||||
|
||||
class TiledDiffusion():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -586,6 +627,7 @@ class TiledDiffusion():
|
||||
"tile_overlap": ("INT", {"default": 8 * opt_f, "min": 0, "max": 256 * opt_f, "step": 4 * opt_f}),
|
||||
"tile_batch_size": ("INT", {"default": 4, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "apply"
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
Reference in New Issue
Block a user