Implement many kernels from scratch
This commit is contained in:
+8
-7
@@ -1,12 +1,13 @@
|
||||
import torch
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTextConfig
|
||||
|
||||
|
||||
class IntegratedCLIP(torch.nn.Module):
|
||||
def __init__(self, config: CLIPTextConfig):
|
||||
def __init__(self, cls, config, add_text_projection=False):
|
||||
super().__init__()
|
||||
self.transformer = CLIPTextModel(config)
|
||||
embed_dim = config.hidden_size
|
||||
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.transformer = cls(config)
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
|
||||
if add_text_projection:
|
||||
embed_dim = config.hidden_size
|
||||
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
|
||||
+9
-4
@@ -397,8 +397,8 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
|
||||
self.decoder = Decoder(double_z=True, z_channels=latent_channels, resolution=256,
|
||||
in_channels=in_channels, out_ch=out_channels, ch=ch, ch_mult=ch_mult,
|
||||
num_res_blocks=layers_per_block, attn_resolutions=[], dropout=0.0)
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||
self.embed_dim = latent_channels
|
||||
self.scaling_factor = scaling_factor
|
||||
self.shift_factor = shift_factor
|
||||
@@ -408,7 +408,10 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
|
||||
|
||||
def encode(self, x, regulation=None):
|
||||
z = self.encoder(x)
|
||||
z = self.quant_conv(z)
|
||||
|
||||
if self.quant_conv is not None:
|
||||
z = self.quant_conv(z)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(z)
|
||||
if regulation is not None:
|
||||
return regulation(posterior)
|
||||
@@ -416,7 +419,9 @@ class IntegratedAutoencoderKL(nn.Module, ConfigMixin):
|
||||
return posterior.sample()
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
|
||||
x = self.decoder(z)
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user