Gemma para transmisión de aprendizaje automático con Dataflow

AGO 16, 2024
Reza Rokni Google Senior Staff Dataflow
Ravin Kumar Google Data Scientist Language Applications

Gemma 2 es la última versión de la familia de modelos abiertos ligeros y de última generación de Google, y se creó sobre la base de la misma investigación y tecnología utilizada para crear los modelos de Gemini. Los modelos de lenguaje grandes (LLM) como Gemma son notablemente versátiles, lo que ofrece la posibilidad de implementar muchas integraciones en los procesos comerciales. En esta entrada del blog, exploramos cómo puedes usar Gemma para conocer las opiniones en una conversación, resumir el contenido de esa conversación y ayudar a crear una respuesta para una conversación difícil que luego pueda aprobar una persona. Uno de los requisitos clave es que los clientes que expresaron una opinión negativa tengan sus necesidades atendidas casi en tiempo real, lo que significa que tendremos que hacer uso de una canalización de datos de transmisión que aproveche los LLM con una latencia mínima.


Gemma

Gemma 2 ofrece un rendimiento inigualable en relación con su tamaño. Se demostró que los modelos de Gemma logran resultados de referencia excepcionales, e incluso superan a algunos modelos más grandes. El tamaño pequeño de los modelos permite utilizar arquitecturas donde el modelo se implementa o incorpora directamente en la canalización de procesamiento de datos de transmisión, lo que brinda los siguientes beneficios, entre otros:

  • Localidad de datos con llamadas de workers locales, en lugar de RPC de datos a un sistema separado

  • Un solo sistema para escalar de manera automática, lo que permite usar métricas como la contrapresión en la fuente como señales directas al escalador automático

  • Un solo sistema para observar y supervisar en producción

Dataflow proporciona una plataforma de procesamiento por lotes y transmisión escalable y unificada. Con Dataflow, puedes usar el SDK de Apache Beam Python para desarrollar datos de transmisión y canalizaciones de procesamiento de eventos. Dataflow proporciona los siguientes beneficios:

  • Dataflow es completamente administrado y cuenta con ajuste de escala automático hacia arriba y hacia abajo en función de la demanda

  • Apache Beam proporciona un conjunto de transformaciones de código bajo listas para usar que pueden ahorrarte tiempo, esfuerzo y costo al escribir código estándar genérico. Después de todo, el mejor código es el que no tienes que escribir

  • Dataflow ML admite GPUs de forma directa, siempre que instales los controladores necesarios y proporciones acceso a una gama de dispositivos GPU

En el siguiente ejemplo, se muestra cómo incorporar el modelo Gemma dentro de la canalización de datos de transmisión para ejecutar inferencias utilizando Dataflow.


Situación

Esta situación gira en torno a una ajetreada cadena en la que se analiza y almacena un gran volumen de solicitudes de atención al cliente a través de varios canales de chat. Estas interacciones incluyen tanto chats generados por chatbots automatizados como conversaciones más matizadas que requieren la atención del personal de asistencia al cliente en vivo. En respuesta a este desafío, nos fijamos objetivos ambiciosos:

  • En primer lugar, queremos administrar y almacenar de manera eficiente los datos del chat resumiendo las interacciones positivas para facilitar la referencia y el análisis futuro.

  • En segundo lugar, queremos implementar la detección y resolución de problemas en tiempo real, y utilizar el análisis de opiniones para identificar rápidamente a los clientes insatisfechos y generar respuestas personalizadas a fin de abordar sus inquietudes.

La solución utiliza una canalización que procesa los mensajes de chat completados casi en tiempo real. Gemma se utiliza, en primera instancia, para realizar trabajos de análisis supervisando las opiniones en estos chats. Todos los chats se resumen y los que incluyen opiniones positivas o neutrales se envían directamente a una plataforma de datos, BigQuery, que utiliza E/S listas para usar con Dataflow. En el caso de los chats en los que se expresan opiniones negativas, usamos Gemma para pedirle al modelo que elabore una respuesta contextualmente apropiada para el cliente insatisfecho. Esta respuesta se envía a un humano para su revisión, lo que permite al personal de asistencia refinar el mensaje antes de que llegue a un cliente potencialmente insatisfecho.

