forge 2.0.0
see also discussions
This commit is contained in:
@@ -428,11 +428,17 @@ class ControlLora(ControlNet):
|
||||
controlnet_config = model.diffusion_model.config.copy()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||
controlnet_config["dtype"] = dtype = model.storage_dtype
|
||||
|
||||
dtype = model.storage_dtype
|
||||
|
||||
if dtype in ['nf4', 'fp4']:
|
||||
dtype = torch.float16
|
||||
|
||||
controlnet_config["dtype"] = dtype
|
||||
|
||||
self.manual_cast_dtype = model.computation_dtype
|
||||
|
||||
with using_forge_operations(operations=ControlLoraOps):
|
||||
with using_forge_operations(operations=ControlLoraOps, dtype=dtype):
|
||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
self.control_model.to(device=memory_management.get_torch_device(), dtype=dtype)
|
||||
|
||||
+9
-11
@@ -3,20 +3,18 @@ import torch
|
||||
|
||||
from backend.modules.k_model import KModel
|
||||
from backend.patcher.base import ModelPatcher
|
||||
from backend import memory_management
|
||||
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
@classmethod
|
||||
def from_model(cls, model, diffusers_scheduler, config, k_predictor=None):
|
||||
parameters = memory_management.module_size(model)
|
||||
unet_dtype = memory_management.unet_dtype(model_params=parameters)
|
||||
load_device = memory_management.get_torch_device()
|
||||
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
computation_dtype = memory_management.get_computation_dtype(load_device, supported_dtypes=config.supported_inference_dtypes)
|
||||
model.to(device=initial_load_device, dtype=unet_dtype)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor, storage_dtype=unet_dtype, computation_dtype=computation_dtype)
|
||||
return UnetPatcher(model, load_device=load_device, offload_device=memory_management.unet_offload_device(), current_device=initial_load_device)
|
||||
model = KModel(model=model, diffusers_scheduler=diffusers_scheduler, k_predictor=k_predictor)
|
||||
return UnetPatcher(
|
||||
model,
|
||||
load_device=model.diffusion_model.load_device,
|
||||
offload_device=model.diffusion_model.offload_device,
|
||||
current_device=model.diffusion_model.initial_device
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -169,8 +167,8 @@ class UnetPatcher(ModelPatcher):
|
||||
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_groupnorm_wrapper(self, wrapper):
|
||||
self.set_transformer_option('groupnorm_wrapper', wrapper)
|
||||
def set_group_norm_wrapper(self, wrapper):
|
||||
self.set_transformer_option('group_norm_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_controlnet_model_function_wrapper(self, wrapper):
|
||||
|
||||
Reference in New Issue
Block a user