Implement many kernels from scratch

This commit is contained in:
layerdiffusion
2024-08-06 18:20:34 -07:00
parent 4c8331b806
commit b57573c8da
15 changed files with 209 additions and 100 deletions
+8 -7
View File
@@ -1,12 +1,13 @@
import torch
from transformers import CLIPTextModel, CLIPTextConfig
class IntegratedCLIP(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
def __init__(self, cls, config, add_text_projection=False):
super().__init__()
self.transformer = CLIPTextModel(config)
embed_dim = config.hidden_size
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
self.transformer = cls(config)
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
if add_text_projection:
embed_dim = config.hidden_size
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
+9 -4
View File
@@ -397,8 +397,8 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
self.decoder = Decoder(double_z=True, z_channels=latent_channels, resolution=256,
in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult,
num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
self.embed_dim = latent_channels
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
@@ -408,7 +408,10 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
def encode(self, x, regulation=None):
z = self.encoder(x)
z = self.quant_conv(z)
if self.quant_conv is not None:
z = self.quant_conv(z)
posterior = DiagonalGaussianDistribution(z)
if regulation is not None:
return regulation(posterior)
@@ -416,7 +419,9 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
return posterior.sample()
def decode(self, z):
z = self.post_quant_conv(z)
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
x = self.decoder(z)
return x