noriho137’s diary

機械学習, 時々Web

Transformers, LangChain & Chromaによるローカルのテキストデータを参照したテキスト生成

前回Hugging FaceTransgormersLangChainを用いたテキスト生成を実装しました。 今回はさらにChromaを用いて、ローカルのDB上のデータを参照して質問応答を行うテキスト生成を実装してみます。

ChromaはいわゆるベクトルDBの一種です。 ベクトルDBは埋め込みベクトルのような高次元のベクトルデータを扱うのに適したDBです。 ベクトルDBの実装には色々とありますが、今回はサーバを構築しなくても簡単に試せるChromaを使ってみました。

このベクトルDBと大規模言語モデル(LLM: Large Language Model)を用いたテキスト生成を組み合わせることで、ユーザからの質問に関連した文章をベクトルDBから取得して、その文章を基に回答を生成する、といったことが可能になります。 全体構成は次の図のようなイメージです。

全体構成の概要

インターネット上に公開していない(あるいは社外秘などで外部には公開できない)ドメイン固有の文書があり、生成AIを使ってそれらの文書をベースに質問応答を行いたい場合は、このような方法が有効になると思います。

実装の流れはこんな感じです。

  1. ベクトルDBの構築
  2. 学習済みモデルをロード
  3. タスクやモデルなどを指定してTransformers Pipelineを構築
  4. PipelineとPromptTemplateを指定してLangChainのLLMChainを構築
  5. 質問文を入力してベクトルDBから類似度の高い文章を検索
  6. 推論の実行(LLMChainに質問文と検索結果を与えて回答を生成する)

なお、今回のコードはこちらです。

github.com

ベクトルDBの構築

今回使用する文書はIPAが公開している「アジャイルソフトウェア開発宣言の読みとき方」です。 PDFファイルをページ単位で埋め込みベクトルに変換してChromaに格納します。

LangChainにはPDFファイルを読み込むPDFMinerLoaderがあるので、それを使用して読み込みます。 PDFファイルを読み込むLoaderは他にもありますが、日本語が文字化けしたり、ページがきちんと認識できなかったりしたので、いくつか試した結果、PDFMinerLoaderを使うことにしました。 ちなみに、名前のとおり、裏ではPDFMiner(正確にはpdfminer.six)が動いています。

from langchain.document_loaders import PDFMinerLoader

file = '000065601.pdf'

loader = PDFMinerLoader(file)
text = loader.load()
pages = text[0].page_content.split('\x0c')

ファイルをページ単位に分割できたら、HuggingFaceのsentence_transformers埋め込みモデルを使って、テキストデータを埋め込みベクトルに変換し、Chromaに格納します。 今回も大規模言語モデルサイバーエージェント社のOpenCALM-1Bを使います。

from langchain.embeddings.huggingface import HuggingFaceEmbeddings

model_name = 'cyberagent/open-calm-1b'
embeddings = HuggingFaceEmbeddings(model_name=model_name)

モデルをロードしたら、テキストデータを入力して埋め込みベクトルに変換します。 Chromaのオリジナルのインターフェースをそのまま使用しても良いですが、LangChainにChromaのラッパーがあるので、それを使うことにします。 langchain.vectorstores.ChromaでChromaのクライアントを生成します。 Chromaの引数embedding_functionに先ほどロードした言語モデルを指定します。 引数persist_directoryを指定すると、ストレージ上の指定した場所にデータが保存されます。 なお、引数persist_directoryを指定しない場合はインメモリとなります。

from langchain.vectorstores import Chroma

vectordb = Chroma(embedding_function=embeddings, persist_directory='./db')

続いて、先ほど読み込んだ文書について、ページ単位でループして、add_textsでDBに登録していきます。 このとき、引数textsに渡したテキストデータが埋め込みベクトルに変換されて登録されます。

for i, page in enumerate(pages):
  if page == '':
    continue
  vectordb.add_texts(texts=[page],
                     metadatas=[{'source': file}],
                     ids=[f'id{i+1}'])

学習済みモデルをロード

学習済みモデルもサイバーエージェント社のOpenCALM-1Bを使います。 なお、埋め込みベクトルを生成した際に使用したモデルと同じものを使う必要があります。

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Pipelineを構築

続いてTransgormersPipelineを構築します。 テキスト生成の場合はtransformers.pipelineの引数tasktext-generationを指定し、modeltokenizerに先ほどロードした学習済みモデルとトークナイザーを指定します。 kwargsの値はOpenCALM-1BのUsageを参考にしています。

from transformers import pipeline

task='text-generation'

kwargs = {
    'max_new_tokens': 64,
    'do_sample': True,
    'temperature': 0.7,
    'top_p': 0.9,
    'repetition_penalty': 1.05,
    'pad_token_id': tokenizer.pad_token_id
}

pipe = pipeline(
    task=task,
    model=model,
    tokenizer=tokenizer,
    device=device_id,
    torch_dtype=torch.float16,
    **kwargs
)

LLMChainを構築

langchain.llms.HuggingFacePipelineに先ほど構築したPipelineを指定します。

from langchain.llms import HuggingFacePipeline

llm = HuggingFacePipeline(pipeline=pipe)

また、PromptTemplateで大規模言語モデルに指示を出すためのプロンプトのテンプレートを作成します。 テンプレートにはパラメータを含めることができ、{query}のような形式で記述します。 のちほどテキスト生成を実行する際、このパラメータに値を渡します。

