본문 바로가기

논문 리뷰&구현/NLP | Natural Language Processing

[논문 구현] Transformer 모델 구현 및 개념 정리

이번에는 NLP 분야에서 가장 유명한 논문이라고 할 수 있는 "Attention is All You Need" 논문의 transformer 모델을 구현해보고자 합니다. 구현의 경우 논문은 내용을 그대로 구현하는 것을 베이스라인으로 하여 최대한 동일하게 구현하고자 하였고 몇몇 막히는 부분은 다른 분들의 티스토리를 참고하였습니다. 참고한 티스토리 링크는 아래 reference에 넣어두겠습니다. 추가적으로 논문을 구현하며 논문에 대해 왜?와 같이 의문을 가졌던 부분은 함께 정리하였습니다.

 

 

1. Input Embedding / Output Embedding

모델 구조에서 input과 output을 받아 가장 먼저 embedding하는 부분이다.

# importing required libraries
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
import warnings
import pandas as pd
import numpy as np
warnings.simplefilter("ignore")


# paper 3.4 Embeddings and Softmax part
class InputEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        """
        임베딩 레이어 초기화.
        Args:
            vocab_size: 어휘 크기
            d_model: 임베딩 차원
        """
        super().__init__()
        
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        """
        순전파 연산.
        Args:
            x: 입력 벡터 (토큰 인덱스)
        Returns:
            w_o: 임베딩 벡터
        """
        # 입력 데이터를 임베딩
        # 루트(d_model)을 가중치로 하여 embedding layer에 곱한다고 paper에 명시되어 있음.
        w_o = self.embedding(x) * math.sqrt(self.d_model)  #[2, 20] -> [2, 20, 512]

        return w_o

 

2. Positional Encoding

Transformer 모델은 recurrence와 convolution을 가지고 있지 않아 토큰의 위치 정보를 알려주기 위해 positional encoding을 사용한다. 이를 input embedding에 더해준다.

