revise stream logics
This commit is contained in:
@@ -340,8 +340,6 @@ class LoadedModel:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
if not disable_async_load:
|
if not disable_async_load:
|
||||||
flag = 'ASYNC' if stream.using_stream else 'SYNC'
|
|
||||||
print(f"[Memory Management] Requested {flag} Preserved Memory (MB) = ", async_kept_memory / (1024 * 1024))
|
|
||||||
real_async_memory = 0
|
real_async_memory = 0
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
for m in self.real_model.modules():
|
for m in self.real_model.modules():
|
||||||
@@ -360,9 +358,14 @@ class LoadedModel:
|
|||||||
elif hasattr(m, "weight"):
|
elif hasattr(m, "weight"):
|
||||||
m.to(self.device)
|
m.to(self.device)
|
||||||
mem_counter += module_size(m)
|
mem_counter += module_size(m)
|
||||||
print(f"[Memory Management] {flag} Loader Disabled for", type(m).__name__)
|
print(f"[Memory Management] Swap disabled for", type(m).__name__)
|
||||||
print(f"[Memory Management] Parameters Loaded to {flag} Stream (MB) = ", real_async_memory / (1024 * 1024))
|
|
||||||
print(f"[Memory Management] Parameters Loaded to GPU (MB) = ", mem_counter / (1024 * 1024))
|
if stream.should_use_stream():
|
||||||
|
print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (asynchronous method)")
|
||||||
|
else:
|
||||||
|
print(f"[Memory Management] Loaded to CPU Swap: {real_async_memory / (1024 * 1024):.2f} MB (blocked method)")
|
||||||
|
|
||||||
|
print(f"[Memory Management] Loaded to GPU: {mem_counter / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
@@ -390,8 +393,12 @@ class LoadedModel:
|
|||||||
return self.model is other.model # and self.memory_required == other.memory_required
|
return self.model is other.model # and self.memory_required == other.memory_required
|
||||||
|
|
||||||
|
|
||||||
|
current_inference_memory = 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return 1024 * 1024 * 1024
|
global current_inference_memory
|
||||||
|
return current_inference_memory
|
||||||
|
|
||||||
|
|
||||||
def unload_model_clones(model):
|
def unload_model_clones(model):
|
||||||
@@ -487,17 +494,17 @@ def load_models_gpu(models, memory_required=0):
|
|||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||||
model_memory = loaded_model.model_memory_required(torch_dev)
|
model_memory = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
minimal_inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
estimated_remaining_memory = current_free_mem - model_memory - minimal_inference_memory
|
estimated_remaining_memory = current_free_mem - model_memory - inference_memory
|
||||||
|
|
||||||
print("[Memory Management] Current Free GPU Memory (MB) = ", current_free_mem / (1024 * 1024))
|
print(f"[Memory Management] Current Free GPU Memory: {current_free_mem / (1024 * 1024):.2f} MB")
|
||||||
print("[Memory Management] Model Memory (MB) = ", model_memory / (1024 * 1024))
|
print(f"[Memory Management] Required Model Memory: {model_memory / (1024 * 1024):.2f} MB")
|
||||||
print("[Memory Management] Minimal Inference Memory (MB) = ", minimal_inference_memory / (1024 * 1024))
|
print(f"[Memory Management] Required Inference Memory: {inference_memory / (1024 * 1024):.2f} MB")
|
||||||
print("[Memory Management] Estimated Remaining GPU Memory (MB) = ", estimated_remaining_memory / (1024 * 1024))
|
print(f"[Memory Management] Estimated Remaining GPU Memory: {estimated_remaining_memory / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
if estimated_remaining_memory < 0:
|
if estimated_remaining_memory < 0:
|
||||||
vram_set_state = VRAMState.LOW_VRAM
|
vram_set_state = VRAMState.LOW_VRAM
|
||||||
async_kept_memory = (current_free_mem - minimal_inference_memory) / 1.3
|
async_kept_memory = (current_free_mem - inference_memory) / 1.3
|
||||||
async_kept_memory = int(max(0, async_kept_memory))
|
async_kept_memory = int(max(0, async_kept_memory))
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ def weights_manual_cast(layer, x, skip_dtype=False):
|
|||||||
if skip_dtype:
|
if skip_dtype:
|
||||||
target_dtype = None
|
target_dtype = None
|
||||||
|
|
||||||
if stream.using_stream:
|
if stream.should_use_stream():
|
||||||
with stream.stream_context()(stream.mover_stream):
|
with stream.stream_context()(stream.mover_stream):
|
||||||
if layer.weight is not None:
|
if layer.weight is not None:
|
||||||
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
weight = layer.weight.to(device=target_device, dtype=target_dtype, non_blocking=non_blocking)
|
||||||
@@ -39,7 +39,7 @@ def weights_manual_cast(layer, x, skip_dtype=False):
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def main_stream_worker(weight, bias, signal):
|
def main_stream_worker(weight, bias, signal):
|
||||||
if not stream.using_stream or signal is None:
|
if signal is None or not stream.should_use_stream():
|
||||||
yield
|
yield
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ def main_stream_worker(weight, bias, signal):
|
|||||||
|
|
||||||
|
|
||||||
def cleanup_cache():
|
def cleanup_cache():
|
||||||
if not stream.using_stream:
|
if not stream.should_use_stream():
|
||||||
return
|
return
|
||||||
|
|
||||||
stream.current_stream.synchronize()
|
stream.current_stream.synchronize()
|
||||||
|
|||||||
+6
-7
@@ -52,11 +52,10 @@ def get_new_stream():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
current_stream = None
|
def should_use_stream():
|
||||||
mover_stream = None
|
return stream_activated and current_stream is not None and mover_stream is not None
|
||||||
using_stream = False
|
|
||||||
|
|
||||||
if args.cuda_stream:
|
|
||||||
current_stream = get_current_stream()
|
current_stream = get_current_stream()
|
||||||
mover_stream = get_new_stream()
|
mover_stream = get_new_stream()
|
||||||
using_stream = current_stream is not None and mover_stream is not None
|
stream_activated = args.cuda_stream
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ def initialize_forge():
|
|||||||
modules_forge.patch_basic.patch_all_basics()
|
modules_forge.patch_basic.patch_all_basics()
|
||||||
|
|
||||||
from backend import stream
|
from backend import stream
|
||||||
print('CUDA Stream Activated: ', stream.using_stream)
|
print('CUDA Using Stream:', stream.should_use_stream())
|
||||||
|
|
||||||
from modules_forge.shared import diffusers_dir
|
from modules_forge.shared import diffusers_dir
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user