Just beat IT

情報系学生が技術的なことから昨日の夕飯のことまで

【転置畳み込み】PyTorchのConvTranspose2dの動作についてまとめとく【Python】

はじめに

どうも!自宅で筋トレを始めたい、と考え続けているロピタルです('ω')

今回は、研究の中で触れる機会のあったPyTorchConvTranspose2dという関数が分かりにくかったので、分かりやすくまとめてみようと思います('Д')

なお、私はPython及びPyTorchの初心者ですので、間違い等ありましたら教えてください(-"-)

もくじ

経緯

研究の一環で、とあるコードの動作について調べていく中にConvTranspose2dが出てきたのですが、入力と出力のサイズが私の想定と合わず、動作がなかなか理解できませんでした…

いろいろ調べて理解できたので、ここにまとめておきます!

畳み込み

まず、畳み込みについて簡単に説明しておきます。

画像データについて

まず、画像というのは二次元の配列で表すことができます。(モノクロ画像は二次元ですが、カラー画像はRGBの三色分の行列が必要なので、三次元です)

例えば、以下の画像を見て下さい。

この画像は、以下のような配列で表されます。

モノクロ画像だと、0が黒色で、255に近づくにつれ白くなっていきます。

ちなみに先程の画像の生成コードはこんな感じです。

import numpy as np
from PIL import Image

array = np.array([[0.,100.],[200.,255.]])
image = Image.fromarray(array.astype(np.uint8)).resize((512,512),resample = 0)
image.save('sample.png')

畳み込みとは

では、畳み込みについて説明します。

畳み込みとは、カーネルと呼ばれる配列を、画像配列の、カーネルと同サイズの部分画像(ウィンドウ)に繰り返し適用し、各要素の積の和を新たな行列として出力する処理のことです。(文章で書くと全然分からんなぁ...)

例として、以下の画像(3x3)とカーネル(2x2)による畳み込みを見てみましょう。

この場合、画像に赤、緑、青、黄で示している4つのウィンドウにカーネルを適用していきます。

赤のウィンドウに適用した場合、2*1 + 1*0 + 1*0 + 0*1 = 2となります。

同様に他のウィンドウにも適用することで、計4つの数値を計算でき、出力は2x2の配列となります。

この畳み込みを行うコードはこんな感じ。

import torch
import numpy as np

Input = torch.tensor(np.array([[[[2.,1.,3.],[1.,0.,2.],[0.,2.,1.]]]]))
Kernel = torch.tensor(np.array([[[[1.,0.],[0.,1.]]]]))
print(Input)
print(Kernel)

Conv = torch.nn.Conv2d(1,1,3, bias = False)
Conv.weight.data = Kernel

Output = Conv(Input)
print(Output)

上記のコードで、以下の出力が得られます。

tensor([[[[2., 1., 3.],
          [1., 0., 2.],
          [0., 2., 1.]]]], dtype=torch.float64)
tensor([[[[1., 0.],
          [0., 1.]]]], dtype=torch.float64)
tensor([[[[2., 3.],
          [3., 1.]]]], dtype=torch.float64, grad_fn=<ThnnConv2DBackward>)

画像通りの出力が得られていますね。畳み込みでは、カーネルの値を変えることで画像の特徴を抽出したりできます。

畳み込みのオプション

前章の畳み込みは非常に単純な畳み込み処理でした。

畳み込みには、様々なオプション(?)が存在します。発展的な手法ってことですね。

それらを簡単に紹介します。

パディング

画像の周囲に値が0のピクセルを追加してから畳み込む手法です。パディングの幅が1なら上下左右に1列ずつ0が追加されるので、2x2の画像なら4x4の配列になります。

ストライド

画像にカーネルを適用する際、ウィンドウを数個とばしで適用していくことができます。ストライドが1であれば前例のように全ウィンドウに適用することになりますが、ストライドが2ならば1つとばしに、3なら2つとばしに適用することになります。

畳み込み例

先程提示した畳み込みに、パディングを幅1で適用し、ストライドを2にして畳み込みをしてみます。

まず、パディングを行うと以下のようになります。

そして、ストライドが2であるため以下の4色のウィンドウにカーネルが適用され、配列が出力されます。

これは以下のコードで実行できます。

import torch
import numpy as np

Input = torch.tensor(np.array([[[[2.,1.,3.],[1.,0.,2.],[0.,2.,1.]]]]))
Kernel = torch.tensor(np.array([[[[1.,0.],[0.,1.]]]]))
Conv = torch.nn.Conv2d(1,1,3, stride=2, padding=1, bias = False)
Conv.weight.data = Kernel

Output = Conv(Input)
print(Output)

上記のコードで、以下の出力が得られます。

tensor([[[[2., 3.],
          [0., 1.]]]], dtype=torch.float64, grad_fn=<ThnnConv2DBackward>)

画像と同じ出力が得られていますね~。

PyTorchでは、Conv2dという関数を用いることで簡単に畳み込み処理を行うことができます。パディングやストライドもオプションとして数値を指定するだけなので簡単ですね!

pytorch.org

転置畳み込み

さて、やっと本題です。転置畳み込みとは、逆畳み込みとも呼ばれる処理のことで、名前の通り畳み込みの逆の処理を行います。

畳み込みでは、複数のピクセルカーネルを適用し、一つのピクセルを生成します。

転置畳み込みでは逆に、一つのピクセルカーネルを適用して、複数のピクセルを生成します。

イメージだと、下図のようになりますね(-"-)

では、具体的にどのような動作を行うのか、PyTorchで転置畳み込みを行う関数であるConvTranspose2dの動作で見ていきましょう。

処理概要

転置畳み込みは、一つのピクセルから複数のピクセルを生成するというイメージしにくい処理を行うのですが、実は行う処理としては畳み込みと同じなんです(?)

ただし畳み込みと違い、畳み込む前の配列に、値が0ピクセルを追加して配列の拡大を行います。その拡大した配列を畳み込むことで、「一つのピクセルから複数のピクセルの生成」を実現します。

入力配列の拡大

畳み込みでも、パディングの章で説明したように配列の周囲に0を追加していく処理がありましたが、転置畳み込みでは3種類の追加処理があります。それぞれについて見ていきましょう!!

ストライド

さて、畳み込みでも出てきた「ストライド」です。

畳み込みでは、畳み込む際にウィンドウを何個とばしで処理していくか、という値でしたが、転置畳み込みにおいては、何ピクセルごとに元の配列のピクセルが入るか、という値になります。

例えばストライドが3の場合、元の配列の各ピクセルの間に2ピクセルずつ0が入れられます。

パディング

畳み込みでも出てきた「パディング」です。

これは、畳み込みの「パディング」と同じで、配列の周囲に0を埋めていきます。

※注意

ここで、私がPyTorchConvTranspose2dという関数の理解に苦しんだ原因についてお話します。

この関数では、paddingという名前のオプションがあり、これでパディングを行う幅が決まっています。

しかし、paddingの値がそのままパディングの幅になる訳ではありません。

ConvTranspose2dでは、Conv2dによる出力のサイズと整合性をとるため、dilation * (kernel_size - 1) - paddingがパディング幅となります。( kernel_size はカーネルの一辺の長さ、dilationは拡張畳み込みと言われる手法を利用しない場合は1です。拡張畳み込みの説明は割愛しますm(__)m )

整合性というのは、例えば3x3の配列に幅1のパディングを適用して2x2カーネルで畳み込みを行った場合、出力は4x4になりますよね?

転置畳み込みは、その逆の動作を行いたいという発想の処理です。そのため、Conv2dに設定したpaddingの値と同じ値をConvTranspose2dpaddingに設定することで、Conv2dで得られた出力を、その入力と同じサイズに変化させられる処理ができるようになっています。

4x4の配列を2x2カーネルで畳み込んで3x3の出力を得たい場合、パディングは必要ありません。しかし、ConvTranspose2dではpaddingの値を0にするのではなく、Conv2dで設定した1を設定することで、逆向きの処理を行うことができます。dilation * (kernel_size - 1) - paddingdilation=1,kernelsize=2,padding=1を入れると、たしかに0になります。

勘違いしやすいと思うので、気を付けましょう('Д')

アウトプットパディング

これは、畳み込みには出てきませんでしたね。

というか、ConvTranspose2dのオプション名がoutput_paddingであるため「アウトプットパディング」と称しましたが、実際にこういう名前かは分かりません(笑)

これは、「パディング」に近い処理をするのですが、周囲4方向に0を埋めるのではなく、右と下にのみ0を埋めます。これは、出力される配列のサイズを調整するために利用するみたいです!

以下の入力配列とカーネルを利用した転置畳み込みの例を見ていきましょう!

まず入力配列の拡大を行います。ConvTranspose2dの引数を、stride=2,padding=1,output_padding=1として拡大します。パディング幅は前述した通り、拡張畳み込みを行わない場合はkernel_size - 1 - paddingで求められるので、今回は0となり、以下のような拡大となります。

あとは、畳み込みと同じように処理するだけです!!

なんとなく、転置畳み込みの処理が分かっていただけたでしょうか??今の例の転置畳み込みを行うコードはこんな感じ。

import torch
import numpy as np

Input = torch.tensor(np.array([[[[2.,3.,2.],[1.,2.,4.],[4.,2.,1.]]]]))
Kernel = torch.tensor(np.array([[[[2.,1.],[1.,2.]]]]))
print(Input)
print(Kernel)

Conv = torch.nn.ConvTranspose2d(1,1,2, stride=2, padding=1, output_padding=1, bias = False)
Conv.weight.data = Kernel

Output = Conv(Input)
print(Output)

これを実行すると、以下の出力が得られます('ω')

tensor([[[[2., 3., 2.],
          [1., 2., 4.],
          [4., 2., 1.]]]], dtype=torch.float64)
tensor([[[[2., 1.],
          [1., 2.]]]], dtype=torch.float64)
tensor([[[[4., 3., 6., 2., 4.],
          [1., 4., 2., 8., 4.],
          [2., 2., 4., 4., 8.],
          [4., 4., 2., 2., 1.],
          [8., 2., 4., 1., 2.]]]], dtype=torch.float64,
       grad_fn=<SlowConvTranspose2DBackward>)

手計算と同じ結果になっています!!

まとめ

今回は転置畳み込みの処理について書いてみました。

分かりにくい文になってしまったかもしれませんが、お役に立てれば幸いです( *´艸`)

でわノシ