# paper 3.5 Positional Encoding part
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len, dropout):
        """
        위치 임베딩 레이어 초기화.
        Args:
            d_model: 임베딩 차원
            max_seq_len: 입력 시퀀스의 최대 길이
            dropout: dropout 비율
        """
        super().__init__()

        self.d_model = d_model  #512
        self.max_seq_len = max_seq_len  #20
        self.dropout = nn.Dropout(dropout)
        
        # pe 빈 텐서 생성
        pe = torch.zeros(max_seq_len, d_model)  #[20, 512]

        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                # 짝수 인덱스에 대해 sin 값을 계산
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
                # 홀수 인덱스에 대해 cos 값을 계산
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))

        # 0번째에 batch 차원을 추가 [max_seq_len, self.d_model] -> [1, max_seq_len, self.d_model]
        pe = pe.unsqueeze(0)  #[1, 20, 512]

        # 위치 임베딩을 고정된 텐서로 등록
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        순전파 연산.
        Args:
            x: 입력 벡터
        Returns:
            x: 위치 정보가 추가된 출력 벡터
        """
        seq_len = x.size(1)  #[2, 20, 512] 그 중에서 20을 반환
        # 임베딩에 위치 정보를 추가
        x = x + self.pe[:, :seq_len]

        x = self.dropout(x)

        return x

 

3. Layer Normalization

# Layer Normalization part
class LayerNormalization(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.layer_norm = nn.LayerNorm(normalized_shape, eps=eps)

    def forward(self, x):
        return self.layer_norm(x)

 

4. Multi-Head Attention

# paper 3.2.2 Multi-Head Attention part
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model, h, dropout):
        """
        Multi-Head Self-Attention 클래스
        Args:
            d_model: 임베딩 벡터 출력 차원
            h: Self-Attention 헤드의 수
            dropout: dropout 비율
        """
        super().__init__()

        self.d_model = d_model  #512
        self.h = h  #8
        # multi-head attention 코드
        self.d_k = int(self.d_model / self.h)  #64

        """
        paper의 차원과 다른 이유
        왜 [d_model, d_k]가 아니라 [d_model, d_model]인가?
        : PyTorch에서는 병렬로 처리하기 때문에 d_k가 아닌 d_model로 한꺼번에 받아 처리 
        """
        self.w_q = nn.Linear(self.d_model, self.d_model ,bias=False)
        self.w_k = nn.Linear(self.d_model, self.d_model, bias=False)
        self.w_v = nn.Linear(self.d_model, self.d_model , bias=False)
        self.w_o = nn.Linear(self.d_model, self.d_model)

        self.dropout = nn.Dropout(dropout)

    def attention(self, query, key, value, mask=None):
        """
        Args:
            query: Query 벡터 [batch_size, seq_length, d_model]
            key: Key 벡터 [batch_size, seq_length, d_model]
            value: Value 벡터 [batch_size, seq_length, d_model]
            mask: 디코더에서 사용되는 마스크 텐서(opt.)

        Returns: 
            Multi-Head Attention 출력 벡터
        """
        d_k = query.shape[-1]  #64
        # key 차원 transpose: [batch_size, h, seq_length, d_k] -> [batch_size, h, d_k, seq_length] 
        key = key.transpose(-2, -1)
        attention_scores = (query @ key) / math.sqrt(d_k) # QK^T 연산 [batch_size, h, seq_length, seq_length]

        # mask 계산
        """
        Masking 왜? 원리?
        : self-attention 시 패딩된 단어는 의미가 없기 때문에 무시하도록 함 디코더에서 미래 단어를 보지 않도록 시행 
        mask는 0인 부분을 -1e9로 바꾸어주어 softmax시 거의 0이 되도록 만듦
        """
        if mask is not None:
            attention_scores.masked_fill_(mask==0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # 행 기준 softmax

        attention_scores = self.dropout(attention_scores)

        out = attention_scores @ value # attention 결과값

        return out, attention_scores

    def forward(self, q, k, v, mask=None):
        """
        Args:
            query: Query 벡터 [batch_size, seq_length, d_model]
            key: Key 벡터 [batch_size, seq_length, d_model]
            value: Value 벡터 [batch_size, seq_length, d_model]
            mask: 디코더에서 사용되는 마스크 텐서(opt.)

        Returns: 
            Multi-Head Attention 출력 벡터
        """
        
        # 1. Q, K, V를 d_k, d_k, d_v로 projection
        """
        왜 projection하는가?
        : 가중치 행렬을 곱하지 않으면 q, k, v 모두 같은 값으로 attention 과정이 의미없음
        각 head마다 서로 다른 특징을 학습할 수 있도록 차별성을 부여할 수 있음
        """
        query = self.w_q(q)  #[2, 20, 512]
        key = self.w_k(k)  #[2, 20, 512]
        value = self.w_v(v)  #[2, 20, 512]

        # 2. Q, K, V를 head 수만큼 분리
        # 차원: [batch_size, seq_len, d_model] -> [batch_size, seq_len, h, d_k] -> [batch_size, h, seq_len, d_k]
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # 3. attention function 이용해서 attention score 계산
        x, self.attention_scores = self.attention(query, key, value, mask)

        # 4. 각 head별로 나온 결과값을 concat하는 과정
        # 차원: [batch_size, h, seq_len, d_k] -> [batch_size, seq_len, h, d_k] -> [batch_size, seq_len, d_model]
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        out = self.w_o(x)

        return out

 

5. Residual Connection

# Residual Connection part
class ResidualConnection(nn.Module):
    def __init__(self, d_model, dropout):
        """
        Transformer의 Residual Connection + Layer Normalization 레이어
        Args:
            d_model: 임베딩 차원
            dropout: 드롭아웃 비율
        """
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(d_model)

    def forward(self, x, sublayer):
        """
        Post-norm 방식(기존 Transformer 논문 그대로 구현, 현재는 Pre-norm 방식 더 선호!)
        Residual Connection을 적용하는 forward 함수
        Args:
            x: 원본 입력 텐서 (batch_size, seq_len, d_model)
            sublayer: Transformer 블록 내부의 서브레이어 (Multi-Head Attention, Feed Forward 등)
        
        Returns:
            Layer Normalization을 적용한 Residual Connection 결과
        """
        residual = x + self.dropout(sublayer(x))  #[2, 20, 512]
        norm_residual = self.norm(residual)

        return norm_residual

 

6. Position-wise Feed-Forward Networks

# paper 3.3 Position-wise Feed-Forward Networks part
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        """
        Position-wise Feed-Forward Networks 블록
        Args:
           d_model: input, output 차원
           d_ff: inner-layer 차원
           dropout: dropout 비율
        """
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        relu = torch.relu(self.linear1(x))  #[2, 20, 2048]
        relu_dropout = self.dropout(relu)
        out = self.linear2(relu_dropout)  #[2, 20, 512]
        return out

 

7. Encoder Block

# paper 3.1 Encoder part
class EncoderBlock(nn.Module):
    def __init__(self, self_attention_block: MultiHeadAttentionBlock, 
                 feed_forward_block: FeedForwardBlock, dropout):
        """
        Encoder 블록
        Args:
           self_attention_block (MultiHeadAttentionBlock): Self-Attention 레이어
           -> 입력 문장 전체를 참고하여 context 정보를 추출 / attention을 통해 각 단어가 다른 단어들과 어떻게 관련되어 있는지를 계산
           feed_forward_block (FeedForwardBlock): Feed Forward 레이어
           -> self-attention으로만 학습된 정보가 너무 단순할 수 있음 / 더욱 풍부한 의미를 가지도록 학습
           dropout: dropout 비율
        """
        super().__init__()

        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([
            ResidualConnection(self_attention_block.d_model, dropout),  # 첫 번째 Residual Connection (Self-Attention 용)
            ResidualConnection(self_attention_block.d_model, dropout)   # 두 번째 Residual Connection (Feed Forward 용)
        ])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x
    

class Encoder(nn.Module):
    def __init__(self, d_model, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(d_model)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        out = self.norm(x)
        return out

 

8. Decoder Block

# paper 3.1 Decoder part
class DecoderBlock(nn.Module):
    def __init__(self, self_attention_block : MultiHeadAttentionBlock, cross_attention_block : MultiHeadAttentionBlock, 
                 feed_forward_block : FeedForwardBlock, dropout):
        """
        Decoder 블록
        Args:
           self_attention_block (MultiHeadAttentionBlock): Self-Attention 레이어
           -> decoder가 이전 단어만을 참고해 다음 단어를 예측 / 미래 단어를 보지 못하도록 tgt_mask 사용
           cross_attention_block (MultiHeadAttentionBlock): 두번째 encoder의 값을 받아 사용하는 레이어
           -> decoder가 encoder의 정보를 참고하여 더 정확한 번역을 생성 / src_mask를 사용하여 encoder에서의 padding 무시
           feed_forward_block (FeedForwardBlock): Feed Forward 레이어
           -> attention layer에서 얻은 정보를 더 복잡하게 변형하여 표현력을 증가 / 문맥을 보지 않음
           dropout: dropout 비율
        """
        super().__init__()

        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([
            ResidualConnection(self_attention_block.d_model, dropout),  # 첫 번째 Residual Connection (Multi-Head attention 용)
            ResidualConnection(self_attention_block.d_model, dropout),  # 두 번째 Residual Connection (Cross attention 용)
            ResidualConnection(self_attention_block.d_model, dropout)   # 세 번째 Residual Connection (Feed Forward Network 용)
        ])
    
    # 단어 길이를 맞추기 위함(연산량 감소)   
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        """
        Args:
            x: 이전 Masked Multi-Head Attention 결과값
            encoder_output: 앞선 encoder의 결과값
            src_mask: 인코더 출력에서 패딩 토큰에 해당하는 위치를 0으로, 실제 단어에 해당하는 위치를 1로 채운 이진 마스크
            tgt_mask: 디코더의 현재 위치 이후의 단어들을 가려주는 마스크 
        """
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x



class Decoder(nn.Module):
    def __init__(self, d_model, layers):
        super().__init__()
        self.layers = layers
        self.norm  = LayerNormalization(d_model)
        
    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        out = self.norm(x)
        return out

 

9. Projection Layer

# Projection Layer part
class ProjectionLayer(nn.Module):
    def __init__(self, d_model, vocab_size):
        """
        Projection Layer: transformer 마지막 단계에서 사용되는 출력 변환 레이어(디코더의 최종 출력을 단어 예측을 위한 확률 분포로 변환하는 역할)
        Args:
            d_model: 임베딩 차원
            vocab_size: tgt_vocab_size로 target 어휘 크기
        """
        super().__init__()

        self.proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        #(Batch, seq_len,d_model) -> (Batch,seq_len,vocab_size)
        linear = self.proj(x)
        out = torch.log_softmax(linear, dim=-1)
        return out

 

10. Transformer

앞서 구현한 각 transformer에 필요한 클래스들을 활용해 최종 transformer 모델의 클래스입니다.

#Transformer
class Transformer(nn.Module):
    def __init__(self,encoder :Encoder, decoder : Decoder, src_embed : InputEmbedding, tgt_embed : InputEmbedding, 
                 src_pos : PositionalEncoding, tgt_pos :PositionalEncoding, projection_layer : ProjectionLayer): 
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        src = self.encoder(src, src_mask)
        return src
    
    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        tgt = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
        return tgt
    
    def project(self, x):
        prj = self.projection_layer(x)
        return prj
    
    
def build_transformer(src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len, d_model, N, h, dropout, d_ff):
    """
        전체 Transformer
        Args:
           src_vocab_size: source 어휘 크기
           tgt_vocab_size: target 어휘 크기
           src_seq_len: source 문장 최대 길이
           tgt_seq_len: target 문장 최대 길이
           d_model: 임베딩 차원(512)
           N: encoder/decoder 수(6)
           h: Multi-head 수(8)
           dropout: dropout 비율
           d_ff: feed forward network 시 차원 확장(2048)
    """
    # Embedding layers
    src_embed = InputEmbedding(src_vocab_size, d_model)  # Embedding(10000, 512) 
    tgt_embed = InputEmbedding(tgt_vocab_size, d_model)  # Embedding(10000, 512) 
    
    # Positional Encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    # Encoder Blocks
    encoder_blocks=[]
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)
        
    # Decoder Blocks
    decoder_blocks=[]
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)

    # Encoder and Decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
    
    # Projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    # Transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    #initial parameters
    for p in transformer.parameters():  # 205개의 파라미터
        if p.dim() >1 :  
            nn.init.xavier_uniform_(p)
            
    return transformer

 

 

11. Test 코드

위 클래스들은 transformer.py 파일로 구현하였으며 간단하게 test를 위한 코드의 경우 따로 test.py로 저장하여 테스트하였습니다. test.py 코드는 아래와 같습니다.

import torch
from transformer import build_transformer 


def main():
    # 하이퍼파라미터 설정
    src_vocab_size = 10000  # source 언어 어휘 크기
    tgt_vocab_size = 10000  # target 언어 어휘 크기
    src_seq_len = 20  # source 문장의 최대 길이
    tgt_seq_len = 20  # target 문장의 최대 길이
    d_model = 512  # 임베딩 차원
    N = 6  # 인코더/디코더 블록 개수
    h = 8  # Multi-Head Attention에서 헤드 개수
    dropout = 0.1  # 드롭아웃 비율
    d_ff = 2048  # Feed Forward 네트워크 차원 확장

    # Transformer 모델 생성
    transformer = build_transformer(
        src_vocab_size, tgt_vocab_size, src_seq_len, tgt_seq_len, 
        d_model, N, h, dropout, d_ff
    )

    # 더미 입력 데이터 생성 (batch_size=2)
    batch_size = 2

    # 소스 입력 (랜덤한 토큰 ID, 범위: 0~vocab_size-1)
    src_input = torch.randint(0, src_vocab_size, (batch_size, src_seq_len))

    # 타겟 입력 (랜덤한 토큰 ID)
    tgt_input = torch.randint(0, tgt_vocab_size, (batch_size, tgt_seq_len))

    # 마스크 생성 (모든 값 1)
    src_mask = torch.ones(batch_size, 1, src_seq_len, src_seq_len)
    tgt_mask = torch.ones(batch_size, 1, tgt_seq_len, tgt_seq_len)

    # **Transformer Forward Pass 실행**
    print("===== Encoding 시작 =====")
    encoder_output = transformer.encode(src_input, src_mask)
    print("Encoder Output Shape:", encoder_output.shape)  #[2, 20, 512]

    print("\n===== Decoding 시작 =====")
    decoder_output = transformer.decode(encoder_output, src_mask, tgt_input, tgt_mask)
    print("Decoder Output Shape:", decoder_output.shape)  #[2, 20, 512]

    print("\n===== Projection 시작 =====")
    final_output = transformer.project(decoder_output)
    print("Final Output Shape:", final_output.shape)  #[2, 20, 10000]


if __name__ == "__main__":
    main()

 

 

 


 

Reference

 

[Transformer] 아키텍처 구현하기 - 1 (Pytorch)

Transformer는 논문으로만 읽어봤지, 코드로 뜯어보는 것은 처음이다. 논문 저자들은 정말 천재가 맞는 것 같다. 유튜브를 참고해서 코드를 구현하였으며, 이번 포스팅은 오로지 아키텍처에만 초점

j2rooong.tistory.com