move to new backend - part 1

This commit is contained in:
layerdiffusion
2024-08-03 14:59:46 -07:00
parent a17abbc097
commit 8a01b2c5db
8 changed files with 200 additions and 52 deletions
+18 -18
View File
@@ -1,14 +1,14 @@
import contextlib
import torch
import ldm_patched.modules.model_management as model_management
from backend import memory_management
def has_xpu() -> bool:
return model_management.xpu_available
return memory_management.xpu_available
def has_mps() -> bool:
return model_management.mps_mode()
return memory_management.mps_mode()
def cuda_no_autocast(device_id=None) -> bool:
@@ -16,27 +16,27 @@ def cuda_no_autocast(device_id=None) -> bool:
def get_cuda_device_id():
return model_management.get_torch_device().index
return memory_management.get_torch_device().index
def get_cuda_device_string():
return str(model_management.get_torch_device())
return str(memory_management.get_torch_device())
def get_optimal_device_name():
return model_management.get_torch_device().type
return memory_management.get_torch_device().type
def get_optimal_device():
return model_management.get_torch_device()
return memory_management.get_torch_device()
def get_device_for(task):
return model_management.get_torch_device()
return memory_management.get_torch_device()
def torch_gc():
model_management.soft_empty_cache()
memory_management.soft_empty_cache()
def torch_npu_set_device():
@@ -49,15 +49,15 @@ def enable_tf32():
cpu: torch.device = torch.device("cpu")
fp8: bool = False
device: torch.device = model_management.get_torch_device()
device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now
device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system
dtype: torch.dtype = model_management.unet_dtype()
dtype_vae: torch.dtype = model_management.vae_dtype()
dtype_unet: torch.dtype = model_management.unet_dtype()
dtype_inference: torch.dtype = model_management.unet_dtype()
device: torch.device = memory_management.get_torch_device()
device_interrogate: torch.device = memory_management.text_encoder_device() # for backward compatibility, not used now
device_gfpgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
device_esrgan: torch.device = memory_management.get_torch_device() # will be managed by memory management system
device_codeformer: torch.device = memory_management.get_torch_device() # will be managed by memory management system
dtype: torch.dtype = memory_management.unet_dtype()
dtype_vae: torch.dtype = memory_management.vae_dtype()
dtype_unet: torch.dtype = memory_management.unet_dtype()
dtype_inference: torch.dtype = memory_management.unet_dtype()
unet_needs_upcast = False