【転置畳み込み】PyTorchのConvTranspose2dの動作についてまとめとく【Python】
はじめに
どうも!自宅で筋トレを始めたい、と考え続けているロピタルです('ω')
今回は、研究の中で触れる機会のあったPyTorch
のConvTranspose2d
という関数が分かりにくかったので、分かりやすくまとめてみようと思います('Д')
なお、私はPython
及びPyTorch
の初心者ですので、間違い等ありましたら教えてください(-"-)
もくじ
経緯
研究の一環で、とあるコードの動作について調べていく中にConvTranspose2d
が出てきたのですが、入力と出力のサイズが私の想定と合わず、動作がなかなか理解できませんでした…
いろいろ調べて理解できたので、ここにまとめておきます!
畳み込み
まず、畳み込みについて簡単に説明しておきます。
画像データについて
まず、画像というのは二次元の配列で表すことができます。(モノクロ画像は二次元ですが、カラー画像はRGBの三色分の行列が必要なので、三次元です)
例えば、以下の画像を見て下さい。
この画像は、以下のような配列で表されます。
モノクロ画像だと、0
が黒色で、255
に近づくにつれ白くなっていきます。
ちなみに先程の画像の生成コードはこんな感じです。
import numpy as np from PIL import Image array = np.array([[0.,100.],[200.,255.]]) image = Image.fromarray(array.astype(np.uint8)).resize((512,512),resample = 0) image.save('sample.png')
畳み込みとは
では、畳み込みについて説明します。
畳み込みとは、カーネルと呼ばれる配列を、画像配列の、カーネルと同サイズの部分画像(ウィンドウ)に繰り返し適用し、各要素の積の和を新たな行列として出力する処理のことです。(文章で書くと全然分からんなぁ...)
例として、以下の画像(3x3)とカーネル(2x2)による畳み込みを見てみましょう。
この場合、画像に赤、緑、青、黄で示している4つのウィンドウにカーネルを適用していきます。
赤のウィンドウに適用した場合、2*1 + 1*0 + 1*0 + 0*1 = 2
となります。
同様に他のウィンドウにも適用することで、計4つの数値を計算でき、出力は2x2
の配列となります。
この畳み込みを行うコードはこんな感じ。
import torch import numpy as np Input = torch.tensor(np.array([[[[2.,1.,3.],[1.,0.,2.],[0.,2.,1.]]]])) Kernel = torch.tensor(np.array([[[[1.,0.],[0.,1.]]]])) print(Input) print(Kernel) Conv = torch.nn.Conv2d(1,1,3, bias = False) Conv.weight.data = Kernel Output = Conv(Input) print(Output)
上記のコードで、以下の出力が得られます。
tensor([[[[2., 1., 3.], [1., 0., 2.], [0., 2., 1.]]]], dtype=torch.float64) tensor([[[[1., 0.], [0., 1.]]]], dtype=torch.float64) tensor([[[[2., 3.], [3., 1.]]]], dtype=torch.float64, grad_fn=<ThnnConv2DBackward>)
画像通りの出力が得られていますね。畳み込みでは、カーネルの値を変えることで画像の特徴を抽出したりできます。
畳み込みのオプション
前章の畳み込みは非常に単純な畳み込み処理でした。
畳み込みには、様々なオプション(?)が存在します。発展的な手法ってことですね。
それらを簡単に紹介します。
パディング
画像の周囲に値が0のピクセルを追加してから畳み込む手法です。パディングの幅が1なら上下左右に1列ずつ0が追加されるので、2x2
の画像なら4x4
の配列になります。
ストライド
画像にカーネルを適用する際、ウィンドウを数個とばしで適用していくことができます。ストライドが1であれば前例のように全ウィンドウに適用することになりますが、ストライドが2ならば1つとばしに、3なら2つとばしに適用することになります。
畳み込み例
先程提示した畳み込みに、パディングを幅1で適用し、ストライドを2にして畳み込みをしてみます。
まず、パディングを行うと以下のようになります。
そして、ストライドが2であるため以下の4色のウィンドウにカーネルが適用され、配列が出力されます。
これは以下のコードで実行できます。
import torch import numpy as np Input = torch.tensor(np.array([[[[2.,1.,3.],[1.,0.,2.],[0.,2.,1.]]]])) Kernel = torch.tensor(np.array([[[[1.,0.],[0.,1.]]]])) Conv = torch.nn.Conv2d(1,1,3, stride=2, padding=1, bias = False) Conv.weight.data = Kernel Output = Conv(Input) print(Output)
上記のコードで、以下の出力が得られます。
tensor([[[[2., 3.], [0., 1.]]]], dtype=torch.float64, grad_fn=<ThnnConv2DBackward>)
画像と同じ出力が得られていますね~。
PyTorch
では、Conv2d
という関数を用いることで簡単に畳み込み処理を行うことができます。パディングやストライドもオプションとして数値を指定するだけなので簡単ですね!
転置畳み込み
さて、やっと本題です。転置畳み込みとは、逆畳み込みとも呼ばれる処理のことで、名前の通り畳み込みの逆の処理を行います。
畳み込みでは、複数のピクセルにカーネルを適用し、一つのピクセルを生成します。
転置畳み込みでは逆に、一つのピクセルにカーネルを適用して、複数のピクセルを生成します。
イメージだと、下図のようになりますね(-"-)
では、具体的にどのような動作を行うのか、PyTorch
で転置畳み込みを行う関数であるConvTranspose2d
の動作で見ていきましょう。
処理概要
転置畳み込みは、一つのピクセルから複数のピクセルを生成するというイメージしにくい処理を行うのですが、実は行う処理としては畳み込みと同じなんです(?)
ただし畳み込みと違い、畳み込む前の配列に、値が0
のピクセルを追加して配列の拡大を行います。その拡大した配列を畳み込むことで、「一つのピクセルから複数のピクセルの生成」を実現します。
入力配列の拡大
畳み込みでも、パディングの章で説明したように配列の周囲に0を追加していく処理がありましたが、転置畳み込みでは3種類の追加処理があります。それぞれについて見ていきましょう!!
ストライド
さて、畳み込みでも出てきた「ストライド」です。
畳み込みでは、畳み込む際にウィンドウを何個とばしで処理していくか、という値でしたが、転置畳み込みにおいては、何ピクセルごとに元の配列のピクセルが入るか、という値になります。
例えばストライドが3の場合、元の配列の各ピクセルの間に2ピクセルずつ0
が入れられます。
パディング
畳み込みでも出てきた「パディング」です。
これは、畳み込みの「パディング」と同じで、配列の周囲に0
を埋めていきます。
※注意
ここで、私がPyTorch
のConvTranspose2d
という関数の理解に苦しんだ原因についてお話します。
この関数では、padding
という名前のオプションがあり、これでパディングを行う幅が決まっています。
しかし、padding
の値がそのままパディングの幅になる訳ではありません。
ConvTranspose2d
では、Conv2d
による出力のサイズと整合性をとるため、dilation * (kernel_size - 1) - padding
がパディング幅となります。( kernel_size はカーネルの一辺の長さ、dilationは拡張畳み込みと言われる手法を利用しない場合は1
です。拡張畳み込みの説明は割愛しますm(__)m )
整合性というのは、例えば3x3
の配列に幅1のパディングを適用して2x2
のカーネルで畳み込みを行った場合、出力は4x4
になりますよね?
転置畳み込みは、その逆の動作を行いたいという発想の処理です。そのため、Conv2d
に設定したpadding
の値と同じ値をConvTranspose2d
のpadding
に設定することで、Conv2d
で得られた出力を、その入力と同じサイズに変化させられる処理ができるようになっています。
4x4
の配列を2x2
のカーネルで畳み込んで3x3
の出力を得たい場合、パディングは必要ありません。しかし、ConvTranspose2d
ではpadding
の値を0
にするのではなく、Conv2d
で設定した1
を設定することで、逆向きの処理を行うことができます。dilation * (kernel_size - 1) - padding
にdilation=1
,kernelsize=2
,padding=1
を入れると、たしかに0
になります。
勘違いしやすいと思うので、気を付けましょう('Д')
アウトプットパディング
これは、畳み込みには出てきませんでしたね。
というか、ConvTranspose2d
のオプション名がoutput_padding
であるため「アウトプットパディング」と称しましたが、実際にこういう名前かは分かりません(笑)
これは、「パディング」に近い処理をするのですが、周囲4方向に0
を埋めるのではなく、右と下にのみ0
を埋めます。これは、出力される配列のサイズを調整するために利用するみたいです!
例
以下の入力配列とカーネルを利用した転置畳み込みの例を見ていきましょう!
まず入力配列の拡大を行います。ConvTranspose2d
の引数を、stride=2
,padding=1
,output_padding=1
として拡大します。パディング幅は前述した通り、拡張畳み込みを行わない場合はkernel_size - 1 - padding
で求められるので、今回は0
となり、以下のような拡大となります。
あとは、畳み込みと同じように処理するだけです!!
なんとなく、転置畳み込みの処理が分かっていただけたでしょうか??今の例の転置畳み込みを行うコードはこんな感じ。
import torch import numpy as np Input = torch.tensor(np.array([[[[2.,3.,2.],[1.,2.,4.],[4.,2.,1.]]]])) Kernel = torch.tensor(np.array([[[[2.,1.],[1.,2.]]]])) print(Input) print(Kernel) Conv = torch.nn.ConvTranspose2d(1,1,2, stride=2, padding=1, output_padding=1, bias = False) Conv.weight.data = Kernel Output = Conv(Input) print(Output)
これを実行すると、以下の出力が得られます('ω')
tensor([[[[2., 3., 2.], [1., 2., 4.], [4., 2., 1.]]]], dtype=torch.float64) tensor([[[[2., 1.], [1., 2.]]]], dtype=torch.float64) tensor([[[[4., 3., 6., 2., 4.], [1., 4., 2., 8., 4.], [2., 2., 4., 4., 8.], [4., 4., 2., 2., 1.], [8., 2., 4., 1., 2.]]]], dtype=torch.float64, grad_fn=<SlowConvTranspose2DBackward>)
手計算と同じ結果になっています!!
まとめ
今回は転置畳み込みの処理について書いてみました。
分かりにくい文になってしまったかもしれませんが、お役に立てれば幸いです( *´艸`)
でわノシ