from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

llm = HuggingFacePipeline(pipeline=pipe)

template = """
質問: {query}

回答:
"""
prompt = PromptTemplate(template=template, input_variables=['query'])

LLMChainで大規模言語モデルとプロンプトを統合します。

llm_chain = LLMChain(llm=llm, prompt=prompt)

質問文に関連する文章をベクトルDBから取得

質問文を入力して、ベクトルDB上で質問文と類似度の高いページを検索します。 similarity_searchの引数kで上位何件までを返すのかを指定できます。

query = 'アジャイルソフトウェア開発宣言で、プロセスやツールよりも重視していることは?'
docs = vectordb.similarity_search(query=query, k=5)

ちなみにここでdocsの中身を確認してみると、質問文に関連したページが含まれていることが分かります。

推論

ベクトルDBの検索結果と元の質問文をLLMChainに渡して、最終的な回答を生成します。

llm_chain.run(input_documents=docs, query=query)

出力結果はこんな感じになりました。

'アジャイルは「俊敏な」「型にはまらない」という2つの言葉で表現されますが、「柔軟性・流動性が重要だ。またその柔軟性を保証するための仕組みが重要である。」というのが答えです。「変化への対応は遅いかもしれないけれど,すぐに改善するんだ」。これがチームにとって大切なことです。”スピード”と簡単にいいますが'

この回答だと、ベクトルDBから取得した文章を活用しきれていないように見えますが、大規模言語モデルのパラメータ数や文書を分割してChromaに登録した際の粒度(今回はページ単位)も影響しているかと思います。 モデルを変えてみたり、文書の分割の粒度や分割方法を変えてみたりするなど、工夫の余地はありそうです。

Transformers PipelineとLangChainによるテキスト生成

前回は Hugging FaceのTransgormers Pipelineを用いたテキスト生成を実装しました。 今回はTransgormers Pipelineに加えてLangChainも使用したテキスト生成を実装してみます。

LangChain を利用すると大規模言語モデルを用いたアプリケーション開発が容易になります。 モデルやデータベース、データの取り込みなどが抽象化されており、統一されたインターフェースを用いて実装できます。 このLangChainとHugging Faceで公開されている学習済み大規模言語モデルを使えば、精度はさておき、無料でChatGPTライクなものが実装できるのでは?と思い、調べてみました。

ちなみにLangChainの使用例を検索すると、OpenAIのAPIを使用した実装例(※LangChainにはOpenAIのAPIのラッパーが用意されています)が多数出てきますが、Hugging FaceのTransformers Pipelineを使用した実装例を見つけられなかったので、試行錯誤しながらの実装となりました。

実装の流れはこんな感じです。

  1. 学習済みモデルをロード
  2. タスクやモデルなどを指定してTransformers Pipelineを構築
  3. PipelineとPromptTemplateを指定してLangChainのLLMChainを構築
  4. 推論の実行(LLMChainに入力を与えて出力を得る)

なお、今回のコードはこちらです。

github.com

学習済みモデルをロード

学習済みモデルとして、今回も、サイバーエージェント社が公開している、日本語データセットで学習済みのOpenCALM-1Bを使用します。

from transformers import AutoModelForCausalLM, AutoTokenizer

pretrained_model_name = 'cyberagent/open-calm-1b'

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)

Pipeline を構築

続いてパイプラインを構築します。 テキスト生成の場合はtransformers.pipelineの引数 task にtext-generationを指定し、modeltokenizerに、先ほどロードした学習済みモデルとトークナイザーを指定します。 kwargsの値はOpenCALM-1BのUsageを参考にしています。

from transformers import pipeline

task='text-generation'

kwargs = {
    'max_new_tokens': 64,
    'do_sample': True,
    'temperature': 0.7,
    'top_p': 0.9,
    'repetition_penalty': 1.05,
    'pad_token_id': tokenizer.pad_token_id
}

pipe = pipeline(
    task=task,
    model=model,
    tokenizer=tokenizer,
    device=device_id,
    torch_dtype=torch.float16,
    **kwargs
)

LangChain

LangChainのHuggingFacePipelineに上で構築したPipelineを指定します。 また、PromptTemplateで大規模言語モデルに指示を出すためのプロンプトのテンプレートを作成します。 テンプレートにはパラメータを含めることができ、{query}のような形式で記述します。 のちほどテキスト生成を実行する際、このパラメータに値を渡します。 LLMChainで大規模言語モデルとプロンプトを統合します。

from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain

llm = HuggingFacePipeline(pipeline=pipe)

template = """
{query}
"""

prompt = PromptTemplate(template=template, input_variables=['query'])

llm_chain = LLMChain(llm=llm, prompt=prompt)

推論

LLMChain.runで推論を実行します。このとき、テンプレートでパラメータになっていたqueryrunの引数で指定します。

llm_chain.run(query='アジャイルソフトウェア開発宣言の内容は')

ちなみに出力結果はこんな感じでした。

'「ソフトウェアは、目的を達成するために手段の1つとして存在しなくてはならない」という原則と、「反復可能な開発によってのみプロジェクトが成功するわけではない。そのプロジェクトは目的に達しないものである」(序文)との両点から成り立っています。「顧客のニーズを満たし(Price of Customer)」を目的に掲げている企業にとってこの宣言'

Transformaers Pipeline によるテキスト生成

Hugging Face の Transformers には推論を簡単に行うための Pipeline という仕組みがあります。 PyTorch や TensorFlow のような面倒なコードを書かずに、わずかなコーディング量で推論することができるので、ちょっとしたことを試すのにはとでも便利です。

Pipeline 自体は画像認識や音声認識など様々なタスクに使用できるようですが、今回は自然言語処理のタスク、特にテキスト生成を試してみます。

特定のモデルで Pipeline を使用する流れはこんな感じです。

  1. 学習済みモデルをロード
  2. タスクやモデルなどを指定して Pipeline を構築
  3. 推論の実行(パイプラインに入力を与えて出力を得る)

ちなみに、最も単純な場合は、明示的に学習済みモデルをロードせずにタスクを指定するだけでも良さそうですが、使用可能な言語が限られるなどの制限がありそうです。 今回は、最近公開された日本語の大規模言語モデルを使用して、日本語テキスト生成を試したかったので、上記のような流れで実装してみます。

なお、今回のコードはこちらです。

github.com

学習済みモデルをロード

学習済みモデルとして、今回は、サイバーエージェント社が公開している日本語データセットで学習済みの OpenCALM-1B を使用してみます。

from transformers import AutoModelForCausalLM, AutoTokenizer

pretrained_model_name = 'cyberagent/open-calm-1b'

model = AutoModelForCausalLM.from_pretrained(pretrained_model_name, torch_dtype=torch.float16).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)

パイプラインを構築

続いてパイプラインを構築する。 テキスト生成の場合は transformers.pipeline の引数 tasktext-generation を指定し、modeltokenizer に、先ほどロードした学習済みモデルとトークナイザーを指定する。 kwargs の値は OpenCALM-1BUsage を参考にしました。

from transformers import pipeline

task = 'text-generation'

kwargs = {
    'max_new_tokens': 64,
    'do_sample': True,
    'temperature': 0.7,
    'top_p': 0.9,
    'repetition_penalty': 1.05,
    'pad_token_id': tokenizer.pad_token_id
}

generator = pipeline( task=task, model=model, tokenizer=tokenizer, device=device_id, torch_dtype=torch.float16, **kwargs )

推論の実行

引数にテキストを与えると、それに続くテキストが生成されます。 たったのこれだけ。なんて簡単なんだ!

generator('アジャイルソフトウェア開発宣言の内容は')

ちなみに出力結果はこんな感じでした。

[{'generated_text': 'アジャイルソフトウェア開発宣言の内容は、ソフトウェアの開発・運用に関する方針を簡潔にまとめたものです。\nこの文書には、「開発」や「生産」「販売」、「保守サービスなどの各業務プロセスにおいて『顧客満足』を追求していくこと」、そしてそれらを実現するための考え方として、『SMILE(Smile:笑顔で)』と書かれています。『信頼』『誠実さ'}]

使用する学習済みモデルやモデルのパラメータ数によって精度も変わると思います。 様々なモデルで試してみたいですね。

PyCharmでプロジェクトの改行コードを一括変換

PyCharmでプロジェクトの改行コードを一括変換する方法をいつも忘れてしまうのでメモ。

プロジェクトをクリックして選択した上で、メニューから、File → File Properties → Line Separators と選択すればOK。

PyCharmの画面

※注:上記はWindows版です。

二値分類の評価指標

機械学習の二値分類の評価指標について、ついつい忘れがちなのでメモしておこう。

混同行列

二値分類の混同行列は次のとおり。

陰性と予測 陽性と予測
実際は陰性 真陰性 (True Negative; TN) 偽陽性 (False Positive; FP)
実際は陽性 偽陰性 (False Negative; FN) 真陽性 (True Positive; TP)

それぞれの件数を表形式で把握できるので視覚的にも分かりやすい。 以降でまとめている各種評価指標を計算する際にも使用する重要な情報。

正解率

正解率 (Accuracy) は次式で定義される。


\begin{align}
\text{Accuracy} &=\frac{\text{TN} + \text{TP}}{\text{TN} + \text{FP} + \text{FN} + \text{TP}} \\
\end{align}

全件中何件予測が正しかったかを表す指標になっている。 単純で分かりやすいが、陰性サンプル数と陽性サンプル数が不均衡な場合は適切な評価ができないので要注意。

適合率・再現率

上記の {TN, FP, FN, TP} を用いると適合率 (Precision)、再現率 (Recall) は次式で定義される。


\begin{align}
\text{Precision} &=\frac{\text{TP}}{\text{TP} + \text{FP}} \\
\text{Recall} &=\frac{\text{TP}}{\text{TP} + \text{FN}}
\end{align}

つまりはこういうことだ。

  • 適合率:陽性と予測したもののうち、実際に陽性だったものの割合

  • 再現率:実際に陽性であるもののうち、陽性と予測できたものの割合

F1値

F1値 (F1 score, F1 measure) は適合率と再現率の調和平均で定義される。


\begin{align}
F1 &= \frac{2}{\frac{1}{\text{Precison}} + \frac{1}{\text{Recall}}} = \frac{2 \times \text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}}
\end{align}

ちなみに適合率も再現率も比率である。比率に対する平均なので、単なる平均ではなく、調和平均が適切だということ。

