AIエージェントのためのクラス。
システムプロンプトは、すぐ忘れえしまうので、ユーザーの入力の前に付け足すようにしている。そして、解答がかえってきた後、メモリ捜査で、そのシステムプロンプトを除いたプロンプトで、該当メモリを上書きして、削除している。削除しているのは、何度もユーザー入力として入れるとプロンプトの命令が強くなりすぎるため。
import streamlit as st
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
from langchain_core.runnables import RunnableConfig
from langchain_google_genai.chat_models import ChatGoogleGenerativeAIError
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.messages import HumanMessage, AIMessage
# models
from langchain_google_genai import ChatGoogleGenerativeAI
# custom tools
######################################################
class AIAgent():
def __init__(self, agent_name, system_prompt, tools, private_memory):
self.chat_history = []
self.sysytem_prompt = ""
self.llm = None
self.prompt = None
self.name = agent_name
self.llm = ChatGoogleGenerativeAI(
temperature=0,
model="gemini-1.5-flash"
)
self.private_memory = private_memory
self.__create_memory()
self.sysytem_prompt = system_prompt
self.tools = tools
self.agent = self.create_agent()
self.update_system_prompt(self.sysytem_prompt)
def __create_agent(self):
agent = create_tool_calling_agent(self.llm, self.tools, self.prompt)
self.agent = AgentExecutor(
agent=agent,
tools=self.tools,
verbose=True,
memory=self.chat_history
)
return self.agent
def __create_memory(self):
if self.private_memory:
st.session_state[self.name + 'memory'] = ConversationBufferWindowMemory(
return_messages=True,
memory_key=self.name + "chat_history",
k=100
)
self.chat_history = st.session_state[self.name + 'memory']
else:
if "messages" not in st.session_state:
st.session_state['memory'] = ConversationBufferWindowMemory(
return_messages=True,
memory_key="chat_history",
k=100
)
self.chat_history = st.session_state['memory']
def create_agent(self):
if self.private_memory:
print("self.sysytem_prompt",self.sysytem_prompt)
print("self.name",self.name)
self.prompt = ChatPromptTemplate.from_messages([
("system", self.sysytem_prompt),
MessagesPlaceholder(variable_name=self.name + "chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
else:
self.prompt = ChatPromptTemplate.from_messages([
("system", self.sysytem_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
return self.__create_agent()
def clear_memory(self):
print("self.chat_history.chat_memory.messages",
self.chat_history.chat_memory.messages)
self.chat_history.chat_memory.messages = []
def updata_temperature(self, temperature):
self.modllmel = ChatGoogleGenerativeAI(
temperature=temperature,
model="gemini-1.5-flash"
)
self.__create_agent()
def update_system_prompt(self, prompt):
self.sysytem_prompt = prompt
self.create_agent()
def get_respons(self, prompt, st_cb):
response = ""
if None is not self.agent:
# command = self.get_user_message(command, pre_agent)
# カスタムコールバックのインスタンスを作成
# エージェントを実行
try:
print("self.sysytem_prompt", self.sysytem_prompt, prompt)
response = self.agent.invoke(
{'input': self.sysytem_prompt+"\r\n"+prompt},
config=RunnableConfig({'callbacks': [st_cb]})
)
self.update_last_input(prompt)
except ChatGoogleGenerativeAIError as e:
print(f"エラーが発生しました: {e}")
print("function_response:", response)
if dict != type(response):
response = {}
response["output"] = f"エラーが発生しました: {e}\r\nこのまま続けられます。"
print(type(response))
return response
def update_last_input(self, imput_prompt):
# chat_history の内容を取得
iend = len(self.chat_history.chat_memory.messages) - 1
for i in range(len(self.chat_history.chat_memory.messages)):
if HumanMessage == type(self.chat_history.chat_memory.messages[iend-i]):
self.chat_history.chat_memory.messages[iend-i].content = imput_prompt
break
return
def modify_prompt(self, data):
self.sysytem_prompt = data
def get_history(self):
return self.chat_history
def append_aimessage(self, message):
self.chat_history.chat_memory.messages.append(AIMessage(message))
AIの解説
このコードは、LangChainライブラリを使用して、Google GeminiモデルをベースとしたAIエージェントを作成するものです。
クラス AIAgent
このクラスは、AIエージェントの機能をカプセル化します。
- コンストラクタ
agent_name
: エージェントの名前system_prompt
: エージェントのシステムプロンプトtools
: エージェントが使用できるツールprivate_memory
: エージェントがプライベートメモリを使用するかどうか
- メソッド
__create_agent()
: エージェントを作成します。__create_memory()
: エージェントのメモリを作成します。create_agent()
: エージェントを作成します。clear_memory()
: エージェントのメモリをクリアします。updata_temperature()
: エージェントの温度を更新します。update_system_prompt()
: エージェントのシステムプロンプトを更新します。get_respons()
: エージェントにプロンプトを送信し、応答を取得します。update_last_input()
: エージェントの最後の入力を更新します。modify_prompt()
: エージェントのプロンプトを変更します。get_history()
: エージェントの履歴を取得します。append_aimessage()
: エージェントの履歴にAIMessageを追加します。
コードの機能
このコードは、LangChainライブラリを使用して、Google GeminiモデルをベースとしたAIエージェントを作成します。エージェントは、システムプロンプト、ツール、メモリを使用して、ユーザーの入力に応答します。
使用方法
このコードを使用するには、まずAIAgentクラスのインスタンスを作成します。次に、エージェントにプロンプトを送信して、応答を取得します。
例
# エージェントを作成
agent = AIAgent(
agent_name="my_agent",
system_prompt="あなたは役に立つAIアシスタントです。",
tools=[],
private_memory=False
)
# エージェントにプロンプトを送信
response = agent.get_respons("こんにちは", None)
# 応答を出力
print(response)
参照