본문 바로가기
데이터 분석

Vision Transformer With Pytorch

by Toddler_AD 2024. 12. 17.

1. Transformer의 기본 구조

 

 

 

 

 

 

 

 

 

 

  • 전체 흐름

 

 

 

 

 

  • MNIST 구현
# 1번 블럭
# 패키지 수입

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm
from time import time
# 2번 블럭
# 하이퍼 파라미터 지정

MY_SHAPE = (1, 28, 28)       # 손글씨 이미지 데이터 모양
MY_EPOCH = 1               # 반복 학습 수
MY_BATCH = 128             # 배치 수
MY_LEARNING = 0.005        # 학습율

MY_PATCH = 7               # 패치 갯수
MY_ENCODER = 2             # 인코더 갯수
MY_HIDDEN = 8              # 임베딩 차원 수
MY_HEAD = 2                # 에텐션 계산 머리 수
MY_MLP = 4                 # MLP 크기
MY_CLASS = 10              # 분류화 대상 수
# 3번 블럭
# 이미지 패치 처리
# 입력 데이터 모양: [128, 1, 28, 28]
# 출력 데이터 모양: [128, 49, 16]

def patchfy(images, n_patches):
    n, c, h, w = images.shape

    patches = torch.zeros(n, n_patches ** 2, h * w // n_patches ** 2)
    #print(patches.shape)
    patch_size = h // n_patches
    #print(patch_size)

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:,
                              i * patch_size: (i + 1) * patch_size,
                              j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

# 테스트용 코드
#temp = torch.randn(MY_BATCH, 1, 28, 28)
#y = patchfy(temp, MY_PATCH)
#print(y.shape)
#print(temp[0].shape)
#print(y[0].shape)

#data = np.array([2.0, 4.0, 4.0])
#prob = nn.Softmax(dim=0)(torch.from_numpy(data))
#print(prob)

 

 

 

 

 

# 4번 블럭
# multi-head attention 클래스 정의

class MyMHA(nn.Module):
    def __init__(self, n_hidden, n_head):
        super(MyMHA, self).__init__()
        self.n_hidden = n_hidden
        self.n_head = n_head

        # 각 head가 처리할 차원 = 8/2 = 4
        d_head = int(n_hidden / n_head)
        self.d_head = d_head

        # Q, K, V 행렬 준비
        self.q_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(n_head)])
        self.k_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(n_head)])
        self.v_mappings = nn.ModuleList(
            [nn.Linear(d_head, d_head) for _ in range(n_head)])
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, images):
        # 어텐션 결과 저장
        results = []

        for sequence in images:
            seq_result = []

            for head in range(self.n_head):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]

                # 자리 만들기
                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]

                # 어텐션 계산
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                #print('attention의 모양:', attention.shape)
                seq_result.append(attention @ v)
                break;

            # 각 머리들의 계산 결과 통합
            merge = torch.hstack(seq_result)
            #print('merge의 모양:', merge.shape)
            results.append(merge)
            break;

        # 결과를 텐서로 전환
        #print('최종 결과', len(results))
        final = [torch.unsqueeze(result, dim=0) for result in results]
        final = torch.cat(final, dim=0)
        #print('final의 모양:', final.shape)
        # 결과 통합
        return final

# 테스트용 코드
#temp = torch.randn(MY_BATCH, 1, 28, 28)
#y = patchfy(temp, MY_PATCH)
#mha = MyMHA(MY_HIDDEN, MY_HEAD)
#mha(y)

 

 

 

 

 

# 5번 블럭
# VIT 인코더 구현

class MyEncoder(nn.Module):
    def __init__(self, n_hidden, n_head):
        super(MyEncoder, self).__init__()

        # 패치 임베딩 차원
        self.n_hidden = n_hidden

        # MHA 머리 수
        self.n_head = n_head

        # 첫번째 layer normalization 층
        self.norm1 = nn.LayerNorm(n_hidden)

        # MHA 층
        self.mha = MyMHA(n_hidden, n_head)

        # 두번째 layer normalization 층
        self.norm2 = nn.LayerNorm(n_hidden)

        # 최종 multi-layer perceptron 층
        self.mlp = nn.Sequential(
            nn.Linear(n_hidden, MY_MLP * n_hidden),
            nn.GELU(),
            nn.Linear(MY_MLP * n_hidden, n_hidden)
        )

    def forward(self, x):
        #print('입력 데이터 모양', x.shape)
        out = x + self.mha(self.norm1(x))
        out = self.norm2(out) + self.mlp(self.norm2(out))
        #print('출력 데이터 모양', out.shape)

        return out

# 테스트용 코드
#temp = torch.randn(MY_BATCH, 1, 28, 28)
#y = patchfy(temp, MY_PATCH)
#encoder = MyEncoder(MY_HIDDEN, MY_HEAD)
#t = encoder(y[:, :, :8])

 

 

 

 

# 6번 블럭
# 비젼 트랜스포머 전체 아키텍쳐 구성

