Gemma 설명: RecurrentGemma 아키텍처

8월 29, 2024
Ju-yeong Ji Gemma DevRel
Ravin Kumar Google Data Scientist Language Applications

Gemma 설명 시리즈의 이전 게시물에서는 최신 Gemma 2 아키텍처에 대해 논했습니다. 이번 게시물에서는 RecurrentGemma 아키텍처를 살펴보겠습니다. 자, 시작합니다!


RecurrentGemma 2B, 9B

RecurrentGemma는 게이트 선형 순환과 로컬 슬라이딩 윈도우 어텐션을 혼합한 하이브리드 모델인 Griffin을 기반으로 합니다. 이러한 변경 덕분에 계산과 메모리 성능이 향상되어 긴 컨텍스트 프롬프트에 더 적합합니다.

Griffin hybrid model architecture

그러나 Griffin 아키텍처의 크기가 고정된 상태이므로 건초더미에서 바늘을 찾듯 다량의 데이터에서 핵심을 찾는 성능이 저하되는 단점이 있습니다. 책 한 권에 수록된 텍스트 전체를 입력 데이터로 제공할 수도 있지만 이는 최적의 접근 방식이 아닐 수 있습니다. 순환 신경망(RNN)은 지나치게 긴 시퀀스에서 장거리 의존성을 학습하는 데 어려움을 겪을 수 있으며 모델의 컨텍스트 윈도우가 제한됩니다. 즉, 예측할 때 특정한 수의 이전 토큰만 실질적으로 고려할 수 있다는 뜻입니다.

더욱이, 순환 모델은 트랜스포머 모델에 비해 추론 시간 최적화 측면에서 아직 많은 주목을 받지 못하고 있습니다. 그리고 잘 정립된 트랜스포머 아키텍처에 비해 사용 가능한 연구 및 커뮤니티 지원이 적습니다.

따라서 LLM의 컨텍스트 윈도우가 소진될 우려가 있는 상황에서 이 모델은 매우 유용합니다. RecurrentGemma는 최신 정보를 우선시하고 오래된 데이터부터 전략적으로 삭제함으로써 컨텍스트가 확장되더라도 LLM의 성능이 강력한 상태로 유지되도록 합니다.

다음은 Recurrent Gemma 2B 모델의 아키텍처 다이어그램입니다.

Recurrent Gemma 2B model architecture

Griffin은 다른 Transformer 베이스라인과 동일한 잔차 패턴 및 MLP 블록을 따릅니다. 그러나 MQA Transformer 베이스라인 및 Hawk 모델과 달리 Griffin은 순환 블록과 MQA 블록을 혼합하여 사용합니다.

Layered structure of recurrent and MQA blocks

Griffin은 두 개의 잔차 블록을 순환 블록과 번갈아 가며 계층화된 구조를 사용한 다음, 로컬 MQA 어텐션 블록을 통합하는 잔차 블록을 사용합니다.

이 아키텍처의 핵심 매개변수는 아래 표에 정리되어 있습니다.

Core parameters of the architecture of 2B and 9B models

비(非)임베딩 매개변수와 임베딩 매개변수

비(非)임베딩 매개변수는 어텐션 메커니즘 및 피드포워드 신경망과 같은 구성요소에서 모델의 숨겨진 레이어 전체에 분산됩니다.

참고: 모델 '2B'라는 이름은 이 매개변수에서 유래했습니다.

임베딩 매개변수는 보통 임베딩 레이어라는 전용 레이어에서 찾을 수 있습니다. 이 레이어는 이산 토큰(예: 단어 또는 문자)을 연속 벡터 표현(임베딩)에 매핑하는 역할을 합니다.

참고: 0.7B는 256k(어휘 크기) x 2560(모델 너비)으로 계산할 수 있습니다.


모델 너비 및 RNN 너비

모델 너비란 모델에 숨겨진 레이어의 크기를 나타냅니다. 기본 Gemma 모델과 마찬가지로, 모델 너비가 복잡한 패턴 표현에 관한 모델의 능력을 결정합니다.

순환 신경망(RNN) 너비는 RG-LRU(Real-Gated Linear Recurrent Unit)에 의해 유지되는 숨겨진 상태의 크기입니다. 기존 Transformer와 달리, 순환 블록은 입력 길이에 관계없이 고정된 크기의 내부 상태가 유지합니다. 덕분에 RecurrentGemma는 더 적은 메모리로 더 긴 시퀀스를 처리할 수 있어, 긴 글이나 긴 코드 생성 같은 작업에 더 효율적입니다.


MLP 확장 계수

기본 Gemma 모델의 피드포워드 은닉 측정기준과 동일합니다. 단순화를 위해 Recurrent Gemma 모델에서는 확장 계수 3을 적용하여 MLP 측정기준이 7680(2560 x 3으로 계산)으로 나왔습니다.


로컬 어텐션 윈도우 크기

RecurrentGemma에 의해 유지되는 상태는 그 크기가 유한하며 토큰 2k개의 로컬 어텐션 윈도우보다 더 긴 시퀀스를 통해 커지지 않습니다. 즉, Gemma에서 자동 회귀적으로 생성되는 샘플의 최대 길이는 호스트 시스템의 메모리 용량으로 제한되지만 RecurrentGemma는 임의 길이의 시퀀스를 생성하여 이러한 제약을 극복할 수 있습니다.

