무슨 생각을 해 그냥 하는거지

[Pytorch] nn.Sequential은 multiple input을 다룰 수 없다 본문

COMPUTER/Deep Learning

[Pytorch] nn.Sequential은 multiple input을 다룰 수 없다

빛나는콩 2022. 2. 16. 11:22

기초적인 얘기지만 혹시나 나중에 까먹을까봐 + 나처럼 답답해 하고 있는 분이 계실까봐 기록한다.

 

분명 forward 때 인풋을 두 개 줬는데 자꾸 에러가 발생했다.

*** TypeError: forward() takes 2 positional arguments but 3 were given

 

다른 부분은 이상이 없어보였는데, 찾아보니 nn.Sequential이 문제였다.

nn.Sequential로 선언한 레이어는 forward 때 인풋을 하나만 넣어줄 수 있다.

(나는 Encoder Layer를 nn.Sequential 안에 여러 개 넣어놨는데 Encoder layer의 forward가 2개의 인자를 받아야 했다.)

 

예를 들어..

어떤 모델 클래스 __init__ 함수에서 이렇게 선언했다면

self.enc_layers = nn.Sequential(~) # ~ == Encoder layer 여러 개

 

forward 함수에서

x = self.enc_layers(x, mask) ← 이렇게 못한다!

 


해결 방법

두 가지가 있다

  1. nn.Sequential를 상속받아 새로 만든다.
  2. input을 하나로 만들어 nn.Sequential의 인풋으로 넣어주고 나중에 분리한다.
  3. nn.ModuleList를 사용한다.

 

두 번째 방법은 간단하니까 생략하고,

1번 방법은 아래와 같다. (참고)

class mySequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs

 

이렇게 선언하고 원래 nn.Sequential 사용하는 것처럼 사용하면 된다. (forward만 override 하는거니까)

 

 

세 번째 방법이 가장 좋은 것 같다.

nn.ModuleList를 쓰고 forward할 때 for문으로 돌려주면 된다.

선언 시

self.enc_layers = nn.ModuleList([~])

forward 함수에서

for layer in self.enc_layers:

    x = f(x)

 

이런 식!

'COMPUTER > Deep Learning' 카테고리의 다른 글

[TensorFlow] tensorflow.linalg.band_part  (0) 2022.02.14