Taming Scalable Visual Tokenizer for Autoregressive Image Generation

  1. 1 研究背景、动机、主要贡献
    1. 1.1 存在问题(动机)
    2. 1.2 主要贡献
  2. 2 论文提出的新方法
  3. 3 论文实验评估方法与效果

1 研究背景、动机、主要贡献

1.1 存在问题(动机)

现有的方法通过降低潜在空间的维度来缓解表示崩溃问题(只有一小部分 codebook 中向量通过梯度下降更新),但会以牺牲模型容量为代价。

1.2 主要贡献

提出了 IBQ,有效解决了 code book 崩溃问题,在图像重建和生成上都取得了不错的结果。

2 论文提出的新方法

z 和 codebook 相乘再用 softmax 算"相似度",得到 soft_one_hot ( b d h w, n d -> b n h w)

获取"相似度"最大的向量作为 z_q (通过最大"相似度"的 index 转为 b n h w 形状的张量后与 codebook 相乘得到)

而反向传播时则通过更新 最大"相似度"的 index 的来源(soft_one_hot),从而更新所有 code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def forward(self, z, temp=None, return_logits=False):
# z: [b, d, h, w]
# embed.weight: [n, d]

logits = einsum('b d h w, n d -> b n h w', z, self.embedding.weight)

soft_one_hot = F.softmax(logits, dim=1)

dim = 1
ind = soft_one_hot.max(dim, keepdim=True)[1]
hard_one_hot = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, ind, 1.0)
one_hot = hard_one_hot - soft_one_hot.detach() + soft_one_hot

z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embedding.weight)
z_q_2 = einsum('b n h w, n d -> b d h w', hard_one_hot, self.embedding.weight)

quant_loss = torch.mean((z_q - z)**2) + torch.mean((z_q_2.detach()-z)**2) + self.beta * \
torch.mean((z_q_2 - z.detach()) ** 2)
diff = quant_loss

if self.use_entropy_loss:
sample_entropy, avg_entropy, entropy_loss= compute_entropy_loss(logits=logits.permute(0, 2, 3, 1).reshape(-1, self.n_e), temperature=self.entropy_temperature, sample_minimization_weight=self.sample_minimization_weight, batch_maximization_weight=self.batch_maximization_weight) # logits [b d h w] -> [b * h * w, n]
diff = (quant_loss, sample_entropy, avg_entropy, entropy_loss)

ind = torch.flatten(ind)

return z_q, diff, (None, None, ind)

3 论文实验评估方法与效果


原文链接:Taming Scalable Visual Tokenizer for Autoregressive Image Generation