Transformers PipelineとLangChainによるテキスト生成
前回は Hugging FaceのTransgormers Pipelineを用いたテキスト生成を実装しました。 今回はTransgormers Pipelineに加えてLangChainも使用したテキスト生成を実装してみます。
LangChain を利用すると大規模言語モデルを用いたアプリケーション開発が容易になります。 モデルやデータベース、データの取り込みなどが抽象化されており、統一されたインターフェースを用いて実装できます。 このLangChainとHugging Faceで公開されている学習済み大規模言語モデルを使えば、精度はさておき、無料でChatGPTライクなものが実装できるのでは?と思い、調べてみました。
ちなみにLangChainの使用例を検索すると、OpenAIのAPIを使用した実装例(※LangChainにはOpenAIのAPIのラッパーが用意されています)が多数出てきますが、Hugging FaceのTransformers Pipelineを使用した実装例を見つけられなかったので、試行錯誤しながらの実装となりました。
実装の流れはこんな感じです。
- 学習済みモデルをロード
- タスクやモデルなどを指定してTransformers Pipelineを構築
- PipelineとPromptTemplateを指定してLangChainのLLMChainを構築
- 推論の実行(LLMChainに入力を与えて出力を得る)
なお、今回のコードはこちらです。
学習済みモデルをロード
学習済みモデルとして、今回も、サイバーエージェント社が公開している、日本語データセットで学習済みのOpenCALM-1Bを使用します。
from transformers import AutoModelForCausalLM, AutoTokenizer pretrained_model_name = 'cyberagent/open-calm-1b' model = AutoModelForCausalLM.from_pretrained(pretrained_model_name, torch_dtype=torch.float16).to(device) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
Pipeline を構築
続いてパイプラインを構築します。
テキスト生成の場合はtransformers.pipeline
の引数 task にtext-generation
を指定し、model
とtokenizer
に、先ほどロードした学習済みモデルとトークナイザーを指定します。
kwargs
の値はOpenCALM-1BのUsageを参考にしています。
from transformers import pipeline task='text-generation' kwargs = { 'max_new_tokens': 64, 'do_sample': True, 'temperature': 0.7, 'top_p': 0.9, 'repetition_penalty': 1.05, 'pad_token_id': tokenizer.pad_token_id } pipe = pipeline( task=task, model=model, tokenizer=tokenizer, device=device_id, torch_dtype=torch.float16, **kwargs )
LangChain
LangChainのHuggingFacePipeline
に上で構築したPipelineを指定します。
また、PromptTemplate
で大規模言語モデルに指示を出すためのプロンプトのテンプレートを作成します。
テンプレートにはパラメータを含めることができ、{query}
のような形式で記述します。
のちほどテキスト生成を実行する際、このパラメータに値を渡します。
LLMChain
で大規模言語モデルとプロンプトを統合します。
from langchain.llms import HuggingFacePipeline from langchain.prompts import PromptTemplate from langchain.chains import LLMChain llm = HuggingFacePipeline(pipeline=pipe) template = """ {query} """ prompt = PromptTemplate(template=template, input_variables=['query']) llm_chain = LLMChain(llm=llm, prompt=prompt)
推論
LLMChain.run
で推論を実行します。このとき、テンプレートでパラメータになっていたquery
をrun
の引数で指定します。
llm_chain.run(query='アジャイルソフトウェア開発宣言の内容は')
ちなみに出力結果はこんな感じでした。
'「ソフトウェアは、目的を達成するために手段の1つとして存在しなくてはならない」という原則と、「反復可能な開発によってのみプロジェクトが成功するわけではない。そのプロジェクトは目的に達しないものである」(序文)との両点から成り立っています。「顧客のニーズを満たし(Price of Customer)」を目的に掲げている企業にとってこの宣言'