The vanilla VAE exhibits distinct clusters whereas the CVAE has a extra homogeneous distribution. Vanilla VAE encodes class and sophistication variation into the latent area since there isn’t any offered conditional sign. Nonetheless, the CVAE doesn’t have to be taught class distinction and the latent area can give attention to the variation inside courses. Subsequently, a CVAE can probably be taught extra data because it doesn’t depend on having to be taught fundamental class conditioning.
Two mannequin architectures have been created to check picture era. The primary structure was a convolutional CVAE with a concatenating conditional strategy. All networks have been constructed for Trend-MNIST photographs of dimension 28×28 (784 whole pixels).
class ConcatConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
tremendous().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, 32)
# Latent area (with concatenated situation)
self.fc_mu = nn.Linear(self.flatten_size + 32, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 32, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim + 32, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Concatenate situation with encoded enter
x = torch.cat([x, c], dim=1)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
c = self.label_embedding(c)
# Concatenate situation with latent vector
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def ahead(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The CVAE encoder consists of three convolutional layers every adopted by a ReLU non-linearity. The output of the encoder is then flattened. The category quantity is then handed by means of an embedding layer and added to the encoder output. The reparameterization trick is then used with 2 linear layers to acquire a μ and σ within the latent area. As soon as sampled, the output of the reparameterized latent area is handed to the decoder now concatenated with the category quantity embedding layer output. The decoder consists of three transposed convolutional layers. The primary two include a ReLU non-linearity with the final layer containing a sigmoid non-linearity. The output of the decoder is a 28×28 generated picture.
The opposite mannequin structure follows the identical strategy however with including the conditional enter as an alternative of concatenating. A significant query was if including or concatenating will result in higher reconstruction or era outcomes.
class AdditiveConditionalVAE(nn.Module):
def __init__(self, latent_dim=128, num_classes=10):
tremendous().__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(1, 32, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 128 * 4 * 4
# Conditional embedding
self.label_embedding = nn.Embedding(num_classes, self.flatten_size)
# Latent area (with out concatenation)
self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
self.fc_var = nn.Linear(self.flatten_size, latent_dim)
# Decoder situation embedding
self.decoder_label_embedding = nn.Embedding(num_classes, latent_dim)
# Decoder
self.decoder_input = nn.Linear(latent_dim, 4 * 4 * 128)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
nn.Sigmoid()
)
def encode(self, x, c):
x = self.encoder(x)
c = self.label_embedding(c)
# Add situation to encoded enter
x = x + c
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
# Add situation to latent vector
c = self.decoder_label_embedding(c)
z = z + c
z = self.decoder_input(z)
z = z.view(-1, 128, 4, 4)
return self.decoder(z)
def ahead(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
The identical loss perform is used for all CVAEs from the equation proven above.
def loss_function(recon_x, x, mu, logvar):
"""Computes the loss = -ELBO = Detrimental Log-Probability + KL Divergence.
Args:
recon_x: Decoder output.
x: Floor reality.
mu: Imply of Z
logvar: Log-Variance of Z
"""
BCE = F.binary_cross_entropy(recon_x, x, discount='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
With a purpose to assess model-generated photographs, 3 quantitative metrics are generally used. Imply Squared Error (MSE) was calculated by summing the squares of the distinction between the generated picture and a floor reality picture pixel-wise. Structural Similarity Index Measure (SSIM) is a metric that evaluates picture high quality by evaluating two photographs primarily based on structural data, luminance, and distinction [3]. SSIM can be utilized to match photographs of any dimension whereas MSE is relative to pixel dimension. SSIM rating ranges from -1 to 1, the place 1 signifies an identical photographs. Frechet inception distance (FID) is a metric for quantifying the realism and variety of photographs generated. As FID is a distance measure, decrease scores are indicative of a greater reconstruction of a set of photographs.
Earlier than scaling as much as full textual content to picture, CVAEs picture reconstruction and era on Trend-MNIST. Trend-MNIST is an MNIST-like dataset consisting of a coaching set of 60,000 examples and a take a look at set of 10,000 examples. Every instance is a 28×28 grayscale picture, related to a label from 10 courses [4].
Preprocessing features have been created to extract the related key phrase containing the category identify from the enter short-text common expression matching. Additional descriptors (synonyms) have been used for many courses to account for comparable trend gadgets included in every class (e.g. Coat & Jacket).
courses = {
'Shirt':0,
'Prime':0,
'Trouser':1,
'Pants':1,
'Pullover':2,
'Sweater':2,
'Hoodie':2,
'Costume':3,
'Coat':4,
'Jacket':4,
'Sandal':5,
'Shirt':6,
'Sneaker':7,
'Shoe':7,
'Bag':8,
'Ankle boot':9,
'Boot':9
}def word_to_text(input_str, courses, mannequin, machine):
label = class_embedding(input_str, courses)
if label == -1: return Exception("No legitimate label")
samples = sample_images(mannequin, num_samples=4, label=label, machine=machine)
plot_samples(samples, input_str, torch.tensor([label]))
return
def class_embedding(input_str, courses):
for key in checklist(courses.keys()):
template = f'(?i)b{key}b'
output = re.search(template, input_str)
if output: return courses[key]
return -1
The category identify was then transformed to its class quantity and used because the conditional enter to the CVAE alongside. With a purpose to generate a picture, the category label extracted from the quick textual content description is handed into the decoder with random samples from a Gaussian distribution to enter the variable from the latent area.
Earlier than testing era, picture reconstruction is examined to make sure the performance of the CVAE. Because of making a convolutional community with 28×28 photographs, the community might be skilled in lower than an hour with lower than 100 epochs.
Reconstructions include the overall form of the bottom reality photographs, however sharp, excessive frequency options are lacking from the picture. Any textual content or intricate design patterns are blurred within the mannequin output. Inputting any quick textual content containing a category of Trend-MNIST offers generated outputs resembling reconstructed photographs.
The generated photographs have an MSE of 11 and a SSIM of 0.76. These represent good generations signifying that in easy, small photographs, CVAEs can generate high quality photographs. GANs and DDPMs will produce larger high quality photographs with complicated options, however CVAEs can deal with easy circumstances.
When scaling as much as picture era to textual content of any size, extra strong strategies could be wanted moreover common expression matching. To do that, Open AI’s CLIP is used to transform textual content right into a excessive dimensional embedding vector. The embedding mannequin is utilized in its ViT-B/32 configuration, which outputs embeddings of size 512. A limitation of the CLIP mannequin is that it has a most token size of 77, with research displaying an excellent smaller efficient size of 20 [5]. Thus, in cases the place the enter textual content accommodates a number of sentences, the textual content is cut up up by sentence and handed by means of the CLIP encoder. The ensuing embeddings are averaged collectively to create the ultimate output embedding.
A protracted textual content mannequin requires much more difficult coaching information than Trend-MNIST, so COCO dataset was used. COCO dataset has annotations (that aren’t fully strong however that will probably be mentioned later) that may be handed into CLIP to get embeddings. Nonetheless, COCO photographs are of dimension 640×480, that means that even with cropping transforms, a bigger community is required. Including and concatenating conditional inputs architectures are each examined for lengthy textual content to picture era, however the concatenating strategy is proven right here:
class cVAE(nn.Module):
def __init__(self, latent_dim=128):
tremendous().__init__()machine = torch.machine("cuda" if torch.cuda.is_available() else "cpu")
self.clip_model, _ = clip.load("ViT-B/32", machine=machine)
self.clip_model.eval()
for param in self.clip_model.parameters():
param.requires_grad = False
self.latent_dim = latent_dim
# Modified encoder for 128x128 enter
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(256, 512, 4, stride=2, padding=1), # 4x4
nn.BatchNorm2d(512),
nn.ReLU(),
nn.Flatten()
)
self.flatten_size = 512 * 4 * 4 # Flattened dimension from encoder
# Course of CLIP embeddings for encoder
self.condition_processor_encoder = nn.Sequential(
nn.Linear(512, 1024)
)
self.fc_mu = nn.Linear(self.flatten_size + 1024, latent_dim)
self.fc_var = nn.Linear(self.flatten_size + 1024, latent_dim)
self.decoder_input = nn.Linear(latent_dim + 512, 512 * 4 * 4)
# Modified decoder for 128x128 output
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), # 8x8
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), # 16x16
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), # 32x32
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1), # 64x64
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1), # 128x128
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 3, 3, stride=1, padding=1), # 128x128
nn.Sigmoid()
)
def encode_condition(self, textual content):
with torch.no_grad():
embeddings = []
for sentence in textual content:
embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).kind(torch.float32))
return torch.imply(torch.stack(embeddings), dim=0)
def encode(self, x, c):
x = self.encoder(x)
c = self.condition_processor_encoder(c)
x = torch.cat([x, c], dim=1)
return self.fc_mu(x), self.fc_var(x)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
z = torch.cat([z, c], dim=1)
z = self.decoder_input(z)
z = z.view(-1, 512, 4, 4)
return self.decoder(z)
def ahead(self, x, c):
mu, log_var = self.encode(x, c)
z = self.reparameterize(mu, log_var)
return self.decode(z, c), mu, log_var
One other main level of investigation was picture era and reconstruction on photographs of various sizes. Particularly, modifying COCO photographs to be of dimension 64×64, 128×128, and 256×256. After coaching the community, reconstruction outcomes ought to first be examined.
All picture sizes result in reconstructed background with some characteristic outlines and proper colours. Nonetheless, as picture dimension will increase, extra options are in a position to be recovered. This is sensible as though it would take so much longer to coach a mannequin with a bigger picture dimension, there may be extra data that may be captured and discovered by the mannequin.
With picture era, this can be very troublesome to generate prime quality photographs. Most photographs have backgrounds to some extent and blurred options within the picture. This could be anticipated for picture era from a CVAE. This happens in each concatenation and addition for the conditional enter, however the concatenated strategy performs higher. That is probably as a result of concatenated conditional inputs is not going to intrude with vital options and ensures data is preserved distinctly. Situations might be ignored if they’re irrelevant. Nonetheless, additive conditional inputs can intrude with current options and fully mess up the community when updating weights throughout backpropagation.
The entire COCO generated photographs have a far decrease SSIM of about 0.4 in comparison with the SSIM on Trend-MNIST. MSE is proportional to picture dimension, so it’s troublesome to quanity variations. FID for COCO picture generations are within the 200s for additional proof that COCO CVAE generated photographs are usually not strong.
The largest limitation in attempting to make use of CVAEs for picture era is, effectively, the CVAE. The quantity of data that may be contained and reconstructed/generated is extraordinarily depending on the dimensions of the latent area. A latent area that’s too small received’t seize any significant data and is proportional to the dimensions of the output picture. A 28×28 picture wants a much smaller latent area than a 64×64 picture (because it proportionally squares from picture dimension). Nonetheless, a latent area greater than the precise picture provides pointless information and at that time simply create a 1-to-1 mapping. For the COCO dataset, a latent area of at the least 512 is required to seize some options. And whereas CVAEs are generative fashions, a convolutional encoder and decoder is a fairly rudimentary community. The coaching type of a GAN or the complicated denoising strategy of a DDPM permits for much extra difficult picture era.
One other main limitation in picture era is the dataset skilled on. Though the COCO dataset has annotations, the annotations are usually not extensively detailed. With a purpose to practice complicated generative fashions, a distinct dataset ought to be used for coaching. COCO doesn’t present areas or extra data for background particulars. A posh characteristic vector from the CLIP encoder can’t be successfully utilized to a CVAE on COCO.
Though CVAEs and picture era on COCO have their limitations, it creates a workable picture era mannequin. Extra code and particulars might be offered simply attain out!
[1] Kingma, Diederik P, et. al. “Auto-encoding variational bayes.” arXiv:1312.6114 (2013).
[2] Sohn, Kihyuk, et. al. “Studying Structured Output Illustration utilizing Deep Conditional Generative Fashions.” NeurIPS Proceedings (2015).
[3] Nilsson, J., et. al. “Understanding ssim.” arXiv:2102.12037 (2020).
[4] Xiao, Han, et. al. “Trend-mnist: a novel picture dataset for benchmarking machine studying algorithms.” arXiv:2403.15378 (2024) (MIT license).
[5] Zhang, B., et. al. “Lengthy-clip: Unlocking the long-text functionality of clip.” arXiv:2403.15378 (2024).
A reference to my group mission companions Jake Hession (Deloitte Marketing consultant), Ashley Hong (Google SWE), and Julian Kuppel (Quant)!