전공/python

PyTorch module forward() 와 특수 매소드 __call__() 비교

import ysy 2022. 7. 12. 13:33

https://ysy2000.tistory.com/117에서 잠깐 특수 매소드에 대해 언급했다. 

특수메소드는 built-in function이라고 보면 된다.

그 중에서 __call__ 메소드를 정의하면 f = Function() 형대로 함수의 인스턴스를 변수 f에 대입하고,

나중에 f(...) 형태로 __call__매소드를 호출할 수 있다. 

 

 


그런데 https://github.com/clovaai/ClovaCall 의 코드를 리뷰하던 중 마치 __call__()을 호출하듯 forward()매소드를 호출하는 것을 볼 수 있었다.

model = Seq2Seq(enc, dec)
logit = model(feats, feat_lengths, scripts, teacher_forcing_ratio=0)
 

첫 문장에서 Seq2Seq의 객체로 model을 선언하고 둘째 문장에서 바로 input을 줘버린다.

무슨 input이고, 애초에 저렇게 줄 수 있을까? 

그래서 Seq2Seq으로 가보면 아래와 같다.

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, decode_function=F.log_softmax):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.decode_function = decode_function

    def flatten_parameters(self):
        pass

    def forward(self, input_variable, input_lengths=None, target_variable=None,
                teacher_forcing_ratio=0):

        self.encoder.rnn.flatten_parameters()
        encoder_outputs, encoder_hidden = self.encoder(input_variable, input_lengths)
		(중략)
        return decoder_output
        
    @staticmethod
    def get_param_size(model):
        params = 0
        for p in model.parameters():
            tmp = 1
            for x in p.size():
                tmp *= x
            params += tmp
        return params

여러 매소드들 중에서  forward() 매소드가 둘째 문장의 입력과 같은 인수를 받는다는 것을 알 수 있다.

즉 둘째문장과 같이 쓰면 바로 forward() 매소드가 호출된다는 것이다.

__call__() 형태의 특수 매소드도 아닌데 어떻게 이것이 가능할까?

 

 


사실 이것이 가능한 이유는 PyTorch에서 nn.module을 설계할 때 그렇게 설계했기 때문이다.

대부분 파이썬을 활용한 딥러닝은 클래스로 모델을 구현하고 있는데 여러 구현체 중 PyTorch는 nn.Module을 상속받아 사용할 수 있다.

이런 PyTorch의 특수 매소드는 forward()외에도 nn.Module 클래스의 속성을 가지고 초기화되는 super() 등이 있다.

 

 


https://stephencowchau.medium.com/pytorch-module-call-vs-forward-c4df3ff304b1

 

PyTorch module__call__() vs forward()

In Python, there is this built-in function __call__() for a class you can override, this make your object instance callable.

stephencowchau.medium.com

위 사이트에서는 __call__()과 forward()의 차이점을 설명했다.

한국어로 간단히 여기서 정리하자면 아래와 같다.

 


__call__() 은 파이썬의 built-in function으로 오버라이드하여 객체 인스턴스를 바로 불러 사용할 수 있다.

한편 PyTorch에서는 nn.module에서 같은 기능을 하는 함수로 forward()를 구현해뒀다.

https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1101를 보면 아래와 같이 _call_impl이 있는 문장이 있다.

( https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1101 )

 

그래서 다시 _call_impl을 찾아가면 아래 사진과 같다.

( https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1045 )

1071번째 줄을 보면

result = forward_call(*input, **kwargs)

이라고 되어 있는데, 이 문장이 forward() 매소드를 호출하는 문장이다.

main call 이전과 후 모두에 대해 동작하기에 바로 호출할 수 있다는 것 같다.

 

 


마지막으로 이러한 hook system에 대해 더 알고 싶다면 아래 사이트를 참고하면 될 것이다.

https://pytorch.org/docs/stable/notes/modules.html?highlight=hook#module-hooks 

 

Modules — PyTorch 1.12 documentation

Modules PyTorch uses modules to represent neural networks. Modules are: Building blocks of stateful computation. PyTorch provides a robust library of modules and makes it simple to define new custom modules, allowing for easy construction of elaborate, mul

pytorch.org

 

반응형