class MyVIT(nn.Module):
    def __init__(self, image, n_patch, n_encoder, n_hidden, n_head, n_class):
        super(MyVIT, self).__init__()

        self.n_hidden = n_hidden
        self.n_patch = n_patch
        self.n_head = n_head
        self.n_encoder = n_encoder

        # d_input: 패치 안에 회소 수
        self.d_input = int (image[2] / n_patch) ** 2
        #print('패치 안에 회소 수', self.d_input)

        # 16차원을 8차원으로 임베딩
        self.embedding = nn.Linear(self.d_input, n_hidden)

        # 이미지 분류용 추가 패치
        self.cls_token = nn.Parameter(torch.randn(1, n_hidden))
        #print('추가 패치의 모양', self.cls_token.shape)

        # 인코더 적층
        self.encoder = nn.ModuleList(
            [MyEncoder(n_hidden, n_head) for _ in range(n_encoder)])
        #print('적층된 인코더 수', len(self.encoder))

        # 최종 이미지 분류 작업
        self.classify = nn.Sequential(
            nn.Linear(n_hidden, n_class),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        # 이미지 패치화
        n_batch, _, _, _ = images.shape
        patches = patchfy(images, self.n_patch)
        #print('패치화 결과 모양', patches.shape)

        # 16차원 화소를 8차원으로 임베딩
        tokens = self.embedding(patches)
        #print('임베딩 결과 모양', tokens.shape)

        # class token 추가
        tokens = torch.cat((self.cls_token.expand(n_batch, 1, -1), tokens), dim=1)
        #print('class token 추가 결과 모양', tokens.shape)

        # 위치 임베딩 추가
        out = tokens + pos_embed(tokens.shape[1], self.n_hidden)
        #print('위치 임베딩 결과 모양', out.shape)

        # 인코더 적용
        for encoder in self.encoder:
            out = encoder(out)
        #print('인코더 적층 후 모양', out.shape)

        # 최종 분류
        out = out[:, 0]
        out = self.classify(out)
        #print('최종 분류 모양', out.shape)
        #print(out)
        return out

# 테스트 용 코드
#temp = torch.randn(MY_BATCH, 1, 28, 28)
#y = MyVIT(MY_SHAPE, MY_PATCH, MY_ENCODER, MY_HIDDEN, MY_HEAD, MY_CLASS)
#z = y(temp)

 

 

 

 

# 7번 블럭
# 위치 임베딩 함수

def pos_embed(n_token, n_hidden):
    result = torch.zeros(n_token, n_hidden)

    for i in range(n_token):
        for j in range(n_hidden):
            result[i][j] = (
                np.sin(i / 10000 ** (j / n_hidden))
                if j % 2 == 0
                else np.cos(i / 10000 ** ((j - 1) / n_hidden))
            )
    return result

# 테스트용 코드
#result = pos_embed(50, 8)
#print(result[0])
#print(result[1])

 

 

 

# 8번 블럭
# VIT 지도 학습

def train_VIT():
    transform = ToTensor()
    dataset = MNIST(root='./data', train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=MY_BATCH, shuffle=True)

    for epoch in range(MY_EPOCH):
        loss = 0.0
        for batch in tqdm(dataloader):
            x, y = batch
            pred = model(x)
            loss = criterion(pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss = loss / len(dataloader)
        print('에포크', epoch, '손실', loss)
# 9번 블럭
# 컨트롤 타워

# VIT 생성
model = MyVIT(
    image=MY_SHAPE,
    n_patch=MY_PATCH,
    n_encoder=MY_ENCODER,
    n_hidden=MY_HIDDEN,
    n_head=MY_HEAD,
    n_class=MY_CLASS
)

# 최적화 함수 지정
optimizer = Adam(model.parameters(), lr = MY_LEARNING)
criterion = CrossEntropyLoss()

# VIT 지도 학습
print('학습 시작')
begin = time()
train_VIT()
end = time()
print('학습 시간', end - begin)

 

 

# 10번 블럭
# VIT 평가

def eval_VIT():
    transform = ToTensor()
    dataset = MNIST(root='./data', train=False, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=MY_BATCH, shuffle=True)

    with torch.no_grad():
        correct, total = 0,0
        loss = 0.0
        for batch in tqdm(dataloader):
            x, y = batch
            pred = model(x)
            loss = criterion(pred, y)
            loss = loss / len(dataloader)
            correct += torch.sum(torch.argmax(pred, dim=1) == y).item()
            total += len(y)
        print('\n정확도', correct / total * 100)

# VIT 평가
eval_VIT()

'데이터 분석' 카테고리의 다른 글

Language Transformer With Pytorch  (0) 2024.12.18
AARRR - Referral 지표  (0) 2024.08.22
AARRR - Revenue 지표  (2) 2024.08.22
AARRR - Retention 지표  (1) 2024.08.22
AARRR - Activation 지표  (0) 2024.08.22