Revert "change some dtype behaviors based on community feedbacks"
This reverts commit 31bed671ac.
This commit is contained in:
@@ -301,13 +301,6 @@ def state_dict_size(sd, exclude_device=None):
|
||||
return module_mem
|
||||
|
||||
|
||||
def state_dict_parameters(sd):
|
||||
module_mem = 0
|
||||
for k, v in sd.items():
|
||||
module_mem += v.nelement()
|
||||
return module_mem
|
||||
|
||||
|
||||
def state_dict_dtype(state_dict):
|
||||
for k, v in state_dict.items():
|
||||
if hasattr(v, 'is_gguf'):
|
||||
@@ -660,22 +653,44 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
||||
|
||||
for candidate in supported_dtypes:
|
||||
if candidate == torch.float16:
|
||||
if should_use_fp16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
|
||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||
return candidate
|
||||
if candidate == torch.bfloat16:
|
||||
if should_use_bf16(device, model_params=model_params, prioritize_performance=True, manual_cast=True):
|
||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||
return candidate
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_computation_dtype(inference_device, parameters=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
# None means no manual cast
|
||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
if weight_dtype == torch.float32:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||
if fp16_supported and weight_dtype == torch.float16:
|
||||
return None
|
||||
|
||||
bf16_supported = should_use_bf16(inference_device)
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
if fp16_supported and torch.float16 in supported_dtypes:
|
||||
return torch.float16
|
||||
|
||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||
for candidate in supported_dtypes:
|
||||
if candidate == torch.float16:
|
||||
if should_use_fp16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
|
||||
if should_use_fp16(inference_device, prioritize_performance=False):
|
||||
return candidate
|
||||
if candidate == torch.bfloat16:
|
||||
if should_use_bf16(inference_device, model_params=parameters, prioritize_performance=True, manual_cast=False):
|
||||
if should_use_bf16(inference_device):
|
||||
return candidate
|
||||
|
||||
return torch.float32
|
||||
@@ -1005,17 +1020,19 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if props.major < 6:
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
# when the model doesn't actually fit on the card
|
||||
# TODO: actually test if GP106 and others have the same type of behavior
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
if manual_cast:
|
||||
# For storage dtype
|
||||
free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
else:
|
||||
# For computation dtype
|
||||
return False # Flux on 1080 can store model in fp16 to reduce swap, but computation must be fp32, otherwise super slow.
|
||||
fp16_works = True
|
||||
|
||||
if fp16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
if props.major < 7:
|
||||
return False
|
||||
@@ -1063,7 +1080,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
|
||||
if bf16_works or manual_cast:
|
||||
free_model_memory = (get_free_memory() * 0.85 - minimum_inference_memory())
|
||||
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
|
||||
@@ -1099,3 +1116,43 @@ def soft_empty_cache(force=False):
|
||||
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key): # TODO: remove
|
||||
return weight
|
||||
|
||||
|
||||
# TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
|
||||
class InterruptProcessingException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
interrupt_processing_mutex = threading.RLock()
|
||||
|
||||
interrupt_processing = False
|
||||
|
||||
|
||||
def interrupt_current_processing(value=True):
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
interrupt_processing = value
|
||||
|
||||
|
||||
def processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
return interrupt_processing
|
||||
|
||||
|
||||
def throw_exception_if_processing_interrupted():
|
||||
global interrupt_processing
|
||||
global interrupt_processing_mutex
|
||||
with interrupt_processing_mutex:
|
||||
if interrupt_processing:
|
||||
interrupt_processing = False
|
||||
raise InterruptProcessingException()
|
||||
|
||||
Reference in New Issue
Block a user