Transformers, LangChain & Chromaによるローカルのテキストデータを参照したテキスト生成
前回はHugging FaceのTransgormersとLangChainを用いたテキスト生成を実装しました。 今回はさらにChromaを用いて、ローカルのDB上のデータを参照して質問応答を行うテキスト生成を実装してみます。
ChromaはいわゆるベクトルDBの一種です。 ベクトルDBは埋め込みベクトルのような高次元のベクトルデータを扱うのに適したDBです。 ベクトルDBの実装には色々とありますが、今回はサーバを構築しなくても簡単に試せるChromaを使ってみました。
このベクトルDBと大規模言語モデル(LLM: Large Language Model)を用いたテキスト生成を組み合わせることで、ユーザからの質問に関連した文章をベクトルDBから取得して、その文章を基に回答を生成する、といったことが可能になります。 全体構成は次の図のようなイメージです。
インターネット上に公開していない(あるいは社外秘などで外部には公開できない)ドメイン固有の文書があり、生成AIを使ってそれらの文書をベースに質問応答を行いたい場合は、このような方法が有効になると思います。
実装の流れはこんな感じです。
- ベクトルDBの構築
- 学習済みモデルをロード
- タスクやモデルなどを指定してTransformers Pipelineを構築
- PipelineとPromptTemplateを指定してLangChainのLLMChainを構築
- 質問文を入力してベクトルDBから類似度の高い文章を検索
- 推論の実行(LLMChainに質問文と検索結果を与えて回答を生成する)
なお、今回のコードはこちらです。
ベクトルDBの構築
今回使用する文書はIPAが公開している「アジャイルソフトウェア開発宣言の読みとき方」です。 PDFファイルをページ単位で埋め込みベクトルに変換してChromaに格納します。
LangChainにはPDFファイルを読み込むPDFMinerLoader
があるので、それを使用して読み込みます。
PDFファイルを読み込むLoaderは他にもありますが、日本語が文字化けしたり、ページがきちんと認識できなかったりしたので、いくつか試した結果、PDFMinerLoader
を使うことにしました。
ちなみに、名前のとおり、裏ではPDFMiner(正確にはpdfminer.six)が動いています。
from langchain.document_loaders import PDFMinerLoader file = '000065601.pdf' loader = PDFMinerLoader(file) text = loader.load() pages = text[0].page_content.split('\x0c')
ファイルをページ単位に分割できたら、HuggingFaceのsentence_transformers埋め込みモデルを使って、テキストデータを埋め込みベクトルに変換し、Chromaに格納します。 今回も大規模言語モデルはサイバーエージェント社のOpenCALM-1Bを使います。
from langchain.embeddings.huggingface import HuggingFaceEmbeddings model_name = 'cyberagent/open-calm-1b' embeddings = HuggingFaceEmbeddings(model_name=model_name)
モデルをロードしたら、テキストデータを入力して埋め込みベクトルに変換します。
Chromaのオリジナルのインターフェースをそのまま使用しても良いですが、LangChainにChromaのラッパーがあるので、それを使うことにします。
langchain.vectorstores.Chroma
でChromaのクライアントを生成します。
Chroma
の引数embedding_function
に先ほどロードした言語モデルを指定します。
引数persist_directory
を指定すると、ストレージ上の指定した場所にデータが保存されます。
なお、引数persist_directory
を指定しない場合はインメモリとなります。
from langchain.vectorstores import Chroma vectordb = Chroma(embedding_function=embeddings, persist_directory='./db')
続いて、先ほど読み込んだ文書について、ページ単位でループして、add_texts
でDBに登録していきます。
このとき、引数texts
に渡したテキストデータが埋め込みベクトルに変換されて登録されます。
for i, page in enumerate(pages): if page == '': continue vectordb.add_texts(texts=[page], metadatas=[{'source': file}], ids=[f'id{i+1}'])
学習済みモデルをロード
学習済みモデルもサイバーエージェント社のOpenCALM-1Bを使います。 なお、埋め込みベクトルを生成した際に使用したモデルと同じものを使う必要があります。
from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name)
Pipelineを構築
続いてTransgormersのPipelineを構築します。
テキスト生成の場合はtransformers.pipeline
の引数task
にtext-generation
を指定し、model
とtokenizer
に先ほどロードした学習済みモデルとトークナイザーを指定します。
kwargs
の値はOpenCALM-1BのUsageを参考にしています。
from transformers import pipeline task='text-generation' kwargs = { 'max_new_tokens': 64, 'do_sample': True, 'temperature': 0.7, 'top_p': 0.9, 'repetition_penalty': 1.05, 'pad_token_id': tokenizer.pad_token_id } pipe = pipeline( task=task, model=model, tokenizer=tokenizer, device=device_id, torch_dtype=torch.float16, **kwargs )
LLMChainを構築
langchain.llms.HuggingFacePipeline
に先ほど構築したPipeline
を指定します。
from langchain.llms import HuggingFacePipeline llm = HuggingFacePipeline(pipeline=pipe)
また、PromptTemplate
で大規模言語モデルに指示を出すためのプロンプトのテンプレートを作成します。
テンプレートにはパラメータを含めることができ、{query}
のような形式で記述します。
のちほどテキスト生成を実行する際、このパラメータに値を渡します。
from langchain.llms import HuggingFacePipeline from langchain.prompts import PromptTemplate from langchain.chains import LLMChain llm = HuggingFacePipeline(pipeline=pipe) template = """ 質問: {query} 回答: """ prompt = PromptTemplate(template=template, input_variables=['query'])
LLMChain
で大規模言語モデルとプロンプトを統合します。
llm_chain = LLMChain(llm=llm, prompt=prompt)
質問文に関連する文章をベクトルDBから取得
質問文を入力して、ベクトルDB上で質問文と類似度の高いページを検索します。
similarity_search
の引数k
で上位何件までを返すのかを指定できます。
query = 'アジャイルソフトウェア開発宣言で、プロセスやツールよりも重視していることは?' docs = vectordb.similarity_search(query=query, k=5)
ちなみにここでdocs
の中身を確認してみると、質問文に関連したページが含まれていることが分かります。
推論
ベクトルDBの検索結果と元の質問文をLLMChain
に渡して、最終的な回答を生成します。
llm_chain.run(input_documents=docs, query=query)
出力結果はこんな感じになりました。
'アジャイルは「俊敏な」「型にはまらない」という2つの言葉で表現されますが、「柔軟性・流動性が重要だ。またその柔軟性を保証するための仕組みが重要である。」というのが答えです。「変化への対応は遅いかもしれないけれど,すぐに改善するんだ」。これがチームにとって大切なことです。”スピード”と簡単にいいますが'
この回答だと、ベクトルDBから取得した文章を活用しきれていないように見えますが、大規模言語モデルのパラメータ数や文書を分割してChromaに登録した際の粒度(今回はページ単位)も影響しているかと思います。 モデルを変えてみたり、文書の分割の粒度や分割方法を変えてみたりするなど、工夫の余地はありそうです。
Transformers PipelineとLangChainによるテキスト生成
前回は Hugging FaceのTransgormers Pipelineを用いたテキスト生成を実装しました。 今回はTransgormers Pipelineに加えてLangChainも使用したテキスト生成を実装してみます。
LangChain を利用すると大規模言語モデルを用いたアプリケーション開発が容易になります。 モデルやデータベース、データの取り込みなどが抽象化されており、統一されたインターフェースを用いて実装できます。 このLangChainとHugging Faceで公開されている学習済み大規模言語モデルを使えば、精度はさておき、無料でChatGPTライクなものが実装できるのでは?と思い、調べてみました。
ちなみにLangChainの使用例を検索すると、OpenAIのAPIを使用した実装例(※LangChainにはOpenAIのAPIのラッパーが用意されています)が多数出てきますが、Hugging FaceのTransformers Pipelineを使用した実装例を見つけられなかったので、試行錯誤しながらの実装となりました。
実装の流れはこんな感じです。
- 学習済みモデルをロード
- タスクやモデルなどを指定してTransformers Pipelineを構築
- PipelineとPromptTemplateを指定してLangChainのLLMChainを構築
- 推論の実行(LLMChainに入力を与えて出力を得る)
なお、今回のコードはこちらです。
学習済みモデルをロード
学習済みモデルとして、今回も、サイバーエージェント社が公開している、日本語データセットで学習済みのOpenCALM-1Bを使用します。
from transformers import AutoModelForCausalLM, AutoTokenizer pretrained_model_name = 'cyberagent/open-calm-1b' model = AutoModelForCausalLM.from_pretrained(pretrained_model_name, torch_dtype=torch.float16).to(device) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
Pipeline を構築
続いてパイプラインを構築します。
テキスト生成の場合はtransformers.pipeline
の引数 task にtext-generation
を指定し、model
とtokenizer
に、先ほどロードした学習済みモデルとトークナイザーを指定します。
kwargs
の値はOpenCALM-1BのUsageを参考にしています。
from transformers import pipeline task='text-generation' kwargs = { 'max_new_tokens': 64, 'do_sample': True, 'temperature': 0.7, 'top_p': 0.9, 'repetition_penalty': 1.05, 'pad_token_id': tokenizer.pad_token_id } pipe = pipeline( task=task, model=model, tokenizer=tokenizer, device=device_id, torch_dtype=torch.float16, **kwargs )
LangChain
LangChainのHuggingFacePipeline
に上で構築したPipelineを指定します。
また、PromptTemplate
で大規模言語モデルに指示を出すためのプロンプトのテンプレートを作成します。
テンプレートにはパラメータを含めることができ、{query}
のような形式で記述します。
のちほどテキスト生成を実行する際、このパラメータに値を渡します。
LLMChain
で大規模言語モデルとプロンプトを統合します。
from langchain.llms import HuggingFacePipeline from langchain.prompts import PromptTemplate from langchain.chains import LLMChain llm = HuggingFacePipeline(pipeline=pipe) template = """ {query} """ prompt = PromptTemplate(template=template, input_variables=['query']) llm_chain = LLMChain(llm=llm, prompt=prompt)
推論
LLMChain.run
で推論を実行します。このとき、テンプレートでパラメータになっていたquery
をrun
の引数で指定します。
llm_chain.run(query='アジャイルソフトウェア開発宣言の内容は')
ちなみに出力結果はこんな感じでした。
'「ソフトウェアは、目的を達成するために手段の1つとして存在しなくてはならない」という原則と、「反復可能な開発によってのみプロジェクトが成功するわけではない。そのプロジェクトは目的に達しないものである」(序文)との両点から成り立っています。「顧客のニーズを満たし(Price of Customer)」を目的に掲げている企業にとってこの宣言'
Transformaers Pipeline によるテキスト生成
Hugging Face の Transformers には推論を簡単に行うための Pipeline という仕組みがあります。 PyTorch や TensorFlow のような面倒なコードを書かずに、わずかなコーディング量で推論することができるので、ちょっとしたことを試すのにはとでも便利です。
Pipeline 自体は画像認識や音声認識など様々なタスクに使用できるようですが、今回は自然言語処理のタスク、特にテキスト生成を試してみます。
特定のモデルで Pipeline を使用する流れはこんな感じです。
- 学習済みモデルをロード
- タスクやモデルなどを指定して Pipeline を構築
- 推論の実行(パイプラインに入力を与えて出力を得る)
ちなみに、最も単純な場合は、明示的に学習済みモデルをロードせずにタスクを指定するだけでも良さそうですが、使用可能な言語が限られるなどの制限がありそうです。 今回は、最近公開された日本語の大規模言語モデルを使用して、日本語テキスト生成を試したかったので、上記のような流れで実装してみます。
なお、今回のコードはこちらです。
学習済みモデルをロード
学習済みモデルとして、今回は、サイバーエージェント社が公開している日本語データセットで学習済みの OpenCALM-1B を使用してみます。
from transformers import AutoModelForCausalLM, AutoTokenizer pretrained_model_name = 'cyberagent/open-calm-1b' model = AutoModelForCausalLM.from_pretrained(pretrained_model_name, torch_dtype=torch.float16).to(device) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
パイプラインを構築
続いてパイプラインを構築する。
テキスト生成の場合は transformers.pipeline
の引数 task
に text-generation
を指定し、model
と tokenizer
に、先ほどロードした学習済みモデルとトークナイザーを指定する。
kwargs
の値は OpenCALM-1B の Usage を参考にしました。
from transformers import pipeline task = 'text-generation' kwargs = { 'max_new_tokens': 64, 'do_sample': True, 'temperature': 0.7, 'top_p': 0.9, 'repetition_penalty': 1.05, 'pad_token_id': tokenizer.pad_token_id } generator = pipeline( task=task, model=model, tokenizer=tokenizer, device=device_id, torch_dtype=torch.float16, **kwargs )
推論の実行
引数にテキストを与えると、それに続くテキストが生成されます。 たったのこれだけ。なんて簡単なんだ!
generator('アジャイルソフトウェア開発宣言の内容は')
ちなみに出力結果はこんな感じでした。
[{'generated_text': 'アジャイルソフトウェア開発宣言の内容は、ソフトウェアの開発・運用に関する方針を簡潔にまとめたものです。\nこの文書には、「開発」や「生産」「販売」、「保守サービスなどの各業務プロセスにおいて『顧客満足』を追求していくこと」、そしてそれらを実現するための考え方として、『SMILE(Smile:笑顔で)』と書かれています。『信頼』『誠実さ'}]
使用する学習済みモデルやモデルのパラメータ数によって精度も変わると思います。 様々なモデルで試してみたいですね。
PyCharmでプロジェクトの改行コードを一括変換
PyCharmでプロジェクトの改行コードを一括変換する方法をいつも忘れてしまうのでメモ。
プロジェクトをクリックして選択した上で、メニューから、File → File Properties → Line Separators と選択すればOK。
※注:上記はWindows版です。
二値分類の評価指標
機械学習の二値分類の評価指標について、ついつい忘れがちなのでメモしておこう。
混同行列
二値分類の混同行列は次のとおり。
陰性と予測 | 陽性と予測 | |
---|---|---|
実際は陰性 | 真陰性 (True Negative; TN) | 偽陽性 (False Positive; FP) |
実際は陽性 | 偽陰性 (False Negative; FN) | 真陽性 (True Positive; TP) |
それぞれの件数を表形式で把握できるので視覚的にも分かりやすい。 以降でまとめている各種評価指標を計算する際にも使用する重要な情報。
正解率
正解率 (Accuracy) は次式で定義される。
全件中何件予測が正しかったかを表す指標になっている。 単純で分かりやすいが、陰性サンプル数と陽性サンプル数が不均衡な場合は適切な評価ができないので要注意。
適合率・再現率
上記の {TN, FP, FN, TP} を用いると適合率 (Precision)、再現率 (Recall) は次式で定義される。
つまりはこういうことだ。
適合率:陽性と予測したもののうち、実際に陽性だったものの割合
再現率:実際に陽性であるもののうち、陽性と予測できたものの割合
F1値
F1値 (F1 score, F1 measure) は適合率と再現率の調和平均で定義される。
ちなみに適合率も再現率も比率である。比率に対する平均なので、単なる平均ではなく、調和平均が適切だということ。
Django + PyTorch で画像認識 Web アプリを作る
DjangoとPyTorchで画像認識Webアプリを作ろうと思ったところ、Djangoのクラスベースビューを使いつつAjaxでグラフ描画するという例が見つからなかったので、試行錯誤して実装してみました。今回はその内容の共有をしたいと思います。
ソースコードはhttps://github.com/noriho137/image-classifierにアップしてあります。
はじめに
検証した環境は次のとおりです。
作るもの
画像ファイルをアップロードして画像認識を行い、認識結果をブラウザ上に表示するWebアプリケーションを作成します。 WebアプリのフレームワークにはDjangoを利用します。 画像認識にはPyTorchのImageNet学習済みモデルを利用します。 認識結果の上位$N$件までのラベル名と確率を表とグラフで表示します。 こんな感じです。
プロジェクト&アプリケーション作成
プロジェクト作成
まずはDjangoのプロジェクト作成。Django公式のチュートリアルによると django-admin startproject {プロジェクト名}
で作成できますが、こうすると {プロジェクト名}
と同じ名前で設定用ディレクトリが作成されてしまってややこしくなります。こんな感じ。
{プロジェクト名}/ ├─ {プロジェクト名}/ │ ├─ __init__.py │ ├─ asgi.py │ ├─ urls.py │ ├─ settings.py │ └─ wsgi.py └─ manage.py
これを回避するためにプロジェクト作成時に次のようなコマンドを実行します。最後のピリオドを忘れずに。これでカレントディレクトリにconfig
という名前で設定用ディレクトリが作成されます。
mkdir {プロジェクト名}/ cd {プロジェクト名}/ django-admin startproject config .
こんな感じのディレクトリ構成になります。
{プロジェクト名}/ ├─ config/ │ ├─ __init__.py │ ├─ asgi.py │ ├─ urls.py │ ├─ settings.py │ └─ wsgi.py └─ manage.py
アプリケーション作成
続いてアプリケーションの作成を行います。アプリケーションの作成コマンドは python manage.py startapp {アプリケーション名}
です。今回はclassifier
という名前にします。
python manage.py startapp classifier
この時点でのディレクトリ構成は次のようになっています。
image-recognition/ ├─ classifier/ │ ├─ migrations/ │ │ └─ __init__.py │ ├─ __init__.py │ ├─ admin.py │ ├─ apps.py │ ├─ models.py │ ├─ tests.py │ └─ views.py ├─ config/ │ ├─ __init__.py │ ├─ asgi.py │ ├─ settings.py │ ├─ urls.py │ └─ wsgi.py └─ manage.py
ちなみにここからファイルを作成して、最終的には次のような構成になります。
image-recognition/ ├─ classifier/ │ ├─ migrations/ │ │ └─ __init__.py │ ├─ admin.py │ ├─ apps.py │ ├─ forms.py │ ├─ models.py │ ├─ predictor.py │ ├─ tests.py │ ├─ urls.py │ └─ views.py ├─ config/ │ ├─ __init__.py │ ├─ asgi.py │ ├─ settings.py │ ├─ urls.py │ └─ wsgi.py ├─ pretrained/ ├─ static/ │ ├─ css/ │ │ └─ style.css │ └─ js/ │ ├─ ajax.js │ └─ barchart.js ├─ templates/ │ ├─ classifier.html │ └─ result.html ├─ .env └─ manage.py
全体的な設定
settings.py
config/settings.py
のINSTALLED_APPS
にclassifierアプリを追加します。
INSTALLED_APPS = [ 'django.contrib.admin', 'django.contrib.auth', 'django.contrib.contenttypes', 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles', # Local 'classifier.apps.ClassifierConfig', ]
TEMPLATES
の'DIRS'
の行を編集して'DIRS': [str(BASE_DIR.joinpath('templates'))]
にします。これによりテンプレートを一か所(この場合はtemplates/
配下)で管理することができるようになり、プロジェクト内にアプリケーションを複数作成した場合でも管理しやすくなります。
TEMPLATES = [ { 'BACKEND': 'django.template.backends.django.DjangoTemplates', # 'DIRS': [], 'DIRS': [str(BASE_DIR.joinpath('templates'))], 'APP_DIRS': True, 'OPTIONS': { 'context_processors': [ 'django.template.context_processors.debug', 'django.template.context_processors.request', 'django.contrib.auth.context_processors.auth', 'django.contrib.messages.context_processors.messages', ], }, }, ]
アップロードした画像をブラウザ上に表示したいので、MEDIA_URL
でメディアファイルの公開先を指定します。また、MEDIA_ROOT
でDjangoが画像ファイルを保存するディレクトリを指定します。
MEDIA_URL = '/media/' MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
ファイル.env
に記述した各種環境変数を読み込みます。
from environs import Env (省略) env = Env() env.read_env() (省略) SECRET_KEY = env.str('DJANGO_SECRET_KEY') (省略) DEBUG = env.bool('DJANGO_DEBUG', default=False) (省略) PRETRAINED_MODEL_NAME = env.str('PRETRAINED_MODEL_NAME') PRETRAINED_MODEL_PATH = env.str('PRETRAINED_MODEL_PATH') CLASS_INDEX = env.str('CLASS_INDEX') MAX_RANK = env.int('MAX_RANK')
ここで
SECRET_KEY
:Djangoの暗号化署名のキーPRETRAINED_MODEL_NAME
:PyTorchのImageNet学習済みモデルの名前(例えばvgg16
)PRETRAINED_MODEL_PATH
:PyTorchのImageNet学習済みモデルのパス(例えば./pretrained/vgg16-397923af.pth
)CLASS_INDEX
:クラス番号とラベル名の関係が定義されたファイルのパス(例えば./pretrained/imagenet_class_index.json
)MAX_RANK
:最大何件まで結果を表示するかを指定
です。
環境変数
これらの変数にセットしている値は、ファイル.env
を作成しておき、そこに環境変数として定義しておきます。
# Django settings export DJANGO_SECRET_KEY={Django secret key generated above} export DJANGO_DEBUG=True # Classification model export PRETRAINED_MODEL_NAME=vgg16 export PRETRAINED_MODEL_PATH=./pretrained/vgg16-397923af.pth export CLASS_INDEX=./pretrained/imagenet_class_index.json export MAX_RANK=5
上記の{Django secret key generated above}
の部分は次のコマンドで生成した値を記述します(Django for Professionalsを参考にしました)。暗号化署名のキーなので各自で生成して、公開しないように注意しなければなりません。間違ってGitHubなどに掲載しないように。
python -c "import secrets; print(secrets.token_urlsafe(38))"
PRETRAINED_MODEL_PATH
で指定している学習済みモデルは例えば下記があります。事前にダウンロードしておきます。
CLASS_INDEX
で指定しているjsonファイルは下記からダウンロードできます。
https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json
urls.py
config/urls.py
にルーティングを記述します。http://{IP Address}:{Port #}/
へのアクセスがあったら、classifier/urls.py
に処理を渡すようにします。
from django.contrib import admin from django.contrib.staticfiles.urls import static from django.urls import include, path from . import settings urlpatterns = [ path('admin/', admin.site.urls), path('', include('classifier.urls')), ] urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)
classifier/urls.py
でビューClassifierView
(これから作成するもの)と結びつけます。
from django.urls import path from .views import ClassifierView app_name = 'classifier' urlpatterns = [ path('', ClassifierView.as_view(), name='classifier'), ]
Model
それではいよいよ処理の中身を実装していきます。DjangoのアーキテクチャはMVT(Model View Template)なので、その順に説明します。
classifierアプリケーションのモデルをclassifier/models.py
に実装します。モデルはDBに画像ファイルを保存するために、django.db.models.ImageFiled
を使ってフィールドを定義します。
from django.db import models class ImageModel(models.Model): image = models.ImageField(verbose_name='Image file', upload_to='images')
ちなみに厳密にはDBに格納されるのは画像ファイルのパスであり、ファイルの実体はconfig/settings.py
にてMEDIA_ROOT
で指定したディレクトリ(のさらに下にImageField
の引数upload_to
で指定したディレクトリが作成されて、その配下)に格納されます。
Form
フォームで入力するのは画像ファイルだけなので、フォーム用のクラスはdjango.forms.ModelForm
を継承して作成することにします。そうすることでモデルのフィールド定義を流用でき、フォームの記述がコンパクトになるメリットがあります(今回はフィールドが1個しかないのであまり変わりませんが)。
classifier/forms.py
from django import forms from .models import ImageModel class ImageForm(forms.ModelForm): class Meta: model = ImageModel fields = ('image',) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) for field in self.fields.values(): field.widget.attrs['class'] = 'form-control'
画面にBootstrapを利用しているので、__init__
の中でBootstrapのスタイルを指定しています。今回はBootstrapの説明は割愛します。
ちなみに、フォームに入力した画像データはAjaxのPOSTリクエストでサーバに送信して、ビューで処理します。ビューで画像認識を行い、その結果をAjaxで返すようにしました。
View
classifierアプリケーションのビューをファイルclassifier/views.py
に実装します。今回はフォームを扱っているので、django.views.generic.FormView
を継承して作成します。フィールドtemplate_name
でテンプレートと、form_class
でフォームと結びつけます。また、フォームのバリデーションがOKだった場合に実行されるメソッドform_valid
とNGだった場合に実行されるメソッドform_invalid
をオーバーライドします。ビューの処理が実行される契機としては、初期表示時のGET処理と画像アップロード時のPOST処理があります。
- 初期表示時(GET)
- 単純に
template_name
で指定したテンプレートを描画
- 単純に
- 画像アップロード時(POST)
classifier/views.py
import json import logging from django.conf import settings from django.http import HttpResponse from django.urls import reverse_lazy from django.views import generic from .forms import ImageForm from .models import ImageModel from .predictor import predict logger = logging.getLogger(__name__) class ClassifierView(generic.FormView): template_name = 'classifier.html' form_class = ImageForm success_url = reverse_lazy('classifier:classifier') def form_valid(self, form): logger.info('form_valid start') if self.request.is_ajax(): logger.info('ajax request') image = form.save(commit=False) image.save() logger.info('saved') max_id = ImageModel.objects.latest('id').id obj = ImageModel.objects.get(id=max_id) top_k_predictions_tmp = predict(obj.image, k=settings.MAX_RANK) logger.info(top_k_predictions_tmp) top_k_predictions = [] for prediction in top_k_predictions_tmp: top_k_predictions.append({'label': prediction[1], 'probability': float(prediction[2])}) logger.info(self.template_name) logger.info(form) logger.info(obj.image.url) logger.info(top_k_predictions) data_json = json.dumps({'top_k_predictions': top_k_predictions, 'image_url': obj.image.url}) return HttpResponse(data_json, content_type='application/json') else: logger.info('not ajax request') return super().form_valid(form) def form_invalid(self, form): logger.info('form_invalid start') if self.request.is_ajax(): logger.info('ajax request') response = HttpResponse({'error': 'Unprocessable Entity'}) response.status_code = 422 return response else: logger.info('not ajax request') return super().form_invalid(form)
画像認識の処理(PyTorch)
画像認識はPyTorchを用いて実装します。
classifier/predictor.py
from django.conf import settings from django.core.cache import cache import json import logging from PIL import Image import torch import torch.nn as nn from torchvision import models from torchvision import transforms logger = logging.getLogger(__name__) def make_model(model_name): logger.info(f'create {model_name}') method = getattr(models, model_name) model = method() return model def predict(image_file, k=5): logger.info(image_file) logger.info(k) img = Image.open(image_file) # Preprocess preprocess_img = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) img_transformed = preprocess_img(img) inputs = img_transformed.unsqueeze_(0) # Create model cache_key = 'model' model = cache.get(cache_key) if model is None: logger.info('model is not cached') model = make_model(settings.PRETRAINED_MODEL_NAME) model.load_state_dict(torch.load(settings.PRETRAINED_MODEL_PATH), strict=False) cache.set(cache_key, model, None) else: logger.info('use cached model') # Set evaluation mode model.eval() # Input the image to the model and input the output to softmax out = model(inputs) out = nn.functional.softmax(out, dim=1) # Get top k probabilities and class indices top_k_values = out.topk(k).values.detach().numpy()[0] top_k_indices = out.topk(k).indices.detach().numpy()[0] # Create a dictionary that converts class indices to label names class_index = json.load(open(settings.CLASS_INDEX, 'r')) labels = {int(key): value for key, value in class_index.items()} # Get top k probabilities and labels top_k_label_prob = [(labels[idx][0], labels[idx][1], prob) for prob, idx in zip(top_k_values, top_k_indices)] return top_k_label_prob
前処理
まずtransforms.Compose
で前処理を行います。引数に渡しているパラメータは適当に決めた訳ではなく、学習済みモデルで実施した前処理と一致させる必要があるのでこのような値になっています。
モデル生成と学習済み重みの復元
画像認識モデルはサイズが大きいので、リクエストのたびにストレージから読み込んでいたのでは処理に時間がかかってしまいます。そこで、Djangoのキャッシュの仕組みを利用して、画像認識モデルをキャッシュに保持するようにします。cache.set
でキャッシュに保持し、cache.get
でキャッシュから読み込むことができます。
キャッシュに存在しない場合はmake_model
でモデルを生成していますが、ここは少々トリッキーです。torchvision.models
の中には様々な画像認識モデルのクラスが定義されており、それらのクラスのオブジェクトを生成する関数も定義されています。method = getattr(models, model_name)
でオブジェクトmodels
の属性model_name
を取得できます。このとき取得できるのがモデルを生成する関数なので、引き続いてmethod()
でモデルを生成することができます。
例えば、AlexNetモデルを生成する関数alexnet()
がありますが、method = getattr(models, 'alexnet')
とすると、AlexNetモデルを生成する関数alexnet
が取得でき、method()
とすることでalexnet()
が実行されます。
設定ファイルで様々なモデルを指定できるようにするためにこのような実装にしました。
モデルが生成出来たら、model.load_state_dict
でストレージから学習済みの重みを読み込んでモデルにセットします。
モデルを生成する際に学習済みモデルをダウンロードして取得することもできますが、ダウンロードに時間がかかるので、今回は予めダウンロードしておいて、そのパスを指定する方針にしました。
推論
model.eval()
で推論モードにセットします。
後はmodel(inputs)
のようにモデルに前処理済みの画像データinput
を与えれば、推論結果が得られます。
最終的には各カテゴリの確率が欲しいのでソフトマックス関数を使用します。
出力結果から上位$K$個を抜き出して、タプル(インデックス, ラベル名, 確率)
のリストを作成して呼び出し元に返却します。
Template
classifierアプリケーションのテンプレートをtemplates/classifier.html
に作成します。
templates/classifier.html
<!DOCTYPE html> {% load static %} <html lang="ja"> <head> <meta charset="UTF-8"> <title>Image Classifier</title> <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" integrity="sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z" crossorigin="anonymous"> <script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.9.3/Chart.min.js"></script> </head> <body class="container"> <div class="card mt-4 mb-4"> <div class="card-header"> <h1 class="display-4 text-center">Image Classifier</h1> </div> <div class="card-body"> <div class="mx-2 my-4"> <form method="post" enctype="multipart/form-data"> {% csrf_token %} <div class="input-group"> <div class="custom-file"> <input type="file" class="custom-file-input" id="inputGroupFile" name="{{ form.image.html_name }}" required> <label class="custom-file-label" for="inputGroupFile" aria-describedby="inputGroupFileAddon">Choose image file</label> </div> <div class="input-group-append"> <button class="btn btn-primary" type="submit" id="inputGroupFileAddon">Submit</button> </div> </div> </form> </div> <div class="mx-4 my-2"> {# replace by Ajax #} <div class="row"> <div class="col-lg-5"> <div> <img id="submittedImage" /> </div> <div> <table id="rankingTable"></table> </div> </div> <div class="col-lg-7"> <div> <canvas id="barChart" height="0"></canvas> </div> </div> </div> </div> </div> </div> <script src="https://code.jquery.com/jquery-3.5.1.min.js" integrity="sha256-9/aliU8dGd2tb6OSsuzixeV4y/faTqgFtohetphbbj0=" crossorigin="anonymous"></script> <script src="https://cdn.jsdelivr.net/npm/bs-custom-file-input/dist/bs-custom-file-input.js"></script> <script type="text/javascript" src="{% static 'js/ajax.js' %}"></script> <script type="text/javascript" src="{% static 'js/barchart.js' %}"></script> <script> bsCustomFileInput.init(); </script> </body> </html>
ここでのポイントは2つ。
画像ファイル送信用のフォームを作成
Djangoでは{{ form.as_p }}
のようにすればフォームを表示できるのですが、見栄えをよくするためにBootstrapを利用しています。当初はDjangoのテンプレートタグのみで描画しようと考え、いろいろと試行錯誤しましたが、細かく整えるのは無理そうとの結論に至りました。Djangoの機能を使うことに固執しすぎるのも良くないですよね。<form method="post" enctype="multipart/form-data"> {% csrf_token %} <div class="input-group"> <div class="custom-file"> <input type="file" class="custom-file-input" id="inputGroupFile" name="{{ form.image.html_name }}" required> <label class="custom-file-label" for="inputGroupFile" aria-describedby="inputGroupFileAddon">Choose image file</label> </div> <div class="input-group-append"> <button class="btn btn-primary" type="submit" id="inputGroupFileAddon">Submit</button> </div> </div> </form>
Ajaxで描画する領域を用意
Ajaxでレスポンスを受け取ったら、その内容をここに埋め込むという方針です。<img id="submittedImage" />
:アップロードした画像を埋め込む場所<table id="rankingTable"></table>
:予測したラベルとその確率の一覧表を埋め込む場所<canvas id="barChart" height="0"></canvas>
:予測したラベルとその確率の水平棒グラフを埋め込む場所
<div class="row"> <div class="col-lg-5"> <div> <img id="submittedImage" /> </div> <div> <table id="rankingTable"></table> </div> </div> <div class="col-lg-7"> <div> <canvas id="barChart" height="0"></canvas> </div> </div> </div>
Ajax
Ajaxの実装にはjQueryを利用しています。formでsubmitされたのを契機にして、POSTリクエストで画像ファイルを送信するようにします。$.ajax
にてリクエストの設定をします。また、done
(正常時の処理)、fail
(異常時の処理)、always
(正常時でも異常時でも実行される後処理)をそれぞれ実装します。処理概要は次のとおりです:
done
- レスポンスを受け取り、レスポンスに含まれるJSONをパース
- パースした内容を基に、画像の表示、表の作成、グラフの描画を実施
fail
- ブラウザ上にアラートでエラーを通知
always
- 特になし
ちなみにgetCookie
の実装はDjangoのドキュメントにあるCross Site Request Forgery protectionを流用させてもらいました。
static/js/ajax.js
/* * This function is from Django official web site: * https://docs.djangoproject.com/en/3.1/ref/csrf/ */ function getCookie(name) { let cookieValue = null; if (document.cookie && document.cookie !== '') { const cookies = document.cookie.split(';'); for (let i = 0; i < cookies.length; i++) { const cookie = cookies[i].trim(); // Does this cookie string begin with the name we want? if (cookie.substring(0, name.length + 1) === (name + '=')) { cookieValue = decodeURIComponent(cookie.substring(name.length + 1)); break; } } } return cookieValue; } const csrftoken = getCookie('csrftoken'); function csrfSafeMethod(method) { // these HTTP methods do not require CSRF protection return (/^(GET|HEAD|OPTIONS|TRACE)$/.test(method)); } $.ajaxSetup({ beforeSend: function(xhr, settings) { if (!csrfSafeMethod(settings.type) && !this.crossDomain) { xhr.setRequestHeader('X-CSRFToken', csrftoken); } } }); $('form').submit(function(event) { event.preventDefault(); var form = $(this); var formData = new FormData($('form').get(0)); $.ajax({ type: 'POST', url: form.prop('action'), method: form.prop('method'), data: formData, processData: false, contentType: false, timeout: 10000, dataType: 'text', }).done(function(response) { console.log('done'); var parsedResponse = JSON.parse(response); // show image showImage(parsedResponse.image_url); // create table createTable(parsedResponse.top_k_predictions); // draw chart drawChart(parsedResponse.top_k_predictions); }).fail(function(jqXHR, textStatus, errorThrown) { console.log('fail'); console.log(jqXHR.status); console.log(textStatus); console.log(errorThrown); alert(jqXHR.status + ' Error: ' + errorThrown); }).always(function() { console.log('always'); }); }); function showImage(image_url) { console.log('showImage start'); var img = $('#submittedImage'); img.attr('src', image_url); img.addClass('img-fluid mx-auto d-block mt-4 mb-4') console.log('showImage end'); } function createTable(top_k_predictions) { console.log('createTable start'); var table = $('#rankingTable'); table.empty(); table.append('<caption>Prediction result</caption>'); table.append('<thead><th>Rank</th><th>Label</th><th>Probability</th></thead>'); table.append('<tbody></tbody>'); for (var i = 0; i < top_k_predictions.length; i++) { var row = $("<tr></tr>"); row.append($('<td align="right"></td>').text(i + 1)); row.append($('<td></td>').text(top_k_predictions[i].label)); row.append($('<td align="right"></td>').text(top_k_predictions[i].probability.toFixed(3))); table.append(row); } table.addClass('table bg-light'); console.log('createTable end'); }
グラフ描画
グラフ描画にはChart.jsを利用します。
static/js/barchart.js
function setData(top_k_predictions) { var labels = []; var probabilities = []; for (var i = 0; i < top_k_predictions.length; i++) { labels.push(top_k_predictions[i].label); probabilities.push(top_k_predictions[i].probability); } var data = { labels: labels, fontSize: 18, datasets: [{ label: 'probability', data: probabilities, backgroundColor: 'rgba(54, 162, 235, 0.7)' }] }; return data; } function setOptions() { var options = { responsive: true, maintainAspectRatio: true, scales: { xAxes: [{ gridLines: { display: true }, scaleLabel: { display: true, labelString: 'probability', fontSize: 14 }, ticks: { min: 0.0, max: 1.0, fontSize: 14 } }], yAxes: [{ scaleLabel: { display: true, labelString: 'label', fontSize: 14 }, ticks: { autoSkip: false, fontSize: 14 } }] } }; return options; } function drawChart(top_k_predictions) { console.log('drawChart start'); var ctx = document.getElementById('barChart'); var data = setData(top_k_predictions); var options = setOptions(); // set height according to the number of data ctx.height = 18 * data.datasets[0]['data'].length + 60; // destroy chart instance if chart instance already exists if (window.myBarChart) { window.myBarChart.destroy(); } window.myBarChart = new Chart(ctx, { type: 'horizontalBar', data: data, options: options }); $('#barChart').addClass('mt-4'); console.log('drawChart end'); }
実行
パッケージのインストールや実行手順はhttps://github.com/noriho137/image-classifierを参考にしてください。