noriho137’s diary

機械学習, 時々Web

Chainlitで大規模言語モデルを非同期にロードしてキャッシュする

モデルをどこでロードするのか問題

Chanlitで大規模言語モデル(LLM)をロードする場合、下記のどこで行うのが適切でしょうか?

  1. グローバルスコープ
  2. デコレータ@cl.on_chat_startを付与した関数内
  3. デコレータ@cl.on_messageを付与した関数内

結論から言うと、デコレータ@cl.on_chat_startを付与した関数内で行うのが適切だと思います。

各方法のメリットやデメリットは下表のとおりです。

LLMをロードする場所 メリット デメリット
グローバルスコープ アプリ起動時に1度だけモデルをロードするので効率的
  • アプリの起動に時間がかかる
  • ユーザによるモデルの選択ができない
デコレータ@cl.on_chat_startを付与した関数内
  • ユーザごとに独立したモデルを使用可能
  • アプリの起動時間を短縮できる
チャットを開始する度にモデルがロードされるので、最初のメッセージ送信まで待たされる
デコレータ@cl.on_messageを付与した関数内 本当に必要になった時にモデルがロードされる メッセージ送信の度にモデルがロードされるので、非常に効率が悪い

デコレータ@cl.on_chat_startを付与した関数内(長いので以下「@cl.on_chat_start内」と略記)でLLMをロードする場合、チャットを開始する度にLLMをロードする点がネックになりますが、これは次の対応を行うことで回避できます(それでも最初のロード時は待つしかありませんが…)。

  • モデルのロード処理を非同期にする
  • キャッシュを利用する
  • セッションを利用する

非同期化

まず、モデルのロード処理の非同期化について。

Hugging Faceで公開されているモデルをロードする場合、transformersAutoModelForCausalLM.from_pretrainedなどを使用すると思いますが、これは同期処理です。 同期処理なので、モデルのダウンロードや読み込みが完了するまで他の処理を実行できません。 モデルのロードはI/O待ちが長いため、単純に@cl.on_chat_start内でAutoModelForCausalLM.from_pretrainedを実行すると、しばらくして、チャットの画面上にサーバに接続できませんでしたと表示されてしまいます。 これではユーザは困ってしまいます。

同期処理でモデルをロードするとフリーズ

そこで、asyncio.to_threadを用いてモデルのロード処理を非同期で実行するようにします。

import aysncio
import chainlit as cl
from transformers import AutoModelForCausalLM

@cl.on_chat_start
async def on_chat_start():
(略)
    model = await asyncio.to_thread(AutoModelForCausalLM.from_pretrained,
                                    model_name_or_path)

非同期処理にすることでチャットがフリーズせずに済みます。

キャッシュとセッション

ところで、非同期化してもチャットを開始する度にモデルのロードが行われること自体は変わらないので、チャットが可能になるまでその都度待たされてしまう点は同じです。 その解決策としてキャッシュとセッションを使用します。

ロードしたモデルをキャッシュすることで、モデルのロードを繰り返さないようにすることができます。 モデルをキャッシュするには、モデルをロードして返却する関数を定義し、その関数にデコレータ@cl.cacheを付与すればOKです。

また、ロードしたモデルをセッションに保存しておくことで、同一セッション内では、セッションからモデルを取り出して使うことができます。

  • セッションへの保存:cl.user_session.set(セッション名, 保存したいオブジェクト)
  • セッションからの取得:cl.user_session.get(セッション名)

まとめると以下のような実装イメージになります。

import aysncio
import chainlit as cl
from transformers import AutoModelForCausalLM


@cl.cache
def load_model(model_name_or_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
    return model


@cl.on_chat_start
async def on_chat_start():
    # model_name_or_pathにはモデル名またはパスを指定
    model = await asyncio.to_thread(load_model,
                                    model_name_or_path)

    # 名前を指定してセッションにモデルを保存
    cl.user_session.set('model', model)


@cl.on_message
async def on_message(message):
    # 保存したときと同じ名前を指定してセッションからモデルを取得
    model = cl.user_session.get('model')

処理状況をユーザに通知

これまでの対応で処理としては問題ありませんが、ユーザにはモデルがロード中なのか終わったのかが分かりません。 そのようなときはcl.Stepを使用すると便利です。 下記のようにwith cl.Step節内でモデルのロードを行うことで、チャット上ではモデルをロード中である旨が表示されるようになります。

import aysncio
import chainlit as cl
from transformers import AutoModelForCausalLM


@cl.cache
def load_model(model_name_or_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
    return model


@cl.on_chat_start
async def on_chat_start():
    message = cl.Message(content='モデルをロード中です。しばらくお待ちください。')
    await message.send()

    with cl.Step(name='モデルをロード', type='llm'):
        # model_name_or_pathにはモデル名またはパスを指定
        model = await asyncio.to_thread(load_model,
                                        model_name_or_path)

    # 名前を指定してセッションにモデルを保存
    cl.user_session.set('model', model)

    message.content = 'モデルのロードが完了しました。'
    await message.update()

    await cl.Message(content='ようこそ!ご用件は何でしょうか?').send()

モデルのロード中は使用中 モデルをロードと表示され、モデルのロード処理が継続していることが分かります。

モデルのロード中

モデルのロードが完了すると表示が使用済み モデルをロードに変わります。

モデルのロード完了後

まとめ

ChainlitでLLMを使用する場合は下記対応を行うと良いでしょう。

  • モデルのロードはデコレータ@cl.on_chat_startを付与した関数内で行う
  • モデルのロードはI/O待ちが長いのでasyncio.to_threadで非同期実行する
  • キャッシュを利用する
  • セッションを利用する
  • cl.Stepでユーザに進捗を通知する