RecurrentGemmaForCausalLM(
  (model): RecurrentGemmaModel(
    (embed_tokens): Embedding(256000, 2560, padding_idx=0)
    (layers): ModuleList(
      (0-1): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (2): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
 
      :
 
      (23): RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaSdpaAttention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=256, bias=False)
          (v_proj): Linear(in_features=2560, out_features=256, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=True)
          (rotary_emb): RecurrentGemmaRotaryEmbedding()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
      (24-25): 2 x RecurrentGemmaDecoderLayer(
        (temporal_pre_norm): RecurrentGemmaRMSNorm()
        (temporal_block): RecurrentGemmaRecurrentBlock(
          (linear_y): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_x): Linear(in_features=2560, out_features=2560, bias=True)
          (linear_out): Linear(in_features=2560, out_features=2560, bias=True)
          (conv_1d): Conv1d(2560, 2560, kernel_size=(4,), stride=(1,), padding=(3,), groups=2560)
          (rg_lru): RecurrentGemmaRglru()
          (act_fn): PytorchGELUTanh()
        )
        (channel_pre_norm): RecurrentGemmaRMSNorm()
        (mlp_block): RecurrentGemmaMlp(
          (gate_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (up_proj): Linear(in_features=2560, out_features=7680, bias=True)
          (down_proj): Linear(in_features=7680, out_features=2560, bias=True)
          (act_fn): PytorchGELUTanh()
        )
      )
    )
    (final_norm): RecurrentGemmaRMSNorm()
  )
  (lm_head): Linear(in_features=2560, out_features=256000, bias=False)
)

embed_tokens(임베딩 레이어)

입력 텍스트를 토큰 시퀀스로 가져와 각 토큰을 2560 크기의 연속 벡터 표현에 매핑합니다. 어휘 크기는 기본 Gemma 모델과 동일한 256000개입니다.


레이어

총 26개의 디코더 레이어가 반복 패턴으로 그룹화됩니다.

모델은 순환 블록(0-1)이 있는 두 개의 잔차 블록으로 시작합니다. 다음으로, 잔차 블록(2)과 레이어(25) 끝까지 교대하는 일련의 연속 블록이 이 시퀀스를 뒤따릅니다.

Recurrent block architecture

순환 블록이 있는 잔차 블록

순환 블록(시간 혼합 블록)에서 이 모델은 측정기준(모델 너비) 2560의 입력을 받아 출력 측정기준(RNN 너비) 2560을 갖는 두 선형 레이어를 병렬로 적용하여 두 개의 분기를 생성합니다.

첫 번째 분기(오른쪽)에서는 시간 필터 측정기준이 4인 소형 분리형 Conv1D 레이어를 적용합니다. 그리고 RG-LRU(Real-Gated Linear Recurrent Unit) 레이어가 이어집니다.

두 번째 분기(왼쪽)에서는 GeLU 비선형성을 적용합니다.

그런 다음, 요소별 곱셈으로 분기를 병합하고 출력 측정기준(모델 너비)이 2560인 최종 선형 레이어를 적용합니다.

RecurrentGemma-Residual-block

RMSNorm을 적용하면 MLP 블록이 뒤따릅니다.


로컬 MQA가 있는 잔차 블록

순환 블록(0-1)이 있는 잔차 블록 두 개가 생성되면 로컬 MQA(2)가 있는 잔차 블록이 이어집니다. 글로벌 어텐션 사용에 있어 주요 단점 중 하나는 계산 복잡도가 시퀀스 길이 면에서 이차 함수적으로 증가한다는 점입니다. 이 문제를 해결하기 위해 RecurrentGemma는 로컬 슬라이딩 윈도우 어텐션을 사용합니다. 이를 통해 각 위치는 과거에 고정된 수의 토큰에만 관여할 수 있습니다.

로컬 MQA 블록(시간 혼합 블록)에서 이 모델은 측정기준(모델 너비) 2560을 입력으로 받습니다. 선형 투영(q_proj, k_proj, v_proj, o_proj)을 사용하여 쿼리, 키, 값 및 출력 표현을 만듭니다. k_proj와 v_proj는 256 크기의 동일한 헤드를 공유하므로 k_projv_proj에 대한 out_features는 256입니다. 반면 q_projo_proj는 병렬로 10개의 헤드(256 x 10 = 2560)를 갖습니다.

기본 Gemma 모델과 마찬가지로, RoPE(rotary positional embeddings)를 위한 rotary_emb(RecurrentGemmaRotaryEmbedding)를 통합합니다.

RMSNorm과 MLP 블록의 적용은 이전의 잔차 블록과 동일합니다.


다음 단계는?

이 기사에서는 RecurrentGemma에 대해 알아보았습니다.

다음 게시물에서는 경량 개방형 비전 언어 모델(VLM)인 PaliGemma에 대해 살펴보겠습니다.

읽어주셔서 감사합니다! 앞으로도 계속 지켜봐 주세요!


참고 문헌

논문


코드 예


📋 전체 Gemma 아키텍처 시리즈