Instruction
PixelRNN은 VAE, GAN과 같이 latent를 받아서 image x를 생성하는 것들과는 다른 방법을 통해 이미지를 생성한다.
이 PixelRNN 전반적인 conditional에 대한 내용 등은 논문 및 포스트들에 잘 설명되어있다.
이 블로그에선 PixelRNN 중 하나인 PixelCNN의 구현에 대해서 정리해보고자한다.
Implementation
우선 구현체에 대해 요약해서 말해보자면 Input으로 [-1, 1] 구간의 값을 갖는 Image를 받아서 output으로 input으로 들어온 image의 bit 형식에 맞게 (8비트 -> 256) 개수의 클래스를 갖도록 한다. 이는 각 픽셀, 채널 당의 값을 클래스로서 갖도록 하기 위함이다.
위 내용을 간단히 shape으로 표기하겠다.
(torch 기준)
Input : [B,C,H,W] (-1~1) -> Model-> prediction : [B, 2**n_bits, C, H, W]
target : [B, C, H, W] (0 ~ 2**n_bits -1)
이렇게 되면 pytorch에서 cross entropy loss를 통해 loss를 구할 수 있다. (loss안에 log softmax 함수가 포함된다.)
전체적인 형식은 이렇다.
PixelCNN에서 가장 중요한 mask CNN 부분이다.
maskCNN은 크게 두 구간으로 나뉜다.
- Convolutional weight에 적용할 mask 생성
- mask가 적용된 Convolutional weight으로 Forward.
1채널이 아닌 3채널 RGB의 경우에 대해 이야기하겠다.
우선 하나의 전제를 두자면 channel 각각의 순서는 R-G-B 순서이다.
그럼 이제 맨 처음 Input에 masked Convolution을 적용할 때는 R 먼저 적용을 할텐데, 이때는 첫 적용이기 때문에 이전 레이어에 대한 R 자기자신의 정보가 없다. 그래서 조건으로 Context만 받을 수 있다. 이렇게 R이 나오고, G에 대해 적용할 때는 이전에 얻은 정보인 R과 Context를 Condition으로 사용할 수 있다. 이렇게 이전 레이어의 자기 자신에 대한 정보를 사용하지 않는 것을 A type mask라고 한다.
반면에 B type mask는 mask conv를 두번째 적용 할 때부터 적용하는데 이 때는 이전 레이어에 대한 자기 자신 채널의 정보가 있기 때문에 R을 생성할 때 이전의 R과 Context를 받을 수 있다. G, B도 마찬가지이다.
이 정보에 따라 maskConv를 구현해보자.
class MaskedConv2d(nn.Conv2d):
def __init__(self, *args, mask_type,n_channel, gated = True, **kwargs):
super().__init__(*args, **kwargs)
nn.init.constant_(self.bias, 0.) # bias 0
out_c, in_c, h, w = self.weight.size()
self.mask = torch.zeros_like(self.weight).to('cuda' if torch.cuda.is_available() else 'cpu')
c_row = self.kernel_size[0]//2
c_col = self.kernel_size[1]//2
self.mask[:, :, :c_row, :] = 1
self.mask[:, :, c_row, :c_col+1] = 1
# Mask Type에 따라 RGB 채널 간의 관계
for i in range(n_channel): # in channel
for j in range(n_channel): # out channel
if (mask_type == 'a' and i >= j) or (mask_type == 'b' and i > j): # conditional (r|context), (g|r, constext), (b|r,g, context)
self.mask[j::n_channel, i::n_channel, c_row, c_col] = 0
def forward(self, x):
self.weight.data = self.mask
return super().forward(x)
대충 이렇다.
어쨌든 이런 구조를 가진 mask conv가지고 논문에 나온대로 architecture를 적용한 다음 훈련을 시키면 된다.
그럼 샘플링은 어떻게 할것인가?
그냥 torch.zeors_like 같은 모듈을 사용해서 0으로만 이루어진 텐서를 만든다.
이 텐서를 가지고 모듈에 넣어주면 된다.
예를 들면 내가 10개의 샘플을 만들고 싶다고 하자.
def sample(shape, count, device='cuda'):
channels, height, width = shape
samples = torch.zeros(count, *shape).to(device)
with torch.no_grad():
for i in range(height):
for j in range(width):
for c in range(channels):
unnormalized_probs = model(samples)
pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1)
sampled_levels = torch.multinomial(pixel_probs, 1).squeeze().float()
samples[:, c, i, j] = sampled_levels
return samples
그럼 위처럼 각각 위치, 채널에 대해 하나하나씩 샘플링 하면 된다.
끝.
'Generative Model' 카테고리의 다른 글
Density Estimation Using Real NVP (0) | 2021.09.26 |
---|---|
NICE : Non-linear Independent Components Estimation (0) | 2021.09.24 |
Generating Diverse High-Fidelity Images with VQ-VAE2 (0) | 2021.09.20 |
Neural Discrete Representation Learning : VQ-VAE (0) | 2021.09.17 |
Variational Auto Encoder를 이해해보자! (3) (0) | 2021.09.14 |