forge 2.0.0

see also discussions
This commit is contained in:
lllyasviel
2024-08-10 19:24:19 -07:00
committed by GitHub
parent 4014013d05
commit cfa5242a75
28 changed files with 785 additions and 1249 deletions
+60 -11
View File
@@ -8,7 +8,7 @@ import platform
from enum import Enum
from backend import stream
from backend.args import args, dynamic_args
from backend.args import args
cpu = torch.device('cpu')
@@ -281,12 +281,8 @@ except:
print("Could not pick default device.")
if 'rtx' in torch_device_name.lower():
if not args.pin_shared_memory:
print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
if not args.cuda_malloc:
print('Hint: your device supports --cuda-malloc for potential speed improvements.')
if not args.cuda_stream:
print('Hint: your device supports --cuda-stream for potential speed improvements.')
current_loaded_models = []
@@ -305,8 +301,54 @@ def state_dict_size(sd, exclude_device=None):
return module_mem
def state_dict_dtype(state_dict):
for k in state_dict.keys():
if 'bitsandbytes__nf4' in k:
return 'nf4'
if 'bitsandbytes__fp4' in k:
return 'fp4'
dtype_counts = {}
for tensor in state_dict.values():
dtype = tensor.dtype
if dtype in dtype_counts:
dtype_counts[dtype] += 1
else:
dtype_counts[dtype] = 1
major_dtype = None
max_count = 0
for dtype, count in dtype_counts.items():
if count > max_count:
max_count = count
major_dtype = dtype
return major_dtype
def module_size(module, exclude_device=None):
return state_dict_size(module.state_dict(), exclude_device=exclude_device)
module_mem = 0
for p in module.parameters():
t = p.data
if exclude_device is not None:
if t.device == exclude_device:
continue
element_size = t.element_size()
if getattr(p, 'quant_type', None) in ['fp4', 'nf4']:
if element_size > 1:
# not quanted yet
element_size = 0.55 # a bit more than 0.5 because of quant state parameters
else:
# quanted
element_size = 1.1 # a bit more than 0.5 because of quant state parameters
module_mem += t.nelement() * element_size
return module_mem
class LoadedModel:
@@ -587,11 +629,6 @@ def unet_inital_load_device(parameters, dtype):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype')
if unet_storage_dtype_overwrite is not None:
return unet_storage_dtype_overwrite
if args.unet_in_bf16:
return torch.bfloat16
@@ -1040,6 +1077,18 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def can_install_bnb():
if not torch.cuda.is_available():
return False
cuda_version = tuple(int(x) for x in torch.version.cuda.split('.'))
if cuda_version >= (11, 7):
return True
return False
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS: