본문 바로가기
Dev

XLM

by 호랑2 2023. 3. 17.

XLM(Extensible Language Model)은 Facebook AI Research(Facebook AI)에서 개발한 언어 모델이다.
XLM은 다국어 학습 데이터를 이용하여 학습된 다국어 언어 모델로 Transformer 아키텍처에 기반하며, 다양한 언어의 텍스트를 이해하고 생성할 수 있다.
 
XLM 모델은 BERT와 같은 pre-training 및 fine-tuning 접근 방식을 사용하지만, XLM 모델은 BERT와는 달리 다국어 pre-training과 fine-tuning을 모두 지원한다. 이를 통해 한 번의 학습으로 다양한 언어에 대한 일반화 성능을 높일 수 있다.
 
XLM 모델은 입력 임베딩, Transformer 인코더 및 디코더 레이어, 그리고 출력 레이어로 구성된다. XLM 모델에서는 인코더와 디코더가 서로 다른 언어의 텍스트를 다루도록 설계되어 있으며 입력 임베딩 레이어에는 언어 ID를 추가하여 언어 정보를 모델에 전달한다.
 
XLM 모델은 pre-training 데이터셋을 사용하여 언어 모델을 학습한 다음, fine-tuning을 통해 다양한 태스크에 적용될 수 있도록 조정된다. Pre-training 단계에서는 MLM(Masked Language Modeling) 및 TLM(Translation Language Modeling) 태스크를 수행한다.
 

출처 : https://github.com/facebookresearch/XLM#ii-cross-lingual-language-model-pretraining-xlm

 
아래는 XLM 모델을 구현한 샘플 코드이다.
 

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embedding_size = embedding_size
        self.num_heads = num_heads
        
        self.query_fc = nn.Linear(embedding_size, embedding_size)
        self.key_fc = nn.Linear(embedding_size, embedding_size)
        self.value_fc = nn.Linear(embedding_size, embedding_size)
        self.softmax = nn.Softmax(dim=2)
        self.dropout = nn.Dropout(0.1)
        self.output_fc = nn.Linear(embedding_size, embedding_size)

    def forward(self, x, mask=None):
        batch_size = x.size(0)
        q = self.query_fc(x).view(batch_size, -1, self.num_heads, self.embedding_size//self.num_heads).transpose(1, 2)
        k = self.key_fc(x).view(batch_size, -1, self.num_heads, self.embedding_size//self.num_heads).transpose(1, 2)
        v = self.value_fc(x).view(batch_size, -1, self.num_heads, self.embedding_size//self.num_heads).transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.embedding_size, dtype=torch.float32))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        scores = self.softmax(scores)
        scores = self.dropout(scores)
        weighted_values = torch.matmul(scores, v)
        
        weighted_values = weighted_values.transpose(1, 2).contiguous().view(batch_size, -1, self.embedding_size)
        output = self.output_fc(weighted_values)
        return output

class PositionwiseFeedforward(nn.Module):
    def __init__(self, embedding_size, ff_hidden_size):
        super(PositionwiseFeedforward, self).__init__()
        self.fc1 = nn.Linear(embedding_size, ff_hidden_size)
        self.fc2 = nn.Linear(ff_hidden_size, embedding_size)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class XLM(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_heads, num_layers):
        super(XLM, self).__init__()

        self.token_embedding = nn.Embedding(vocab_size, embedding_size)
        self.position_embedding = nn.Embedding(512, embedding_size)
        self.dropout = nn.Dropout(0.1)
        
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            layer = nn.TransformerEncoderLayer(embedding_size, num_heads, dim_feedforward=4*embedding_size)
            self.layers.append(layer)

        self.layernorm = nn.LayerNorm(embedding_size)

    def forward(self, x, mask=None):
        positions = torch.arange(x.size(1), device=x.device).expand(x.size(0), x.size(1)).contiguous

위 코드에서는 MultiHeadAttention과 PositionwiseFeedforward 클래스를 정의한 후, XLM 클래스를 정의한다.
XLM 클래스 생성자에서는 입력 단어 집합의 크기, 임베딩 크기, 어텐션 헤드 수, 레이어 수를 인자로 받는다.
 
클래스 내부에서는 먼저 입력 토큰과 위치 정보를 임베딩한 후, 인코더 레이어를 num_layers개 만큼 쌓아서 인코딩을 수행한다.
인코딩 결과는 LayerNorm을 거쳐 출력되며 인코더 레이어는 TransformerEncoderLayer 클래스를 사용하여 구현한다.
이 클래스는 self-attention, multi-head attention 및 position-wise feedforward 네트워크를 포함한다.

'Dev' 카테고리의 다른 글

Electra  (0) 2023.03.24
ALBERT  (0) 2023.03.19
RoBERTa  (0) 2023.03.16
DistilBERT  (0) 2023.03.12
Zero-shot learning  (0) 2023.03.11

댓글