noriho137’s diary

機械学習, 時々Web

Transformers PipelineとLangChainによるテキスト生成

前回は Hugging FaceのTransgormers Pipelineを用いたテキスト生成を実装しました。 今回はTransgormers Pipelineに加えてLangChainも使用したテキスト生成を実装してみます。

LangChain を利用すると大規模言語モデルを用いたアプリケーション開発が容易になります。 モデルやデータベース、データの取り込みなどが抽象化されており、統一されたインターフェースを用いて実装できます。 このLangChainとHugging Faceで公開されている学習済み大規模言語モデルを使えば、精度はさておき、無料でChatGPTライクなものが実装できるのでは?と思い、調べてみました。

ちなみにLangChainの使用例を検索すると、OpenAIのAPIを使用した実装例(※LangChainにはOpenAIのAPIのラッパーが用意されています)が多数出てきますが、Hugging FaceのTransformers Pipelineを使用した実装例を見つけられなかったので、試行錯誤しながらの実装となりました。

実装の流れはこんな感じです。

  1. 学習済みモデルをロード
  2. タスクやモデルなどを指定してTransformers Pipelineを構築
  3. PipelineとPromptTemplateを指定してLangChainのLLMChainを構築
  4. 推論の実行(LLMChainに入力を与えて出力を得る)

なお、今回のコードはこちらです。

github.com

学習済みモデルをロード

学習済みモデルとして、今回も、サイバーエージェント社が公開している、日本語データセットで学習済みのOpenCALM-1Bを使用します。

from transformers import AutoModelForCausalLM, AutoTokenizer

pretrained_model_name = 'cyberagent/open-calm-1b'

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)

Pipeline を構築

続いてパイプラインを構築します。 テキスト生成の場合はtransformers.pipelineの引数 task にtext-generationを指定し、modeltokenizerに、先ほどロードした学習済みモデルとトークナイザーを指定します。 kwargsの値はOpenCALM-1BのUsageを参考にしています。

from transformers import pipeline

task='text-generation'

kwargs = {
    'max_new_tokens': 64,
    'do_sample': True,
    'temperature': 0.7,
    'top_p': 0.9,
    'repetition_penalty': 1.05,
    'pad_token_id': tokenizer.pad_token_id
}

pipe = pipeline(
    task=task,
    model=model,
    tokenizer=tokenizer,
    device=device_id,
    torch_dtype=torch.float16,
    **kwargs
)

LangChain

LangChainのHuggingFacePipelineに上で構築したPipelineを指定します。 また、PromptTemplateで大規模言語モデルに指示を出すためのプロンプトのテンプレートを作成します。 テンプレートにはパラメータを含めることができ、{query}のような形式で記述します。 のちほどテキスト生成を実行する際、このパラメータに値を渡します。 LLMChainで大規模言語モデルとプロンプトを統合します。

from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

llm = HuggingFacePipeline(pipeline=pipe)

template = """
{query}
"""

prompt = PromptTemplate(template=template, input_variables=['query'])

llm_chain = LLMChain(llm=llm, prompt=prompt)

推論

LLMChain.runで推論を実行します。このとき、テンプレートでパラメータになっていたqueryrunの引数で指定します。

llm_chain.run(query='アジャイルソフトウェア開発宣言の内容は')

ちなみに出力結果はこんな感じでした。

'「ソフトウェアは、目的を達成するために手段の1つとして存在しなくてはならない」という原則と、「反復可能な開発によってのみプロジェクトが成功するわけではない。そのプロジェクトは目的に達しないものである」(序文)との両点から成り立っています。「顧客のニーズを満たし(Price of Customer)」を目的に掲げている企業にとってこの宣言'