Con este caso de uso, exploramos algunos aspectos interesantes del uso de un LLM dentro de una canalización. Por ejemplo, se presentan desafíos cuando se deben procesar las respuestas en código, dadas las respuestas no deterministas que se pueden aceptar. Por ejemplo, le pedimos a nuestro LLM que responda en JSON, lo cual no está garantizado. Esta solicitud requiere que analicemos y validemos la respuesta, que es un proceso similar a cómo procesarías normalmente los datos de fuentes que pueden no tener datos estructurados correctamente.

Gracias a esta solución, los clientes pueden disfrutar de respuestas más rápidas y recibir atención personalizada cuando surgen problemas. La automatización del resumen positivo del chat libera tiempo para el personal de asistencia, lo que permite a estos centrarse en interacciones más complejas. Además, el análisis en profundidad de los datos del chat puede impulsar la toma de decisiones basada en datos, mientras que la escalabilidad del sistema le permite adaptarse sin esfuerzo a los crecientes volúmenes de los chats, sin comprometer la calidad de las respuestas.


La canalización de procesamiento de datos

El flujo de la canalización se puede ver a continuación:

Data processing pipeline architecture

La canalización de alto nivel se puede describir con algunas líneas:

  1. Lee los datos de las opiniones de Pub/Sub, nuestra fuente de mensajería para eventos. Estos datos contienen el ID y el historial del chat como una carga útil JSON. Esta carga útil se procesa en la canalización.

2. La canalización pasa el texto de este mensaje a Gemma con un mensaje. La canalización solicita que se completen dos tareas.

  • Adjunta una puntuación de opinión al mensaje, utilizando los siguientes tres valores: 1 para un chat con opinión positiva, 0 para un chat neutral y -1 para un chat con opinión negativa.

  • Resume el chat con una sola oración.

3. A continuación, la canalización se ramifica según la puntuación de la opinión:

  • Si la puntuación es 1 o 0, el chat con resumen se envía a nuestro sistema de análisis de datos para su almacenamiento y uso en análisis futuros.

  • Si la puntuación es -1, le pedimos a Gemma que proporcione una respuesta. Esta respuesta, combinada con la información del chat, se envía a un sistema de mensajería para eventos que funciona como vínculo entre la canalización y otras aplicaciones. Este paso permite que una persona revise el contenido.


El código de la canalización

Configuración

Acceder a Gemma y descargarlo

En nuestro ejemplo, usamos Gemma a través de KerasNLP y usamos la variante “ajustada según instrucción” de Kaggle gemma2_keras_gemma2_instruct_2b_en. Debes descargar el modelo y almacenarlo en una ubicación a la que pueda acceder la canalización.


Usar el servicio de Dataflow

Si bien es posible usar CPUs para pruebas y desarrollo, dados los tiempos de inferencia, para un sistema de producción necesitamos usar GPUs en el servicio de Dataflow ML. El uso de GPUs con Dataflow se ve facilitado por un contenedor personalizado. Los detalles de esta configuración están disponibles en la compatibilidad de Dataflow con GPU. Te recomendamos que sigas la guía de desarrollo local para el desarrollo, que permite realizar una prueba rápida de la canalización. También puedes consultar la guía para usar Gemma en Dataflow, que incluye vínculos a un Dockerfile de ejemplo.


Controlador de modelos personalizado de Gemma

La transformación RunInference de Apache Beam es la parte más importante de esta solución: utiliza un controlador de modelos para la configuración y abstrae al usuario del código estándar repetitivo necesario para la producción. La mayoría de los tipos de modelos pueden ser compatibles con la configuración solo utilizando los controladores de modelos integrados de Beam, pero para Gemma, en este blog usamos un controlador de modelos personalizado, que nos da un control total de nuestras interacciones con el modelo mientras seguimos utilizando toda la maquinaria que RunInference proporciona para el procesamiento. La canalización custom_model_gemma.py tiene un ejemplo de GemmModelHandler que puedes usar. Ten en cuenta cómo se usa el valor max_length en la llamada model.generate() desde ese GemmModelHandler. Este valor controla la longitud máxima de la respuesta de Gemma a las consultas y deberá cambiarse para que coincida con las necesidades del caso de uso; para este blog, usamos el valor 512.

