forge 2.0.0
see also discussions
This commit is contained in:
@@ -8,7 +8,7 @@ import platform
|
||||
|
||||
from enum import Enum
|
||||
from backend import stream
|
||||
from backend.args import args, dynamic_args
|
||||
from backend.args import args
|
||||
|
||||
|
||||
cpu = torch.device('cpu')
|
||||
@@ -281,12 +281,8 @@ except:
|
||||
print("Could not pick default device.")
|
||||
|
||||
if 'rtx' in torch_device_name.lower():
|
||||
if not args.pin_shared_memory:
|
||||
print('Hint: your device supports --pin-shared-memory for potential speed improvements.')
|
||||
if not args.cuda_malloc:
|
||||
print('Hint: your device supports --cuda-malloc for potential speed improvements.')
|
||||
if not args.cuda_stream:
|
||||
print('Hint: your device supports --cuda-stream for potential speed improvements.')
|
||||
|
||||
|
||||
current_loaded_models = []
|
||||
@@ -305,8 +301,54 @@ def state_dict_size(sd, exclude_device=None):
|
||||
return module_mem
|
||||
|
||||
|
||||
def state_dict_dtype(state_dict):
|
||||
for k in state_dict.keys():
|
||||
if 'bitsandbytes__nf4' in k:
|
||||
return 'nf4'
|
||||
if 'bitsandbytes__fp4' in k:
|
||||
return 'fp4'
|
||||
|
||||
dtype_counts = {}
|
||||
|
||||
for tensor in state_dict.values():
|
||||
dtype = tensor.dtype
|
||||
if dtype in dtype_counts:
|
||||
dtype_counts[dtype] += 1
|
||||
else:
|
||||
dtype_counts[dtype] = 1
|
||||
|
||||
major_dtype = None
|
||||
max_count = 0
|
||||
|
||||
for dtype, count in dtype_counts.items():
|
||||
if count > max_count:
|
||||
max_count = count
|
||||
major_dtype = dtype
|
||||
|
||||
return major_dtype
|
||||
|
||||
|
||||
def module_size(module, exclude_device=None):
|
||||
return state_dict_size(module.state_dict(), exclude_device=exclude_device)
|
||||
module_mem = 0
|
||||
for p in module.parameters():
|
||||
t = p.data
|
||||
|
||||
if exclude_device is not None:
|
||||
if t.device == exclude_device:
|
||||
continue
|
||||
|
||||
element_size = t.element_size()
|
||||
|
||||
if getattr(p, 'quant_type', None) in ['fp4', 'nf4']:
|
||||
if element_size > 1:
|
||||
# not quanted yet
|
||||
element_size = 0.55 # a bit more than 0.5 because of quant state parameters
|
||||
else:
|
||||
# quanted
|
||||
element_size = 1.1 # a bit more than 0.5 because of quant state parameters
|
||||
|
||||
module_mem += t.nelement() * element_size
|
||||
return module_mem
|
||||
|
||||
|
||||
class LoadedModel:
|
||||
@@ -587,11 +629,6 @@ def unet_inital_load_device(parameters, dtype):
|
||||
|
||||
|
||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
unet_storage_dtype_overwrite = dynamic_args.get('forge_unet_storage_dtype')
|
||||
|
||||
if unet_storage_dtype_overwrite is not None:
|
||||
return unet_storage_dtype_overwrite
|
||||
|
||||
if args.unet_in_bf16:
|
||||
return torch.bfloat16
|
||||
|
||||
@@ -1040,6 +1077,18 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
|
||||
def can_install_bnb():
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
cuda_version = tuple(int(x) for x in torch.version.cuda.split('.'))
|
||||
|
||||
if cuda_version >= (11, 7):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
||||
Reference in New Issue
Block a user