PyTorch module forward() 와 특수 매소드 __call__() 비교
https://ysy2000.tistory.com/117에서 잠깐 특수 매소드에 대해 언급했다.
특수메소드는 built-in function이라고 보면 된다.
그 중에서 __call__ 메소드를 정의하면 f = Function() 형대로 함수의 인스턴스를 변수 f에 대입하고,
나중에 f(...) 형태로 __call__매소드를 호출할 수 있다.
그런데 https://github.com/clovaai/ClovaCall 의 코드를 리뷰하던 중 마치 __call__()을 호출하듯 forward()매소드를 호출하는 것을 볼 수 있었다.
model = Seq2Seq(enc, dec)
logit = model(feats, feat_lengths, scripts, teacher_forcing_ratio=0)
첫 문장에서 Seq2Seq의 객체로 model을 선언하고 둘째 문장에서 바로 input을 줘버린다.
무슨 input이고, 애초에 저렇게 줄 수 있을까?
그래서 Seq2Seq으로 가보면 아래와 같다.
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, decode_function=F.log_softmax):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.decode_function = decode_function
def flatten_parameters(self):
pass
def forward(self, input_variable, input_lengths=None, target_variable=None,
teacher_forcing_ratio=0):
self.encoder.rnn.flatten_parameters()
encoder_outputs, encoder_hidden = self.encoder(input_variable, input_lengths)
(중략)
return decoder_output
@staticmethod
def get_param_size(model):
params = 0
for p in model.parameters():
tmp = 1
for x in p.size():
tmp *= x
params += tmp
return params
여러 매소드들 중에서 forward() 매소드가 둘째 문장의 입력과 같은 인수를 받는다는 것을 알 수 있다.
즉 둘째문장과 같이 쓰면 바로 forward() 매소드가 호출된다는 것이다.
__call__() 형태의 특수 매소드도 아닌데 어떻게 이것이 가능할까?
사실 이것이 가능한 이유는 PyTorch에서 nn.module을 설계할 때 그렇게 설계했기 때문이다.
대부분 파이썬을 활용한 딥러닝은 클래스로 모델을 구현하고 있는데 여러 구현체 중 PyTorch는 nn.Module을 상속받아 사용할 수 있다.
이런 PyTorch의 특수 매소드는 forward()외에도 nn.Module 클래스의 속성을 가지고 초기화되는 super() 등이 있다.
https://stephencowchau.medium.com/pytorch-module-call-vs-forward-c4df3ff304b1
PyTorch module__call__() vs forward()
In Python, there is this built-in function __call__() for a class you can override, this make your object instance callable.
stephencowchau.medium.com
위 사이트에서는 __call__()과 forward()의 차이점을 설명했다.
한국어로 간단히 여기서 정리하자면 아래와 같다.
__call__() 은 파이썬의 built-in function으로 오버라이드하여 객체 인스턴스를 바로 불러 사용할 수 있다.
한편 PyTorch에서는 nn.module에서 같은 기능을 하는 함수로 forward()를 구현해뒀다.
https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1101를 보면 아래와 같이 _call_impl이 있는 문장이 있다.
그래서 다시 _call_impl을 찾아가면 아래 사진과 같다.
1071번째 줄을 보면
result = forward_call(*input, **kwargs)
이라고 되어 있는데, 이 문장이 forward() 매소드를 호출하는 문장이다.
main call 이전과 후 모두에 대해 동작하기에 바로 호출할 수 있다는 것 같다.
마지막으로 이러한 hook system에 대해 더 알고 싶다면 아래 사이트를 참고하면 될 것이다.
https://pytorch.org/docs/stable/notes/modules.html?highlight=hook#module-hooks
Modules — PyTorch 1.12 documentation
Modules PyTorch uses modules to represent neural networks. Modules are: Building blocks of stateful computation. PyTorch provides a robust library of modules and makes it simple to define new custom modules, allowing for easy construction of elaborate, mul
pytorch.org