z PixelCNN의 실행 흐름
본문 바로가기

Generative Model

PixelCNN의 실행 흐름

728x90

Instruction

PixelRNN은 VAE, GAN과 같이 latent를 받아서 image x를 생성하는 것들과는 다른 방법을 통해 이미지를 생성한다.

이 PixelRNN 전반적인 conditional에 대한 내용 등은 논문 및 포스트들에 잘 설명되어있다.

이 블로그에선 PixelRNN 중 하나인 PixelCNN의 구현에 대해서 정리해보고자한다.

 

Implementation

우선 구현체에 대해 요약해서 말해보자면 Input으로 [-1, 1] 구간의 값을 갖는 Image를 받아서 output으로 input으로 들어온 image의 bit 형식에 맞게 (8비트 -> 256) 개수의 클래스를 갖도록 한다. 이는 각 픽셀, 채널 당의 값을 클래스로서 갖도록 하기 위함이다.

 

위 내용을 간단히 shape으로 표기하겠다.

(torch 기준)

Input : [B,C,H,W] (-1~1) -> Model-> prediction : [B, 2**n_bits, C, H, W]

target : [B, C, H, W] (0 ~ 2**n_bits -1)

 

이렇게 되면 pytorch에서 cross entropy loss를 통해 loss를 구할 수 있다. (loss안에 log softmax 함수가 포함된다.)

 

전체적인 형식은 이렇다.

 

PixelCNN에서 가장 중요한 mask CNN 부분이다.

 

maskCNN은 크게 두 구간으로 나뉜다.

  1. Convolutional weight에 적용할 mask 생성
  2. mask가 적용된 Convolutional weight으로 Forward.

1채널이 아닌 3채널 RGB의 경우에 대해 이야기하겠다.

우선 하나의 전제를 두자면 channel 각각의 순서는 R-G-B 순서이다.

 

그럼 이제 맨 처음 Input에 masked Convolution을 적용할 때는 R 먼저 적용을 할텐데, 이때는 첫 적용이기 때문에 이전 레이어에 대한 R 자기자신의 정보가 없다. 그래서 조건으로 Context만 받을 수 있다. 이렇게 R이 나오고, G에 대해 적용할 때는 이전에 얻은 정보인 R과 Context를 Condition으로 사용할 수 있다. 이렇게 이전 레이어의 자기 자신에 대한 정보를 사용하지 않는 것을 A type mask라고 한다.

 

반면에 B type mask는 mask conv를 두번째 적용 할 때부터 적용하는데 이 때는 이전 레이어에 대한 자기 자신 채널의 정보가 있기 때문에 R을 생성할 때 이전의 R과 Context를 받을 수 있다. G, B도 마찬가지이다.

 

이 정보에 따라 maskConv를 구현해보자.

 

class MaskedConv2d(nn.Conv2d):
    def __init__(self, *args, mask_type,n_channel, gated = True, **kwargs):
        super().__init__(*args, **kwargs)
        nn.init.constant_(self.bias, 0.) # bias 0

        out_c, in_c, h, w = self.weight.size()


        self.mask = torch.zeros_like(self.weight).to('cuda' if torch.cuda.is_available() else 'cpu')

        c_row = self.kernel_size[0]//2
        c_col = self.kernel_size[1]//2

        self.mask[:, :, :c_row, :] = 1
        self.mask[:, :, c_row, :c_col+1] = 1

        # Mask Type에 따라 RGB 채널 간의 관계
        for i in range(n_channel): # in channel
            for j in range(n_channel): # out channel
                if (mask_type == 'a' and i >= j) or (mask_type == 'b' and i > j): # conditional (r|context), (g|r, constext), (b|r,g, context)
                    self.mask[j::n_channel, i::n_channel, c_row, c_col] = 0
        
    def forward(self, x):
        self.weight.data = self.mask
        return super().forward(x)

대충 이렇다.

 

어쨌든 이런 구조를 가진 mask conv가지고 논문에 나온대로 architecture를 적용한 다음 훈련을 시키면 된다.

 

그럼 샘플링은 어떻게 할것인가?

그냥 torch.zeors_like 같은 모듈을 사용해서 0으로만 이루어진 텐서를 만든다.

이 텐서를 가지고 모듈에 넣어주면 된다.

 

예를 들면 내가 10개의 샘플을 만들고 싶다고 하자.

def sample(shape, count, device='cuda'):
	channels, height, width = shape

	samples = torch.zeros(count, *shape).to(device)

	with torch.no_grad():
	for i in range(height):
		for j in range(width):
			for c in range(channels):
				unnormalized_probs = model(samples)
				pixel_probs = torch.softmax(unnormalized_probs[:, :, c, i, j], dim=1)
				sampled_levels = torch.multinomial(pixel_probs, 1).squeeze().float()
				samples[:, c, i, j] = sampled_levels
	return samples

 

그럼 위처럼 각각 위치, 채널에 대해 하나하나씩 샘플링 하면 된다.

끝.

 

 

728x90