ブログ記事に自動でカテゴリを割り当てる

昨日書いた記事の中でブログのカテゴリを設計した。 また、別の記事では、このブログを管理するhugoのメタデータを編集しやすくするクラスを作った。

ここではそれらを前提にしてChatGPTを利用し、ブログ記事にカテゴリを付与していく。

今回は後述のクラスを作った。

以下のように実行すると、カテゴリが自動で設定される。

editor = LlmEditor(post)
editor.set_categories()
import json
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import HumanMessage

# あらかじめ用意したカテゴリとその説明
CATEGORIES = {
    "ブログ構築記録": "ブログを構築する過程で得た知見を記録する",
    "新技術の試用": "新しい技術を試してみた結果を記録する",
    "DIYシステム": "自作したシステムの構築過程を記録する",
    "日常の気づき": "日常の中で気づいたことを記録する",
    "草の根活動報告": "草の根活動の報告を記録する",
    "個人活動の振り返り": "個人活動の振り返りを記録する",
}

# カテゴリーを選択するプロンプト
GET_CATEGORIES_PROMPT = """
Set categories for text separated by triple backsticks.
Provide in JSON format with the following key: categories
Constraint: Categories must be one or more but not more than two of the following.
{category_list}
```{input_text}```
JSON:
"""


class LlmEditor:
    """
    LLMを利用して記事を編集する
    """

    def __init__(self, post):
        self.post = post
        self.llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")

    def set_categories(self):
        """
        記事に適切なカテゴリーを設定する
        """
        # カテゴリーの候補を取得する
        categories = self._get_categories()
        logger.info(f"カテゴリーを設定します: {categories}")

        # カテゴリーを設定する
        self.post.categories = categories

        # 記事を保存する
        self.post.save()

    def _get_categories(self):
        """
        記事に適切なカテゴリーを取得する
        """
        category_list = "\n".join(
            [
                f"- {category}: {description}"
                for category, description in CATEGORIES.items()
            ]
        )
        # バックスラッシュを除去する
        input_text = f"{self.post.title}\n{self.post.content}".replace("`", "")

        prompt = PromptTemplate(
            input_variables=[
                "category_list",
                "input_text",
            ],
            template=GET_CATEGORIES_PROMPT,
        )
        message_content = prompt.format(
            category_list=category_list,
            input_text=input_text,
        )
        messages = [
            HumanMessage(content=message_content),
        ]
        response = self.llm(messages)
        logger.info(f"response: {response}")

        # カテゴリーを抽出する
        response_json = json.loads(response.content)
        categories = response_json["categories"]

        # カテゴリーが適切かどうかを確認する
        if not categories:
            raise ValueError("カテゴリーが空です")
        elif len(categories) > 2:
            logger.warning(f"カテゴリーが多すぎます: {categories}, {self.post.title}")
            return categories[:2]
        elif not all([category in CATEGORIES for category in categories]):
            raise ValueError("カテゴリーが不正です", categories)

        return categories