Django + PyTorch で画像認識 Web アプリを作る

DjangoとPyTorchで画像認識Webアプリを作ろうと思ったところ、Djangoのクラスベースビューを使いつつAjaxでグラフ描画するという例が見つからなかったので、試行錯誤して実装してみました。今回はその内容の共有をしたいと思います。

ソースコードhttps://github.com/noriho137/image-classifierにアップしてあります。

  • Django

    • クラスベースビューでAjax

    • Ajaxでグラフ描画(グラフ描画自体はChart.jsを利用)

  • PyTorch

    • ImageNet学習済みモデルを利用

はじめに

検証した環境は次のとおりです。

作るもの

画像ファイルをアップロードして画像認識を行い、認識結果をブラウザ上に表示するWebアプリケーションを作成します。 WebアプリのフレームワークにはDjangoを利用します。 画像認識にはPyTorchのImageNet学習済みモデルを利用します。 認識結果の上位$N$件までのラベル名と確率を表とグラフで表示します。 こんな感じです。

f:id:noriho137:20220211190216p:plain
認識結果イメージ

プロジェクト&アプリケーション作成

プロジェクト作成

まずはDjangoのプロジェクト作成。Django公式のチュートリアルによると django-admin startproject {プロジェクト名} で作成できますが、こうすると {プロジェクト名} と同じ名前で設定用ディレクトリが作成されてしまってややこしくなります。こんな感じ。

{プロジェクト名}/
├─ {プロジェクト名}/
│  ├─ __init__.py
│  ├─ asgi.py
│  ├─ urls.py
│  ├─ settings.py
│  └─ wsgi.py
└─ manage.py

これを回避するためにプロジェクト作成時に次のようなコマンドを実行します。最後のピリオドを忘れずに。これでカレントディレクトリにconfigという名前で設定用ディレクトリが作成されます。

mkdir {プロジェクト名}/
cd {プロジェクト名}/
django-admin startproject config .

こんな感じのディレクトリ構成になります。

{プロジェクト名}/
├─ config/
│  ├─ __init__.py
│  ├─ asgi.py
│  ├─ urls.py
│  ├─ settings.py
│  └─ wsgi.py
└─ manage.py

アプリケーション作成

続いてアプリケーションの作成を行います。アプリケーションの作成コマンドは python manage.py startapp {アプリケーション名} です。今回はclassifierという名前にします。

python manage.py startapp classifier

この時点でのディレクトリ構成は次のようになっています。

image-recognition/
├─ classifier/
│  ├─ migrations/
│  │  └─ __init__.py
│  ├─ __init__.py
│  ├─ admin.py
│  ├─ apps.py
│  ├─ models.py
│  ├─ tests.py
│  └─ views.py
├─ config/
│  ├─ __init__.py
│  ├─ asgi.py
│  ├─ settings.py
│  ├─ urls.py
│  └─ wsgi.py
└─ manage.py

ちなみにここからファイルを作成して、最終的には次のような構成になります。

image-recognition/
├─ classifier/
│  ├─ migrations/
│  │  └─ __init__.py
│  ├─ admin.py
│  ├─ apps.py
│  ├─ forms.py
│  ├─ models.py
│  ├─ predictor.py
│  ├─ tests.py
│  ├─ urls.py
│  └─ views.py
├─ config/
│  ├─ __init__.py
│  ├─ asgi.py
│  ├─ settings.py
│  ├─ urls.py
│  └─ wsgi.py
├─ pretrained/
├─ static/
│  ├─ css/
│  │  └─ style.css
│  └─ js/
│     ├─ ajax.js
│     └─ barchart.js
├─ templates/
│  ├─ classifier.html
│  └─ result.html
├─ .env
└─ manage.py

全体的な設定

settings.py

config/settings.pyINSTALLED_APPSにclassifierアプリを追加します。

INSTALLED_APPS = [
    'django.contrib.admin',
    'django.contrib.auth',
    'django.contrib.contenttypes',
    'django.contrib.sessions',
    'django.contrib.messages',
    'django.contrib.staticfiles',

    # Local
    'classifier.apps.ClassifierConfig',
]

TEMPLATES'DIRS'の行を編集して'DIRS': [str(BASE_DIR.joinpath('templates'))]にします。これによりテンプレートを一か所(この場合はtemplates/配下)で管理することができるようになり、プロジェクト内にアプリケーションを複数作成した場合でも管理しやすくなります。

TEMPLATES = [
    {
        'BACKEND': 'django.template.backends.django.DjangoTemplates',
        # 'DIRS': [],
        'DIRS': [str(BASE_DIR.joinpath('templates'))],
        'APP_DIRS': True,
        'OPTIONS': {
            'context_processors': [
                'django.template.context_processors.debug',
                'django.template.context_processors.request',
                'django.contrib.auth.context_processors.auth',
                'django.contrib.messages.context_processors.messages',
            ],
        },
    },
]

アップロードした画像をブラウザ上に表示したいので、MEDIA_URLでメディアファイルの公開先を指定します。また、MEDIA_ROOTDjangoが画像ファイルを保存するディレクトリを指定します。

MEDIA_URL = '/media/'
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')

ファイル.envに記述した各種環境変数を読み込みます。

