Keras에서 Gemma 모델 도입

2월 21, 2024
Martin Görner Product Manager Keras

Keras 팀은 Gemini 모델을 만드는 데 사용되었던 동일한 연구와 기술로 개발된 경량의 최첨단 개방형 모델 제품군인 Gemma를 이제 KerasNLP 컬렉션으로 제공합니다. Keras 3 덕분에 Gemma는 JAX, PyTorch, TensorFlow에서 실행됩니다. 이번 출시에서 Keras는 새로운 LoRA(Low Rank Adaptation) API와 대규모 모델 병렬 학습 기능 등 대형 언어 모델용으로 특별히 설계된 새로운 기능도 몇 가지 선보입니다.

코드 샘플을 직접 살펴보고 싶다면 다음 링크로 이동하세요.

시작하기

Gemma 모델은 휴대용 2B 및 7B의 매개변수 크기로 제공되며 유사한 개방형 모델과 몇몇 더 큰 모델에 비해서도 상당히 발전된 모델입니다. 예를 들면 다음과 같습니다.

  • Gemma 7B는 MMLU 언어 이해 벤치마크에서 동급 최고 수준인 64.3%의 정답률로 신기록 수립(Mistral-7B는 62.5%, Llama2-13B는 54.8%의 정답률 기록)
  • Gemma는 초등학교 수학 문제에 대한 GSM8K 벤치마크 점수에 +11%p 추가(Gemma 7B는 46.4%, Mistral-7B는 35.4%, Llama2-13B는 28.7%)
  • 코딩 챌린지인 HumanEval에서 +6.1%p의 정답률 추가(Gemma 7B는 32.3%, Mistral 7B는 26.2%, Llama2 13B는 18.3%)

Gemma 모델은 친숙한 KerasNLP API와 매우 읽기 쉬운 Keras 구현과 함께 제공됩니다. 다음과 같이 한 줄의 코드로 모델을 인스턴스화할 수 있습니다.

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

그리고 텍스트 프롬프트에서 직접 실행할 수 있습니다. 즉, 토큰화가 기본 제공된다는 뜻이며 필요하다면 쉽게 분할할 수 있습니다. 상세한 방법은 Keras NLP 가이드를 읽어 보세요.

gemma_lm.generate("Keras is a", max_length=32)
> "Keras is a popular deep learning framework for neural networks..."

직접 해보기: Gemma 모델 시작하기

Keras 3 덕분에 모델을 실행할 백엔드를 선택할 수 있습니다. 전환하는 방법은 다음과 같습니다.

os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
import keras # import keras after having selected the backend

Keras 3은 대형 언어 모델에 특화된 몇 가지 새로운 기능을 제공합니다. 그중 주요한 하나가 매개변수 효율성이 좋은 미세 조정을 위한 새로운 LoRA(Low Rank Adaptation) API입니다. 이 API의 활성화 방법은 다음과 같습니다.

gemma_lm.backbone.enable_lora(rank=4)
# 참고: rank=4는 관련 레이어의 가중치 행렬을 rank 4의
# 두 행렬의 곱 AxB로 바꾸어
# 학습 가능한 매개변수의 수를 줄입니다.

이 한 줄의 코드로 인해 학습 가능한 매개변수의 수가 25억 개에서 130만 개로 줄어듭니다!

직접 해보기: LoRA를 이용한 Gemma 모델 미세 조정

여러 GPU/TPU에서 Gemma 모델 미세 조정

Keras 3는 대규모 모델 학습도 지원하고, Gemma는 Keras 3를 사용해 보기에 완벽한 모델입니다. 새로운 Keras 배포 API는 데이터 병렬 및 모델 병렬 분산 학습 옵션을 제공합니다. 새로운 API는 다중 백엔드로 만들어졌지만 당분간은 JAX 백엔드용으로만 구현되는데, 그 확장성이 입증되었기 때문입니다(Gemma 모델은 JAX로 학습되었음).

더 큰 Gemma 7B를 미세 조정하려면 분산 설정이 유용합니다(예: Kaggle에서 무료로 이용 가능한 TPU 코어가 8개 있는 TPUv3 또는 Google Cloud의 8GPU 머신). 아래는 모델 병렬 처리를 사용하여 분산 학습용 모델을 구성하는 방법입니다.

device_mesh = keras.distribution.DeviceMesh(
   (1, 8), # Mesh topology
   ["batch", "model"], # named mesh axes
   devices=keras.distribution.list_devices() # actual accelerators
)
 
 
# 모델 구성
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, "model")
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
   None, "model", None)
layout_map["decoder_block.*attention_output.*kernel"] = (
   None, None, "model")
layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, "model")
 
 
# 모델 구성 설정  모델 로드
model_parallel = keras.distribution.ModelParallel(
   device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
# 준비: 이제 model.fit()으로 학습시키거나 generate() 텍스트를 생성할  있음

이 코드 스니펫의 역할은 가속기 8개를 1 x 8 행렬로 설정하는 것으로, 여기서 두 차원을 '배치'와 '모델'이라고 합니다. 모델 가중치는 '모델' 차원에서 분할되며, 여기서는 8개의 가속기 사이에서 분할됩니다. 반면 '배치' 차원이 1이므로 데이터 배치는 분할되지 않습니다.

직접 해보기: 여러 GPU/TPU에서 Gemma 모델 미세 조정

다음 단계

Transformer 모델을 올바르게 분할하고 위에 나온 6줄의 파티션 나누기 설정을 작성하는 방법을 보여주는 가이드를 곧 게시할 예정입니다. 그리 긴 내용은 아니지만 이 게시물에는 맞지 않아 따로 게시하겠습니다.

레이어 파티션 나누기는 레이어 이름에 대한 정규식을 통해 정의된다는 점을 알게 될 것입니다. 이 코드 스니펫으로 레이어 이름을 확인할 수 있습니다. 이 코드 스니펫을 실행하여 위의 LayoutMap을 구성했습니다.

# 이것은  번째 Transformer 블록에만 해당하지만
# 모두 동일한 구조를 가지고 있습니다
tlayer = gemma_lm.backbone.get_layer('decoder_block_0')
for variable in tlayer.weights:
 print(f'{variable.path:<58}  {str(variable.shape):<16}')

Keras가 분산 계산의 다른 세부 정보를 모두 파악하는 강력한 XLA 컴파일러에 이러한 설정을 전달하므로 전체 GSPMD 모델 병렬 처리가 여기서는 단 몇 개의 파티션 나누기 힌트만으로 작동합니다.

Gemma 모델을 즐겁게 활용해 보시길 바랍니다. 명령어 튜닝 튜토리얼도 유용하게 사용하실 수 있을 겁니다. 이제 Kaggle Model Hub에서 사용자 조정 가중치 업로드를 지원하므로 미세 조정된 가중치를 커뮤니티와 공유할 수 있습니다. Kaggle의 Gemma 모델 페이지로 이동하여 다른 사람들이 벌써 무엇을 만들었는지 확인해 보세요!