revise kernel

This commit is contained in:
layerdiffusion
2024-08-07 17:24:22 -07:00
parent b61bf553ea
commit e1df7a1bae
2 changed files with 14 additions and 3 deletions
+12
View File
@@ -619,6 +619,18 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return torch.float32
def get_computation_dtype(inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
for candidate in supported_dtypes:
if candidate == torch.float16:
if should_use_fp16(inference_device, prioritize_performance=False):
return candidate
if candidate == torch.bfloat16:
if should_use_bf16(inference_device):
return candidate
return torch.float32
def text_encoder_offload_device():
if args.always_gpu:
return get_torch_device()