モデルをどこでロードするのか問題
Chanlitで大規模言語モデル(LLM)をロードする場合、下記のどこで行うのが適切でしょうか?
- グローバルスコープ
- デコレータ
@cl.on_chat_start
を付与した関数内 - デコレータ
@cl.on_message
を付与した関数内
結論から言うと、デコレータ@cl.on_chat_start
を付与した関数内で行うのが適切だと思います。
各方法のメリットやデメリットは下表のとおりです。
LLMをロードする場所 | メリット | デメリット |
---|---|---|
グローバルスコープ | アプリ起動時に1度だけモデルをロードするので効率的 |
|
デコレータ@cl.on_chat_start を付与した関数内 |
|
チャットを開始する度にモデルがロードされるので、最初のメッセージ送信まで待たされる |
デコレータ@cl.on_message を付与した関数内 |
本当に必要になった時にモデルがロードされる | メッセージ送信の度にモデルがロードされるので、非常に効率が悪い |
デコレータ@cl.on_chat_start
を付与した関数内(長いので以下「@cl.on_chat_start
内」と略記)でLLMをロードする場合、チャットを開始する度にLLMをロードする点がネックになりますが、これは次の対応を行うことで回避できます(それでも最初のロード時は待つしかありませんが…)。
- モデルのロード処理を非同期にする
- キャッシュを利用する
- セッションを利用する
非同期化
まず、モデルのロード処理の非同期化について。
Hugging Faceで公開されているモデルをロードする場合、transformers
のAutoModelForCausalLM.from_pretrained
などを使用すると思いますが、これは同期処理です。
同期処理なので、モデルのダウンロードや読み込みが完了するまで他の処理を実行できません。
モデルのロードはI/O待ちが長いため、単純に@cl.on_chat_start
内でAutoModelForCausalLM.from_pretrained
を実行すると、しばらくして、チャットの画面上にサーバに接続できませんでした
と表示されてしまいます。
これではユーザは困ってしまいます。
そこで、asyncio.to_thread
を用いてモデルのロード処理を非同期で実行するようにします。
import aysncio import chainlit as cl from transformers import AutoModelForCausalLM @cl.on_chat_start async def on_chat_start(): (略) model = await asyncio.to_thread(AutoModelForCausalLM.from_pretrained, model_name_or_path)
非同期処理にすることでチャットがフリーズせずに済みます。
キャッシュとセッション
ところで、非同期化してもチャットを開始する度にモデルのロードが行われること自体は変わらないので、チャットが可能になるまでその都度待たされてしまう点は同じです。 その解決策としてキャッシュとセッションを使用します。
ロードしたモデルをキャッシュすることで、モデルのロードを繰り返さないようにすることができます。
モデルをキャッシュするには、モデルをロードして返却する関数を定義し、その関数にデコレータ@cl.cache
を付与すればOKです。
また、ロードしたモデルをセッションに保存しておくことで、同一セッション内では、セッションからモデルを取り出して使うことができます。
- セッションへの保存:
cl.user_session.set(セッション名, 保存したいオブジェクト)
- セッションからの取得:
cl.user_session.get(セッション名)
まとめると以下のような実装イメージになります。
import aysncio import chainlit as cl from transformers import AutoModelForCausalLM @cl.cache def load_model(model_name_or_path): model = AutoModelForCausalLM.from_pretrained(model_name_or_path) return model @cl.on_chat_start async def on_chat_start(): # model_name_or_pathにはモデル名またはパスを指定 model = await asyncio.to_thread(load_model, model_name_or_path) # 名前を指定してセッションにモデルを保存 cl.user_session.set('model', model) @cl.on_message async def on_message(message): # 保存したときと同じ名前を指定してセッションからモデルを取得 model = cl.user_session.get('model')
処理状況をユーザに通知
これまでの対応で処理としては問題ありませんが、ユーザにはモデルがロード中なのか終わったのかが分かりません。
そのようなときはcl.Step
を使用すると便利です。
下記のようにwith cl.Step
節内でモデルのロードを行うことで、チャット上ではモデルをロード中である旨が表示されるようになります。
import aysncio import chainlit as cl from transformers import AutoModelForCausalLM @cl.cache def load_model(model_name_or_path): model = AutoModelForCausalLM.from_pretrained(model_name_or_path) return model @cl.on_chat_start async def on_chat_start(): message = cl.Message(content='モデルをロード中です。しばらくお待ちください。') await message.send() with cl.Step(name='モデルをロード', type='llm'): # model_name_or_pathにはモデル名またはパスを指定 model = await asyncio.to_thread(load_model, model_name_or_path) # 名前を指定してセッションにモデルを保存 cl.user_session.set('model', model) message.content = 'モデルのロードが完了しました。' await message.update() await cl.Message(content='ようこそ!ご用件は何でしょうか?').send()
モデルのロード中は使用中 モデルをロード
と表示され、モデルのロード処理が継続していることが分かります。
モデルのロードが完了すると表示が使用済み モデルをロード
に変わります。
まとめ
ChainlitでLLMを使用する場合は下記対応を行うと良いでしょう。
- モデルのロードはデコレータ
@cl.on_chat_start
を付与した関数内で行う - モデルのロードはI/O待ちが長いので
asyncio.to_thread
で非同期実行する - キャッシュを利用する
- セッションを利用する
cl.Step
でユーザに進捗を通知する