昨日書いた記事の中でブログのカテゴリを設計した。 また、別の記事では、このブログを管理するhugoのメタデータを編集しやすくするクラスを作った。
ここではそれらを前提にしてChatGPTを利用し、ブログ記事にカテゴリを付与していく。
今回は後述のクラスを作った。
以下のように実行すると、カテゴリが自動で設定される。
editor = LlmEditor(post)
editor.set_categories()
import json
from langchain import PromptTemplate
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import HumanMessage
# あらかじめ用意したカテゴリとその説明
CATEGORIES = {
"ブログ構築記録": "ブログを構築する過程で得た知見を記録する",
"新技術の試用": "新しい技術を試してみた結果を記録する",
"DIYシステム": "自作したシステムの構築過程を記録する",
"日常の気づき": "日常の中で気づいたことを記録する",
"草の根活動報告": "草の根活動の報告を記録する",
"個人活動の振り返り": "個人活動の振り返りを記録する",
}
# カテゴリーを選択するプロンプト
GET_CATEGORIES_PROMPT = """
Set categories for text separated by triple backsticks.
Provide in JSON format with the following key: categories
Constraint: Categories must be one or more but not more than two of the following.
{category_list}
```{input_text}```
JSON:
"""
class LlmEditor:
"""
LLMを利用して記事を編集する
"""
def __init__(self, post):
self.post = post
self.llm = ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo")
def set_categories(self):
"""
記事に適切なカテゴリーを設定する
"""
# カテゴリーの候補を取得する
categories = self._get_categories()
logger.info(f"カテゴリーを設定します: {categories}")
# カテゴリーを設定する
self.post.categories = categories
# 記事を保存する
self.post.save()
def _get_categories(self):
"""
記事に適切なカテゴリーを取得する
"""
category_list = "\n".join(
[
f"- {category}: {description}"
for category, description in CATEGORIES.items()
]
)
# バックスラッシュを除去する
input_text = f"{self.post.title}\n{self.post.content}".replace("`", "")
prompt = PromptTemplate(
input_variables=[
"category_list",
"input_text",
],
template=GET_CATEGORIES_PROMPT,
)
message_content = prompt.format(
category_list=category_list,
input_text=input_text,
)
messages = [
HumanMessage(content=message_content),
]
response = self.llm(messages)
logger.info(f"response: {response}")
# カテゴリーを抽出する
response_json = json.loads(response.content)
categories = response_json["categories"]
# カテゴリーが適切かどうかを確認する
if not categories:
raise ValueError("カテゴリーが空です")
elif len(categories) > 2:
logger.warning(f"カテゴリーが多すぎます: {categories}, {self.post.title}")
return categories[:2]
elif not all([category in CATEGORIES for category in categories]):
raise ValueError("カテゴリーが不正です", categories)
return categories