from environs import Env
(省略)
env = Env()
env.read_env()
(省略)
SECRET_KEY = env.str('DJANGO_SECRET_KEY')
(省略)
DEBUG = env.bool('DJANGO_DEBUG', default=False)
(省略)
PRETRAINED_MODEL_NAME = env.str('PRETRAINED_MODEL_NAME')
PRETRAINED_MODEL_PATH = env.str('PRETRAINED_MODEL_PATH')
CLASS_INDEX = env.str('CLASS_INDEX')
MAX_RANK = env.int('MAX_RANK')

ここで

  • SECRET_KEYDjangoの暗号化署名のキー
  • DEBUGDjangoデバッグモードで起動するかどうかのフラグ

  • PRETRAINED_MODEL_NAME:PyTorchのImageNet学習済みモデルの名前(例えばvgg16

  • PRETRAINED_MODEL_PATH:PyTorchのImageNet学習済みモデルのパス(例えば./pretrained/vgg16-397923af.pth
  • CLASS_INDEX:クラス番号とラベル名の関係が定義されたファイルのパス(例えば./pretrained/imagenet_class_index.json
  • MAX_RANK:最大何件まで結果を表示するかを指定

です。

環境変数

これらの変数にセットしている値は、ファイル.envを作成しておき、そこに環境変数として定義しておきます。

# Django settings
export DJANGO_SECRET_KEY={Django secret key generated above}
export DJANGO_DEBUG=True

# Classification model
export PRETRAINED_MODEL_NAME=vgg16
export PRETRAINED_MODEL_PATH=./pretrained/vgg16-397923af.pth
export CLASS_INDEX=./pretrained/imagenet_class_index.json
export MAX_RANK=5

上記の{Django secret key generated above}の部分は次のコマンドで生成した値を記述します(Django for Professionalsを参考にしました)。暗号化署名のキーなので各自で生成して、公開しないように注意しなければなりません。間違ってGitHubなどに掲載しないように。

python -c "import secrets; print(secrets.token_urlsafe(38))"

PRETRAINED_MODEL_PATHで指定している学習済みモデルは例えば下記があります。事前にダウンロードしておきます。

Name URL
alexnet https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth
densenet201 https://download.pytorch.org/models/densenet201-c1103571.pth
googlenet https://download.pytorch.org/models/googlenet-1378be20.pth
inception_v3 https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth
mobilenet_v2 https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
resnet50 https://download.pytorch.org/models/resnet50-19c8e357.pth
squeezenet1_0 https://download.pytorch.org/models/squeezenet1_0-a815701f.pt
vgg16 https://download.pytorch.org/models/vgg16-397923af.pth

CLASS_INDEXで指定しているjsonファイルは下記からダウンロードできます。

https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json

urls.py

config/urls.pyにルーティングを記述します。http://{IP Address}:{Port #}/へのアクセスがあったら、classifier/urls.pyに処理を渡すようにします。

from django.contrib import admin
from django.contrib.staticfiles.urls import static
from django.urls import include, path
from . import settings

urlpatterns = [
    path('admin/', admin.site.urls),
    path('', include('classifier.urls')),
]

urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT)

classifier/urls.pyでビューClassifierView(これから作成するもの)と結びつけます。

from django.urls import path
from .views import ClassifierView


app_name = 'classifier'
urlpatterns = [
    path('', ClassifierView.as_view(), name='classifier'),
]

Model

それではいよいよ処理の中身を実装していきます。DjangoアーキテクチャはMVT(Model View Template)なので、その順に説明します。

classifierアプリケーションのモデルをclassifier/models.pyに実装します。モデルはDBに画像ファイルを保存するために、django.db.models.ImageFiledを使ってフィールドを定義します。

from django.db import models


class ImageModel(models.Model):
    image = models.ImageField(verbose_name='Image file', upload_to='images')

ちなみに厳密にはDBに格納されるのは画像ファイルのパスであり、ファイルの実体はconfig/settings.pyにてMEDIA_ROOTで指定したディレクトリ(のさらに下にImageFieldの引数upload_toで指定したディレクトリが作成されて、その配下)に格納されます。

Form

フォームで入力するのは画像ファイルだけなので、フォーム用のクラスはdjango.forms.ModelFormを継承して作成することにします。そうすることでモデルのフィールド定義を流用でき、フォームの記述がコンパクトになるメリットがあります(今回はフィールドが1個しかないのであまり変わりませんが)。

classifier/forms.py

from django import forms
from .models import ImageModel


class ImageForm(forms.ModelForm):
    class Meta:
        model = ImageModel
        fields = ('image',)

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for field in self.fields.values():
            field.widget.attrs['class'] = 'form-control'

画面にBootstrapを利用しているので、__init__の中でBootstrapのスタイルを指定しています。今回はBootstrapの説明は割愛します。

ちなみに、フォームに入力した画像データはAjaxのPOSTリクエストでサーバに送信して、ビューで処理します。ビューで画像認識を行い、その結果をAjaxで返すようにしました。

View

classifierアプリケーションのビューをファイルclassifier/views.pyに実装します。今回はフォームを扱っているので、django.views.generic.FormViewを継承して作成します。フィールドtemplate_nameでテンプレートと、form_classでフォームと結びつけます。また、フォームのバリデーションがOKだった場合に実行されるメソッドform_validとNGだった場合に実行されるメソッドform_invalidをオーバーライドします。ビューの処理が実行される契機としては、初期表示時のGET処理と画像アップロード時のPOST処理があります。

  • 初期表示時(GET)
    • 単純にtemplate_nameで指定したテンプレートを描画
  • 画像アップロード時(POST)
    • メソッドform_validform_invalidにて、リクエストがAjaxかどうかで場合分けして処理を実装。リクエストがAjaxかどうかはself.request.is_ajaxで判定
    • form_validの処理概要
      • アップロードされた画像ファイルを保存(パスをDBに、実体を指定の場所に保存)
      • PyTorchの学習済みモデルで画像認識を行う
      • 認識結果をJSON形式にしてレスポンスを返却
    • form_invalidの処理概要
      • アップロードされたファイルが正しく処理できないので、レスポンスで「422 Unprocessable Entity」を返却

classifier/views.py

import json
import logging

from django.conf import settings
from django.http import HttpResponse
from django.urls import reverse_lazy
from django.views import generic

from .forms import ImageForm
from .models import ImageModel
from .predictor import predict

logger = logging.getLogger(__name__)


class ClassifierView(generic.FormView):
    template_name = 'classifier.html'
    form_class = ImageForm
    success_url = reverse_lazy('classifier:classifier')

    def form_valid(self, form):
        logger.info('form_valid start')

        if self.request.is_ajax():
            logger.info('ajax request')

            image = form.save(commit=False)
            image.save()
            logger.info('saved')

            max_id = ImageModel.objects.latest('id').id
            obj = ImageModel.objects.get(id=max_id)

            top_k_predictions_tmp = predict(obj.image, k=settings.MAX_RANK)
            logger.info(top_k_predictions_tmp)

            top_k_predictions = []
            for prediction in top_k_predictions_tmp:
                top_k_predictions.append({'label': prediction[1],
                                          'probability': float(prediction[2])})

            logger.info(self.template_name)
            logger.info(form)
            logger.info(obj.image.url)
            logger.info(top_k_predictions)

            data_json = json.dumps({'top_k_predictions': top_k_predictions,
                                    'image_url': obj.image.url})
            return HttpResponse(data_json, content_type='application/json')
        else:
            logger.info('not ajax request')
            return super().form_valid(form)

    def form_invalid(self, form):
        logger.info('form_invalid start')
        if self.request.is_ajax():
            logger.info('ajax request')
            response = HttpResponse({'error': 'Unprocessable Entity'})
            response.status_code = 422
            return response
        else:
            logger.info('not ajax request')
            return super().form_invalid(form)

画像認識の処理(PyTorch)

画像認識はPyTorchを用いて実装します。

classifier/predictor.py

from django.conf import settings
from django.core.cache import cache

import json
import logging
from PIL import Image

import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

logger = logging.getLogger(__name__)


def make_model(model_name):
    logger.info(f'create {model_name}')
    method = getattr(models, model_name)
    model = method()
    return model


def predict(image_file, k=5):
    logger.info(image_file)
    logger.info(k)

    img = Image.open(image_file)

    # Preprocess
    preprocess_img = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    img_transformed = preprocess_img(img)
    inputs = img_transformed.unsqueeze_(0)

    # Create model
    cache_key = 'model'
    model = cache.get(cache_key)
    if model is None:
        logger.info('model is not cached')
        model = make_model(settings.PRETRAINED_MODEL_NAME)
        model.load_state_dict(torch.load(settings.PRETRAINED_MODEL_PATH),
                              strict=False)
        cache.set(cache_key, model, None)
    else:
        logger.info('use cached model')

    # Set evaluation mode
    model.eval()

    # Input the image to the model and input the output to softmax
    out = model(inputs)
    out = nn.functional.softmax(out, dim=1)

    # Get top k probabilities and class indices
    top_k_values = out.topk(k).values.detach().numpy()[0]
    top_k_indices = out.topk(k).indices.detach().numpy()[0]

    # Create a dictionary that converts class indices to label names
    class_index = json.load(open(settings.CLASS_INDEX, 'r'))
    labels = {int(key): value for key, value in class_index.items()}

    # Get top k probabilities and labels
    top_k_label_prob = [(labels[idx][0], labels[idx][1], prob)
                        for prob, idx in zip(top_k_values, top_k_indices)]

    return top_k_label_prob

前処理

まずtransforms.Composeで前処理を行います。引数に渡しているパラメータは適当に決めた訳ではなく、学習済みモデルで実施した前処理と一致させる必要があるのでこのような値になっています。

モデル生成と学習済み重みの復元

画像認識モデルはサイズが大きいので、リクエストのたびにストレージから読み込んでいたのでは処理に時間がかかってしまいます。そこで、Djangoのキャッシュの仕組みを利用して、画像認識モデルをキャッシュに保持するようにします。cache.setでキャッシュに保持し、cache.getでキャッシュから読み込むことができます。

キャッシュに存在しない場合はmake_modelでモデルを生成していますが、ここは少々トリッキーです。torchvision.modelsの中には様々な画像認識モデルのクラスが定義されており、それらのクラスのオブジェクトを生成する関数も定義されています。method = getattr(models, model_name)でオブジェクトmodelsの属性model_nameを取得できます。このとき取得できるのがモデルを生成する関数なので、引き続いてmethod()でモデルを生成することができます。

例えば、AlexNetモデルを生成する関数alexnet()がありますが、method = getattr(models, 'alexnet')とすると、AlexNetモデルを生成する関数alexnetが取得でき、method()とすることでalexnet()が実行されます。

設定ファイルで様々なモデルを指定できるようにするためにこのような実装にしました。

モデルが生成出来たら、model.load_state_dictでストレージから学習済みの重みを読み込んでモデルにセットします。

モデルを生成する際に学習済みモデルをダウンロードして取得することもできますが、ダウンロードに時間がかかるので、今回は予めダウンロードしておいて、そのパスを指定する方針にしました。

推論

model.eval()で推論モードにセットします。

後はmodel(inputs)のようにモデルに前処理済みの画像データinputを与えれば、推論結果が得られます。

最終的には各カテゴリの確率が欲しいのでソフトマックス関数を使用します。

出力結果から上位$K$個を抜き出して、タプル(インデックス, ラベル名, 確率)のリストを作成して呼び出し元に返却します。

Template

classifierアプリケーションのテンプレートをtemplates/classifier.htmlに作成します。

templates/classifier.html

<!DOCTYPE html>
{% load static %}
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <title>Image Classifier</title>
   <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css" integrity="sha384-JcKb8q3iqJ61gNV9KGb8thSsNjpSL0n8PARn9HuZOnIxN0hoP+VmmDGMN5t9UJ0Z" crossorigin="anonymous">
    <script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.9.3/Chart.min.js"></script>
</head>
<body class="container">
    <div class="card mt-4 mb-4">
        <div class="card-header">
            <h1 class="display-4 text-center">Image Classifier</h1>
        </div>
        <div class="card-body">
            <div class="mx-2 my-4">
                <form method="post" enctype="multipart/form-data">
                    {% csrf_token %}
                    <div class="input-group">
                        <div class="custom-file">
                            <input type="file" class="custom-file-input" id="inputGroupFile" name="{{ form.image.html_name }}" required>
                            <label class="custom-file-label" for="inputGroupFile" aria-describedby="inputGroupFileAddon">Choose image file</label>
                        </div>
                        <div class="input-group-append">
                            <button class="btn btn-primary" type="submit" id="inputGroupFileAddon">Submit</button>
                        </div>
                    </div>
                </form>
            </div>
            <div class="mx-4 my-2">
                {# replace by Ajax #}
                <div class="row">
                    <div class="col-lg-5">
                        <div>
                            <img id="submittedImage" />
                        </div>
                        <div>
                            <table id="rankingTable"></table>
                        </div>
                    </div>
                    <div class="col-lg-7">
                        <div>
                            <canvas id="barChart" height="0"></canvas>
                        </div>
                    </div>
                </div>
            </div>
        </div>
    </div>
    <script src="https://code.jquery.com/jquery-3.5.1.min.js" integrity="sha256-9/aliU8dGd2tb6OSsuzixeV4y/faTqgFtohetphbbj0=" crossorigin="anonymous"></script>
    <script src="https://cdn.jsdelivr.net/npm/bs-custom-file-input/dist/bs-custom-file-input.js"></script>
    <script type="text/javascript" src="{% static 'js/ajax.js' %}"></script>
    <script type="text/javascript" src="{% static 'js/barchart.js' %}"></script>
    <script>
       bsCustomFileInput.init();
   </script>
</body>
</html>

ここでのポイントは2つ。

  • 画像ファイル送信用のフォームを作成
    Djangoでは{{ form.as_p }}のようにすればフォームを表示できるのですが、見栄えをよくするためにBootstrapを利用しています。当初はDjangoのテンプレートタグのみで描画しようと考え、いろいろと試行錯誤しましたが、細かく整えるのは無理そうとの結論に至りました。Djangoの機能を使うことに固執しすぎるのも良くないですよね。

    <form method="post" enctype="multipart/form-data">
        {% csrf_token %}
        <div class="input-group">
            <div class="custom-file">
                <input type="file" class="custom-file-input" id="inputGroupFile" name="{{ form.image.html_name }}" required>
                <label class="custom-file-label" for="inputGroupFile" aria-describedby="inputGroupFileAddon">Choose image file</label>
            </div>
            <div class="input-group-append">
                <button class="btn btn-primary" type="submit" id="inputGroupFileAddon">Submit</button>
            </div>
        </div>
    </form>
    
  • Ajaxで描画する領域を用意
    Ajaxでレスポンスを受け取ったら、その内容をここに埋め込むという方針です。

    • <img id="submittedImage" />:アップロードした画像を埋め込む場所
    • <table id="rankingTable"></table>:予測したラベルとその確率の一覧表を埋め込む場所
    • <canvas id="barChart" height="0"></canvas>:予測したラベルとその確率の水平棒グラフを埋め込む場所
    <div class="row">
        <div class="col-lg-5">
            <div>
                <img id="submittedImage" />
            </div>
            <div>
                <table id="rankingTable"></table>
            </div>
        </div>
        <div class="col-lg-7">
            <div>
                <canvas id="barChart" height="0"></canvas>
            </div>
        </div>
    </div>
    

Ajax

Ajaxの実装にはjQueryを利用しています。formでsubmitされたのを契機にして、POSTリクエストで画像ファイルを送信するようにします。$.ajaxにてリクエストの設定をします。また、done(正常時の処理)、fail(異常時の処理)、always(正常時でも異常時でも実行される後処理)をそれぞれ実装します。処理概要は次のとおりです:

  • done
    • レスポンスを受け取り、レスポンスに含まれるJSONをパース
    • パースした内容を基に、画像の表示、表の作成、グラフの描画を実施
  • fail
    • ブラウザ上にアラートでエラーを通知
  • always
    • 特になし

ちなみにgetCookieの実装はDjangoのドキュメントにあるCross Site Request Forgery protectionを流用させてもらいました。

static/js/ajax.js

/*
 * This function is from Django official web site:
 *   https://docs.djangoproject.com/en/3.1/ref/csrf/
 */
function getCookie(name) {
    let cookieValue = null;
    if (document.cookie && document.cookie !== '') {
        const cookies = document.cookie.split(';');
        for (let i = 0; i < cookies.length; i++) {
            const cookie = cookies[i].trim();
            // Does this cookie string begin with the name we want?
            if (cookie.substring(0, name.length + 1) === (name + '=')) {
                cookieValue = decodeURIComponent(cookie.substring(name.length + 1));
                break;
            }
        }
    }
    return cookieValue;
}

const csrftoken = getCookie('csrftoken');

function csrfSafeMethod(method) {
    // these HTTP methods do not require CSRF protection
    return (/^(GET|HEAD|OPTIONS|TRACE)$/.test(method));
}

$.ajaxSetup({
    beforeSend: function(xhr, settings) {
        if (!csrfSafeMethod(settings.type) && !this.crossDomain) {
            xhr.setRequestHeader('X-CSRFToken', csrftoken);
        }
    }
});

$('form').submit(function(event) {
    event.preventDefault();
    var form = $(this);
    var formData = new FormData($('form').get(0));
    $.ajax({
        type: 'POST',
        url: form.prop('action'),
        method: form.prop('method'),
        data: formData,
        processData: false,
        contentType: false,
        timeout: 10000,
        dataType: 'text',
    }).done(function(response) {
        console.log('done');

        var parsedResponse = JSON.parse(response);

        // show image
        showImage(parsedResponse.image_url);

        // create table
        createTable(parsedResponse.top_k_predictions);

        // draw chart
        drawChart(parsedResponse.top_k_predictions);
    }).fail(function(jqXHR, textStatus, errorThrown) {
        console.log('fail');
        console.log(jqXHR.status);
        console.log(textStatus);
        console.log(errorThrown);
        alert(jqXHR.status + ' Error: ' + errorThrown);
    }).always(function() {
        console.log('always');
    });
});

function showImage(image_url) {
    console.log('showImage start');
    var img = $('#submittedImage');
    img.attr('src', image_url);
    img.addClass('img-fluid mx-auto d-block mt-4 mb-4')
    console.log('showImage end');
}

function createTable(top_k_predictions) {
    console.log('createTable start');
    var table = $('#rankingTable');
    table.empty();
    table.append('<caption>Prediction result</caption>');
    table.append('<thead><th>Rank</th><th>Label</th><th>Probability</th></thead>');
    table.append('<tbody></tbody>');
    for (var i = 0; i < top_k_predictions.length; i++) {
        var row = $("<tr></tr>");
        row.append($('<td align="right"></td>').text(i + 1));
        row.append($('<td></td>').text(top_k_predictions[i].label));
        row.append($('<td align="right"></td>').text(top_k_predictions[i].probability.toFixed(3)));
        table.append(row);
    }
    table.addClass('table bg-light');
    console.log('createTable end');
}

グラフ描画

グラフ描画にはChart.jsを利用します。

static/js/barchart.js

function setData(top_k_predictions) {
    var labels = [];
    var probabilities = [];

    for (var i = 0; i < top_k_predictions.length; i++) {
        labels.push(top_k_predictions[i].label);
        probabilities.push(top_k_predictions[i].probability);
    }

    var data = {
        labels: labels,
        fontSize: 18,
        datasets: [{
            label: 'probability',
            data: probabilities,
            backgroundColor: 'rgba(54, 162, 235, 0.7)'
        }]
    };
    return data;
}

function setOptions() {
    var options = {
        responsive: true,
        maintainAspectRatio: true,
        scales: {
            xAxes: [{
                gridLines: {
                    display: true
                },
                scaleLabel: {
                    display: true,
                    labelString: 'probability',
                    fontSize: 14
                },
                ticks: {
                    min: 0.0,
                    max: 1.0,
                    fontSize: 14
                }
            }],
            yAxes: [{
                scaleLabel: {
                    display: true,
                    labelString: 'label',
                    fontSize: 14
                },
                ticks: {
                    autoSkip: false,
                    fontSize: 14
                }
            }]
        }
    };
    return options;
}

function drawChart(top_k_predictions) {
    console.log('drawChart start');

    var ctx = document.getElementById('barChart');
    var data = setData(top_k_predictions);
    var options = setOptions();

    // set height according to the number of data
    ctx.height = 18 * data.datasets[0]['data'].length + 60;

    // destroy chart instance if chart instance already exists
    if (window.myBarChart) {
        window.myBarChart.destroy();
    }

    window.myBarChart = new Chart(ctx, {
        type: 'horizontalBar',
        data: data,
        options: options
    });

    $('#barChart').addClass('mt-4');

    console.log('drawChart end');
}

実行

パッケージのインストールや実行手順はhttps://github.com/noriho137/image-classifierを参考にしてください。