Memperkenalkan model Gemma di Keras

FEB 21, 2024
Martin Görner Product Manager Keras

Tim Keras dengan bangga mengumumkan Gemma, keluarga model terbuka yang ringan dan canggih yang dibangun dari riset dan teknologi yang sama dengan yang kami gunakan untuk membuat model Gemini, kini tersedia dalam koleksi KerasNLP. Berkat Keras 3, Gemma dapat berjalan di JAX, PyTorch, dan TensorFlow. Dengan rilis ini, Keras juga memperkenalkan beberapa fitur baru yang dirancang khusus untuk model bahasa besar: LoRA API (Low Rank Adaptation) baru dan kemampuan pelatihan model-paralel berskala besar.

Jika Anda ingin langsung mempelajari contoh kode, masuk ke sini:

Mulai

Model Gemma hadir dalam ukuran parameter portabel 2B dan 7B, dan memberikan peningkatan yang signifikan terhadap model terbuka yang serupa, bahkan beberapa model yang lebih besar. Misalnya:

  • Gemma 7B meraih skor 64,3% jawaban benar, terbaik di kelasnya dalam tolok ukur pemahaman bahasa MMLU (vs. 62,5% untuk Mistral-7B dan 54,8% untuk Llama2-13B)
  • Gemma menambahkan +11 poin persentase pada skor tolok ukur GSM8K untuk soal matematika sekolah dasar (46,4% untuk Gemma 7B vs. Mistral-7B 35,4%, Llama2-13B 28,7%)
  • dan +6,1 poin persentase jawaban benar di HumanEval, sebuah tantangan coding (32,3% untuk Gemma 7B, vs. Mistral 7B 26,2%, Llama2 13B 18,3%).

Model Gemma ditawarkan dengan KerasNLP API yang familier dan implementasi Keras yang sangat mudah dibaca. Anda bisa membuat instance model dengan satu baris kode:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

Dan menjalankannya secara langsung pada prompt teks – ya, tokenisasi sudah ada di dalamnya meskipun Anda bisa membaginya dengan mudah jika diperlukan - baca Panduan Keras NLP untuk mengetahui caranya.

gemma_lm.generate("Keras is a", max_length=32)
> "Keras is a popular deep learning framework for neural networks..."

Coba di sini: Memulai dengan model Gemma

Berkat Keras 3, Anda bisa memilih backend tempat menjalankan model. Berikut adalah cara untuk beralih:

os.environ["KERAS_BACKEND"] = "jax" # Or "tensorflow" or "torch".
import keras # import keras after having selected the backend

Keras 3 hadir dengan beberapa fitur baru yang khusus untuk model bahasa besar. Yang paling utama adalah LoRA API (Low Rank Adaptation) baru untuk menyesuaikan parameter secara efisien. Berikut adalah cara mengaktifkannya:

gemma_lm.backbone.enable_lora(rank=4)
# Note: rank=4 replaces the weights matrix of relevant layers with the
# product AxB of two matrices of rank 4, which reduces the number of
# trainable parameters.

Satu baris ini menurunkan jumlah parameter yang dapat dilatih dari 2,5 miliar menjadi 1,3 juta!

Coba di sini: Menyesuaikan model Gemma dengan LoRA.

Menyesuaikan model Gemma pada beberapa GPU/TPU

Keras 3 juga mendukung pelatihan model berskala besar dan Gemma adalah model yang sempurna untuk mencobanya. Keras distribution API yang baru menawarkan opsi pelatihan terdistribusi data-paralel dan model-paralel. API baru ini dirancang multi-backend, tetapi untuk saat ini, ia diimplementasikan hanya untuk backend JAX karena skalabilitasnya yang telah terbukti (model Gemma dilatih dengan JAX).

Untuk menyesuaikan Gemma 7B yang lebih besar, pengaturan terdistribusi sangat berguna, misalnya TPUv3 dengan 8 inti TPU yang bisa Anda dapatkan secara gratis di Kaggle, atau perangkat 8-GPU dari Google Cloud. Berikut adalah cara mengonfigurasi model untuk pelatihan terdistribusi, menggunakan paralelisme model:

device_mesh = keras.distribution.DeviceMesh(
   (1, 8), # Mesh topology
   ["batch", "model"], # named mesh axes
   devices=keras.distribution.list_devices() # actual accelerators
)
 
 
# Model config
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (None, "model")
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
   None, "model", None)
layout_map["decoder_block.*attention_output.*kernel"] = (
   None, None, "model")
layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
layout_map["decoder_block.*ffw_linear.*kernel"] = (None, "model")
 
 
# Set the model config and load the model
model_parallel = keras.distribution.ModelParallel(
   device_mesh, layout_map, batch_dim_name="batch")
keras.distribution.set_distribution(model_parallel)
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_7b_en")
# Ready: you can now train with model.fit() or generate text with generate()

Yang dilakukan cuplikan kode ini adalah mengatur 8 akselerator ke dalam matriks 1 x 8 di mana dua dimensi tersebut disebut "batch" dan "model". Bobot model dipecah pada dimensi "model", di sini dibagi di antara 8 akselerator, sementara batch data tidak dipartisi karena dimensi "batch" adalah 1.

Coba di sini: Menyesuaikan model Gemma pada beberapa GPU/TPU.

Apa Berikutnya

Kami akan segera memublikasikan panduan yang menunjukkan kepada Anda cara mempartisi model Transformer dengan benar dan menulis 6 baris pengaturan partisi di atas. Panduannya tidak terlalu panjang, tetapi tidak akan muat dalam postingan ini.

Anda akan melihat bahwa partisi lapisan didefinisikan melalui ekspresi reguler pada nama lapisan. Anda bisa memeriksa nama lapisan dengan cuplikan kode ini. Kami menjalankannya untuk membuat LayoutMap di atas.

# This is for the first Transformer block only,
# but they all have the same structure
tlayer = gemma_lm.backbone.get_layer('decoder_block_0')
for variable in tlayer.weights:
 print(f'{variable.path:<58}  {str(variable.shape):<16}')

Paralelisme model GSPMD penuh bekerja di sini hanya dengan beberapa petunjuk partisi karena Keras meneruskan setelan ini ke compiler XLA kuat yang mengetahui semua detail lain dari komputasi terdistribusi.

Kami harap Anda suka bermain-main dengan model Gemma. Di sini juga ada tutorial panduan penyesuaian yang mungkin berguna bagi Anda. Dan omong-omong, jika Anda ingin membagikan bobot yang telah disesuaikan dengan komunitas, Kaggle Model Hub sekarang mendukung upload bobot yang telah diatur pengguna. Buka halaman model untuk model Gemma di Kaggle dan lihat kreasi yang sudah dibuat pengguna lainnya!