forge 2.0.0

see also discussions
This commit is contained in:
lllyasviel
2024-08-10 19:24:19 -07:00
committed by GitHub
parent 4014013d05
commit cfa5242a75
28 changed files with 785 additions and 1249 deletions
+8 -2
View File
@@ -428,11 +428,17 @@ class ControlLora(ControlNet):
controlnet_config = model.diffusion_model.config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["dtype"] = dtype = model.storage_dtype
dtype = model.storage_dtype
if dtype in ['nf4', 'fp4']:
dtype = torch.float16
controlnet_config["dtype"] = dtype
self.manual_cast_dtype = model.computation_dtype
with using_forge_operations(operations=ControlLoraOps):
with using_forge_operations(operations=ControlLoraOps, dtype=dtype):
self.control_model = cldm.ControlNet(**controlnet_config)
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
+9 -11
View File
@@ -3,20 +3,18 @@ import torch
from backend.modules.k_model import KModel
from backend.patcher.base import ModelPatcher
from backend import memory_management
class UnetPatcher(ModelPatcher):
@classmethod
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
parameters = memory_management.module_size(model)
unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = memory_management.get_torch_device()
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=config.supported_inference_dtypes)
model.to(device=initial_load_device, dtype=unet_dtype)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=computation_dtype)
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor)
return UnetPatcher(
model,
load_device=model.diffusion_model.load_device,
offload_device=model.diffusion_model.offload_device,
current_device=model.diffusion_model.initial_device
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -169,8 +167,8 @@ class UnetPatcher(ModelPatcher):
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
return
def set_groupnorm_wrapper(self, wrapper):
self.set_transformer_option('groupnorm_wrapper', wrapper)
def set_group_norm_wrapper(self, wrapper):
self.set_transformer_option('group_norm_wrapper', wrapper)
return
def set_controlnet_model_function_wrapper(self, wrapper):