Sugerencia: para este blog, descubrimos que usar el backend de jax keras permitió obtener un rendimiento mucho más alto. Para habilitar esta función, el DockerFile debe contener la instrucción ENV KERAS_BACKEND="jax". Debes configurar esta opción en tu contenedor antes de que el worker inicie Beam (que importa Keras)


Crear la canalización

El primer paso en la canalización es estándar en todos los sistemas de procesamiento de eventos: necesitamos leer los mensajes JSON que crearon nuestros sistemas ascendentes, que empaquetan los mensajes de chat en una estructura simple que incluye el ID del chat.

chats = ( pipeline | "Read Topic" >>
                        beam.io.ReadFromPubSub(subscription=args.messages_subscription)
| "Decode" >> beam.Map(lambda x: x.decode("utf-8")
   )

En el siguiente ejemplo, se muestra uno de estos mensajes JSON, así como un debate muy importante sobre la piña y la pizza, y el ID 221 es nuestro cliente.

{
"id": 1, 
"user_id": 221, 
"chat_message": "\\nid 221: Hay I am really annoyed that your menu includes a pizza with pineapple on it! \\nid 331: Sorry to hear that , but pineapple is nice on pizza\\nid 221: What a terrible thing to say! Its never ok, so unhappy right now! \\n"
}

Ahora tenemos una PCollection de objetos de chat de Python. En el siguiente paso, extraemos los valores necesarios de estos mensajes de chat y los incorporamos en un mensaje para pasar a nuestro LLM ajustado según la instrucción. Para completar este paso, creamos una plantilla de mensaje que proporciona instrucciones para el modelo.

prompt_template = """
<prompt>
Provide the results of doing these two tasks on the chat history provided below for the user {}
task 1 : assess if the tone is happy = 1 , neutral = 0 or angry = -1
task 2 : summarize the text with a maximum of 512 characters
Output the results as a json with fields [sentiment, summary]
 
@@@{}@@@
<answer>
"""

El siguiente es un ejemplo de un mensaje que se envía al modelo:

<prompt>
Provide the results of doing these two tasks on the chat history provided below for the user 221
task 1 : assess if the tone is happy = 1 , neutral = 0 or angry = -1
task 2 : summarize the text with a maximum of 512 characters
Output the results as a json with fields [sentiment, summary]
 
@@@"\\nid 221: Hay I am really annoyed that your menu includes a pizza with pineapple on it! \\nid 331: Sorry to hear that , but pineapple is nice on pizza\\nid 221: What a terrible thing to say! Its never ok, so unhappy right now! \\n"@@@
<answer>

Algunas notas sobre el mensaje:

  1. Este mensaje pretende ser un ejemplo ilustrativo. Para crear tus propias indicaciones, ejecuta un análisis completo con datos indicativos para tu aplicación.

  • En el caso de la creación de prototipos, puedes usar aistudio.google.com para probar el comportamiento de Gemma y Gemini rápidamente. También hay una clave de API de un clic que puedes usar si deseas hacer pruebas mediante programación.

2. Con modelos más pequeños y menos potentes, es posible obtener mejores respuestas simplificando las instrucciones a una sola tarea y haciendo múltiples llamadas al modelo.

3. Limitamos los resúmenes de los mensajes de chat a un máximo de 512 caracteres. Haz coincidir este valor con el valor que se proporciona en la configuración max_length a la llamada de generación de Gemma.

4. Las tres Y comerciales, '@@ @', se utilizan como un truco para permitirnos extraer los chats originales del mensaje después del procesamiento. Entre otras formas en que podemos hacer esta tarea, se incluyen las siguientes:

  • Usa todo el mensaje de chat como clave en el par clave-valor.

  • Reúne los resultados con los datos originales. Este enfoque requiere una combinación aleatoria.

5. Como necesitamos procesar la respuesta en código, le pedimos al LLM que cree una representación JSON de su respuesta con dos campos: sentiment (opinión) y summary (resumen).

Para crear el mensaje, debemos analizar la información de nuestro mensaje JSON de origen y luego insertarla en la plantilla. Encapsulamos este proceso en un DoFN de Beam y lo usamos en nuestra canalización. En nuestra sentencia de rendimiento, construimos una estructura clave-valor, con el ID del chat como clave. Esta estructura nos permite hacer coincidir el chat con la inferencia cuando llamamos al modelo.

# Create the prompt using the information from the chat
class CreatePrompt(beam.DoFn):
  def process(self, element, *args, **kwargs):
    user_chat = json.loads(element)
    chat_id = user_chat['id']
    user_id = user_chat['user_id']
    messages = user_chat['chat_message']
    yield (chat_id, prompt_template.format(user_id, messages))
 
prompts = chats |  "Create Prompt" >> beam.ParDo(CreatePrompt())

Ya tenemos todo listo para llamar a nuestro modelo. Gracias a la maquinaria RunInference, este paso es sencillo. Envolvemos el GemmaModelHandler dentro de un KeyedModelhandler, que le indica a RunInference que acepte los datos entrantes como una tupla de par clave-valor. Durante el desarrollo y las pruebas, el modelo se almacena en el directorio gemma2. Al ejecutar el modelo en el servicio de Dataflow ML, el modelo se almacena en Google Cloud Storage, con el formato URI gs://<your_bucket>/gemma-directory.

keyed_model_handler = KeyedModelHandler(GemmaModelHandler('gemma2'))
results =  prompts | "RunInference-Gemma" >> RunInference(keyed_model_handler)

La colección de resultados ahora contiene los resultados de la llamada al LLM. En este punto, las cosas se ponen un poco interesantes: aunque la llamada al LLM es código, a diferencia de las llamadas a otras funciones, los resultados no son deterministas. Aquí se incluye ese bit final de nuestra solicitud de aviso "Output the results as a JSON with fields [sentiment, summary]". En general, la respuesta coincide con esa forma, pero no está garantizada. Necesitamos estar un poco a la defensiva y validar nuestra opinión. Si falla la validación, enviamos los resultados a una colección de errores. En esta muestra, dejamos esos valores ahí. Para un proceso de producción, te conviene dejar que el LLM lo intente por segunda vez y ejecute los resultados de la recopilación de errores en RunInference nuevamente, y luego alinear la respuesta con la colección de resultados. Debido a que las canalizaciones de Beam son gráficos acíclicos dirigidos, no podemos crear un bucle en este caso.

Ahora tomamos la colección de resultados y procesamos la salida del LLM. Para procesar los resultados de RunInference, creamos un nuevo SentimentAnalysis de DoFn y la función extract_model_reply. Este paso muestra un objeto de tipo PredictionResult:

def extract_model_reply(model_inference):
    match = re.search(r"(\{[\s\S]*?\})", model_inference)
    json_str = match.group(1)
    result = json.loads(json_str)
    if all(key in result for key in ['sentiment', 'summary']):
        return result
    raise Exception('Malformed model reply')
class SentimentAnalysis(beam.DoFn):
    def process(self, element):
        key = element[0]                          
        match = re.search(r"@@@([\s\S]*?)@@@", element[1].example)
        chats = match.group(1)
 
        try:
            # The result will contain the prompt, replace the prompt with ""
            result = extract_model_reply(element[1].inference.replace(element[1].example, ""))
            processed_result = (key, chats, result['sentiment'], result['summary'])           
 
            if (result['sentiment'] <0):
              output = beam.TaggedOutput('negative', processed_result)
            else:
              output = beam.TaggedOutput('main', processed_result)
 
        except Exception as err:
            print("ERROR!" + str(err))
            output = beam.TaggedOutput('error', element)
 
        yield output

Vale la pena dedicar unos minutos al hecho de que necesitamos extract_model_reply(). Debido a que el modelo es autoalojado, no podemos garantizar que el texto sea una salida JSON. Para asegurarnos de que obtenemos una salida JSON, necesitamos ejecutar un par de comprobaciones. Uno de los beneficios de usar la API de Gemini es que incluye una función que garantiza que la salida siempre sea JSON. A esta función se la conoce como decodificación restringida.

Ahora, usemos estas funciones en nuestra canalización:

filtered_results = (results | "Process Results" >> beam.ParDo(SentimentAnalysis()).with_outputs('main','negative','error'))

El uso de with_outputs permite crear múltiples colecciones accesibles en filter_results. La colección principal tiene opiniones y resúmenes de revisiones positivas y neutrales, mientras que la de errores contiene cualquier respuesta no analizable del LLM. Puedes enviar estas colecciones a otras fuentes, como BigQuery, con una transformación de escritura. En este ejemplo, no se demuestra este paso; sin embargo, queremos implementar la colección de opiniones negativas dentro de esta canalización.


Procesamiento de opiniones negativas

Garantizar que los clientes estén contentos es fundamental para la retención. Si bien utilizamos un ejemplo alegre en nuestro debate sobre la piña en la pizza, las interacciones directas con un cliente siempre deben lograr empatía y respuestas positivas de todas las partes de una organización. En esta etapa, pasamos este chat a uno de los representantes de atención capacitados, pero podemos ver si el LLM puede ayudar a esa persona a reducir el tiempo de resolución.

En este paso, hacemos una llamada al modelo y le pedimos que formule una respuesta. Volvemos a utilizar el modelo Gemma 2B para esta llamada en el código.

generated_responses = (results.negative 
       | "Generate Response" >> beam.Map(lambda x: ((x[0], x[3]), "<prompt>Generate an apology response for the user in this chat text: " + x[1] + "<answer>"))
       | "Gemma-Response" >> RunInference(keyed_model_handler)

En general, envuelves el código de creación de mensajes en un DoFn, pero también es posible usar una lambda simple en el propio código de la canalización. Aquí generamos un mensaje que contiene el mensaje de chat original, que se extrajo en la función SentimentAnalysis.

Para la ejecución y prueba local, podemos usar algunas sentencias de impresión simples a fin de ver los resultados en las diversas PCollections:

generated_responses | "Print Response" >> beam.Map(print)
filtered_results.main | "Print Main" >> beam.Map(print)
filtered_results.error | "Print Errors" >> beam.Map(print)

Por supuesto, para el uso real, estas salidas se enviarán a varios receptores, como Pub/Sub y BigQuery.


Ejecutar la canalización

Veamos cómo le va al modelo con el mensaje JSON anterior:

Paso 1: Análisis de opiniones y resumen

"sentiment": -1,

"summary": "User 221 is very unhappy about the presence of pineapple on pizza."

Las respuestas que generó el modelo 2B no son malas. La opinión es correcta y, debido a que los resultados del resumen son más subjetivos, la exactitud de la respuesta depende de los usos posteriores de esta información.

Paso 2: Respuesta generada

"I understand that you're upset about the pineapple pizza. It's a very personal preference, and I apologize that it might have caused you some frustration. We strive to offer a diverse menu to cater to a wide range of tastes, and we're always open to feedback. Would you like to share your thoughts on the pizza with pineapple?"

¿Son aceptables estas respuestas? En esta etapa, tenemos la intención de enviar todo el paquete de datos a un representante de atención para que lo analice y, si queda satisfecho, puede enviarlo tal cual está, o bien puede realizar algunas modificaciones y ajustes.


Próximos pasos

Es posible que en esta etapa queramos usar un modelo con más parámetros, como Gemma2 9B o 27B. También podríamos usar un modelo que sea lo suficientemente grande como para requerir una llamada de la API a una llamada de un servicio externo, como Gemini, en lugar de cargarse en un worker. Después de todo, redujimos el trabajo necesario para hacer envíos a estos modelos más grandes utilizando el modelo más pequeño como filtro. Tomar estas decisiones no es solo una decisión técnica, sino también una decisión comercial. Es necesario medir los costos y beneficios. En este punto, también podemos usar Dataflow para configurar más fácilmente las pruebas A/B.

También puedes optar por ajustar un modelo personalizado según tu caso de uso. Esta es una forma de cambiar la “voz” del modelo para que se adapte a tus necesidades.


Pruebas A/B

En nuestro paso de generación, pasamos todos los chats negativos entrantes a nuestro modelo 2B. Si queremos enviar una parte de la colección a otro modelo, podemos usar la función Partition en Beam con la colección filter_responses.negative. Al dirigir algunos mensajes de los clientes a diferentes modelos y hacer que el personal de asistencia califique las respuestas generadas antes de enviarlas, podemos recopilar comentarios valiosos sobre la calidad de la respuesta y los márgenes de mejora.


Resumen

Con esas pocas líneas de código, creamos un sistema capaz de procesar los datos de opiniones de los clientes con gran velocidad y variabilidad. Al utilizar el modelo abierto Gemma 2, con su “rendimiento inigualable en relación con su tamaño”, pudimos incorporar este poderoso LLM dentro de un caso de uso de procesamiento de transmisión que ayuda a crear una mejor experiencia para los clientes.