from abc import ABC, abstractmethod
from typing import Any, List, Optional, Dict
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage, AIMessageChunk
from langchain.agents import create_agent
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.tools import tool
#from langchain.agents import create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from tools.exception import InterruptedException
from Agents.llms.LlmInterface import LLMInterface # Import the interface
from langgraph.graph import StateGraph, MessagesState
from langgraph.checkpoint.memory import MemorySaver
try:
# ollamaライブラリがインストールされている場合、専用のエラーをインポート
from ollama import ResponseError
except ImportError:
ResponseError = None # インストールされていない場合はNoneにしておく
from typing import Optional, List, Union
from typing import Annotated, TypedDict
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.messages import ToolMessage # ツール結果用
import time
from langgraph.checkpoint.base import Checkpoint # これが必要かもしれません
from tools.word2contents import Word2Contents
#from langgraph.checkpoint.sqlite import SQLiteSaver
#from langchain_core.messages import SystemMessage
import time
import os
import requests
from pathlib import Path
import base64
#import GUI.PaintGUI as gui
g_mime_map = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".bmp": "image/bmp",
".tif": "image/tiff",
".tiff": "image/tiff",
".svg": "image/svg+xml",
".webp": "image/webp",
".emf": "image/x-emf",
".wmf": "application/x-msmetafile",
}
class State(TypedDict):
messages: Annotated[list[BaseMessage], add_messages] # メッセージ履歴
# ツール結果を追加する場合(オプション)
tool_results: list[str] # 例: ツールの出力リスト
def extract_user_input(messages):
# HumanMessage の最後のものを探す
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
return msg.content
raise ValueError("HumanMessage が見つかりません")
def extract_memory(messages):
# HumanMessage と AIMessage のペアを履歴として抽出
history = []
for msg in messages:
if isinstance(msg, (HumanMessage, AIMessage)):
history.append(msg)
return history
def extract_scratchpad(messages):
# ToolMessage や FunctionCallMessage などを scratchpad として抽出
scratchpad = []
for msg in messages:
if isinstance(msg, ToolMessage):
scratchpad.append(msg)
return scratchpad
class LlmBase(LLMInterface): # LlmBase implements LLMInterface
model_name: str
temperature: float
llm_kwargs: Dict[str, Any]
_supports_images: bool
llm: BaseChatModel # The actual Langchain LLM instance
memory: MemorySaver
#chat_history : MemorySaver
def __init__(self, model_identifier: str, temperature: float = 0, **kwargs):
self.model_name = model_identifier
self.temperature = temperature
self.llm_kwargs = kwargs
self.tools: List[Any] = []
self.system_prompt: str = ""
self._supports_images = False # Default, to be overridden by subclasses
self.llm = self._initialize_llm() # Initialize the specific LLM in the constructor
self.w2c=Word2Contents()
self.private_memory = False
self.name=""
LlmBase.memory = MemorySaver()#SQLiteSaver.from_cwd()
self.memory_key ="default"
# ツールマップ作成(名前で自動振り分け用)
self.tool_map = None
self.app = self.create_character()
self.__create_memory()
@abstractmethod
def _initialize_llm(self) -> BaseChatModel:
"""Abstract method to initialize the specific Langchain LLM instance."""
pass
def get_langchain_llm_instance(self) -> Optional[BaseChatModel]:
return self.llm
#######################
def create_character(self):
return self.create_agent_executer()
def create_agent_executer(self):
#各種設定が行われたときに作り直す。get_responseでは作らない。連続してメッセージをやり取りするときの負荷低減
#self.agent = create_tool_calling_agent(self.llm, self.tools, self.prompt)
self.agent = create_agent(
model=self.llm,
tools=self.tools,
#prompt=self.prompt
)
return self.build()
#####################################################################
def append_tools(self, tools_list: list, new_tools: list) -> list:
"""ツールリストに新しいツール(関数 or Toolインスタンス)を追加。
@tool付き関数は自動でToolに変換済みなので、そのまま追加。
"""
for new_tool in new_tools:
if callable(new_tool) and not hasattr(new_tool, 'invoke'): # 生関数なら@toolでラップ
new_tool = tool(new_tool) # 自動デコレータ適用(ただし事前定義推奨)
elif isinstance(new_tool, list): # ネストリスト対応
tools_list.extend(new_tool)
continue
tools_list.append(new_tool)
self.tool_map = {tool.name: tool for tool in tools_list}
def agent_node(self, state: State) -> State:
#print("state", state)
return self.agent.invoke(state) # state そのまま渡すだけでOK!
# tool_node: リスト登録ツールを自動実行(分岐なし!)
def tool_node(self, state: State) -> State:
outputs = []
last_message = state["messages"][-1]
for tool_call in last_message.tool_calls: # 複数ツール呼び出し対応
tool_name = tool_call["name"]
tool_args = tool_call["args"]
# 自動振り分け: tool_map.get()で名前からツール取得
selected_tool = self.tool_map.get(tool_name)
if selected_tool:
tool_result = selected_tool.invoke(tool_args) # 動的実行
outputs.append(tool_result)
else:
outputs.append(f"不明なツール: {tool_name}")
# 結果をToolMessageとして状態に追加(LLMが読めるよう)
return {
"messages": [ToolMessage(
content=str(outputs),
tool_call_id=last_message.tool_calls[0].get("id") if last_message.tool_calls else None
)]
}
def should_continue(self, state: State):
last_message = state["messages"][-1]
return "tools" if last_message.tool_calls else END
def build(self):
workflow = StateGraph(state_schema=State)
workflow.add_node("agent", self.agent_node) # ← self.agent_node を渡す
workflow.add_node("tools", self.tool_node) # ← self.agent_node を渡す
workflow.set_entry_point("agent")
workflow.add_conditional_edges("agent", self.should_continue)
workflow.add_edge("tools", "agent")
return workflow.compile(checkpointer=LlmBase.memory)
###########################################################
def flatten_message(self, msg) -> dict:
return {
"id": getattr(msg, "id", None),
"content": getattr(msg, "content", None),
"additional_kwargs": getattr(msg, "additional_kwargs", {}),
"response_metadata": getattr(msg, "response_metadata", {}),
"usage_metadata": getattr(msg, "usage_metadata", {}),
"type": msg.__class__.__name__
}
def restore_message(self, d: dict) -> BaseMessage:
msg_type = d.get("type", "AIMessage")
cls_map = {
"AIMessage": AIMessage,
"HumanMessage": HumanMessage,
"SystemMessage": SystemMessage,
"ToolMessage": ToolMessage
}
cls = cls_map.get(msg_type, AIMessage)
return cls(
content=d.get("content", ""),
additional_kwargs=d.get("additional_kwargs", {}),
response_metadata=d.get("response_metadata", {}),
id=d.get("id", None),
usage_metadata=d.get("usage_metadata", {})
)
def clear_memory(self):
LlmBase.memory.delete_thread(self.memory_key)
self.set_system_prompt(self.system_prompt)
def __set_index_contents(self, index: int, type_str: str, contents: str):
config = {
"configurable": {
"thread_id": self.memory_key,
}
}
state = self.app.get_state(config)
dict_data=self.flatten_message(state.values["messages"][index])
dict_data["type"] = type_str
dict_data["content"] = contents
new_msg=self.restore_message(dict_data)
new_messages = state.values["messages"].copy()
new_messages[1] = new_msg
new_state = {"messages": new_messages}
self.app.update_state(config, new_state)
def is_memory_Nodata(self) -> bool:
config = {
"configurable": {
"thread_id": self.memory_key,
}
}
state = self.app.get_state(config)
if state is None:
return True
if "messages" not in state.values:
return True
if len(state.values["messages"]) == 0:
return True
return False
def set_system_prompt(self, system_prompt: str):
self.system_prompt = system_prompt
# SQLiteSaverは "system" などのネームスペースを意識しなくても動くことが多いです
# thread_id のみを設定 config は MemorySaver の時と同じ構造でOK
config = {
"configurable": {
"thread_id": self.memory_key,
}
}
if self.is_memory_Nodata():
return False
self.__set_index_contents(0, "SystemMessage", self.system_prompt)
state = self.app.get_state(config)
#print("set_system_prompt:", state)
#LlmBase.memory.put(config=config,checkpoint=history, metadata={}, new_versions=new_versions)
return True
def get_memory(self):
return LlmBase.memory
def __create_memory(self):
"""
会話履歴を保存するためのメモリを作成します。
private_memoryフラグに基づいて、エージェント固有または共有のメモリを使用します。
"""
if self.private_memory:
self.memory_key = self.name + "_chat_history" # アンダースコア区切りが一般的
else:
self.memory_key="chat_history"
self.clear_memory()
def is_private_memory(self, is_private):
self.private_memory = is_private
def set_name(self, name):
self.name = name
self.__create_memory()
def append_message(self, huma_message, ai_message):
"""
メモリにHumanメッセージとAIメッセージを追加します。
Args:
huma_message (str): 追加するHumanメッセージの内容。
ai_message (str): 追加するAIメッセージの内容。
"""
self.append_human_message(huma_message)
self.append_ai_message(ai_message)
def append_ai_message(self, content: str):
"""
メモリにAIメッセージを追加します。
Args:
content (str): 追加するAIメッセージの内容。
"""
config = {"configurable": {"thread_id": self.memory_key}}
LlmBase.memory.put(config, {
"messages": [
AIMessage(content=content)
]
})
def append_human_message(self, message):
"""
メモリにHumanメッセージを追加します。
Args:
message (str): 追加するHumanメッセージの内容。
"""
config = {"configurable": {"thread_id": self.memory_key}}
LlmBase.memory.put(config, {
"messages": [
HumanMessage(content=message)
]
})
def update_input(
self,
imput_prompt_text: str,
image_data_urls: Optional[List[str]] = None,
target_index: int = None):
"""
会話履歴内のユーザー入力(HumanMessage)を更新します。
- target_index が指定されればその位置の HumanMessage を更新
- 指定がなければ最後の HumanMessage を更新
Args:
imput_prompt_text (str): 更新するユーザー入力の文字列
image_data_urls (Optional[List[str]]): 画像URLリスト
target_index (Optional[int]): 更新対象のインデックス(Noneなら末尾)
"""
pass
def update_input_text(
self,
imput_prompt_text: str,
target_index: int = None):
"""
会話履歴内のユーザー入力(HumanMessage)を更新します。
- target_index が指定されればその位置の HumanMessage を更新
- 指定がなければ最後の HumanMessage を更新
Args:
imput_prompt_text (str): 更新するユーザー入力の文字列
image_data_urls (Optional[List[str]]): 画像URLリスト
target_index (Optional[int]): 更新対象のインデックス(Noneなら末尾)
"""
# stateを取得
config = {
"configurable": {
"thread_id": self.memory_key,
}
}
state = self.app.get_state(config)
messages = state.values.get("messages", [])
# 更新対象インデックスを決定
if target_index is None:
for i in range(len(messages) - 1, -1, -1):
if messages[i].__class__.__name__ == "HumanMessage":
target_index = i
break
#if isinstance(target_index,list):
# print("target_index is list", target_index)
if target_index is None or target_index < 0 or len(messages) <= target_index :
print("指定されたインデックスに HumanMessage が存在しません")
return
## flattenしたdictを返す
#if target_index is not None and 0 <= target_index < len(messages):
self.__set_index_contents(target_index, "HumanMessage", imput_prompt_text)
###################
def update_tools(self, tools: List[Any]):
"""Updates the list of tools available to the LLM provider."""
self.tools = tools
def update_system_prompt(self, system_prompt: str):
"""Updates the system prompt for the LLM provider."""
self.system_prompt_str = system_prompt
self.set_system_prompt(system_prompt)
#########################
def _encode_image_to_data_url(self, image_path_or_url: str) -> str:
"""画像パスまたはURLからBase64エンコードされたデータURL文字列を生成する"""
if image_path_or_url.startswith("http://") or image_path_or_url.startswith("https://"):
# import requests # AIAgent.py の冒頭で import 済みのはず
response = requests.get(image_path_or_url, timeout=10)
response.raise_for_status()
image_data = response.content
mime_type = response.headers.get('Content-Type', 'image/jpeg')
elif Path(image_path_or_url).exists():
import os # ローカルインポートで良いか、クラス冒頭でimportするか検討
with open(image_path_or_url, "rb") as image_file:
image_data = image_file.read()
_, ext = os.path.splitext(image_path_or_url.lower())
mime_type = g_mime_map.get(ext)
else:
raise FileNotFoundError(f"Image path or URL not found or not accessible: {image_path_or_url}")
base64_encoded_data = base64.b64encode(image_data).decode('utf-8')
return {"type": "image", "base64": base64_encoded_data, "mime_type":mime_type}
def __create_text_contents(self, text):
content_parts = {"type": "text", "text": text}
return content_parts
def _create_image_contents(self, image_path_or_url: str):
return self._encode_image_to_data_url()
def get_current_message_content(self,prompt,file_paths):
content_parts = []
print("get_current_message_content file_paths",file_paths)
for file_path in file_paths:
if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
content_parts.append(self._encode_image_to_data_url(file_path))
if file_path.lower().endswith(('.pdf')):
content_parts.append(self.open_pdf(file_path))
if file_path.lower().endswith(('.txt', '.md', ".csv", ".json", ".log", ".xml", ".html", "htm", ".css", ".js", ".sql", ".yaml", ".yml", ".ini",
".py", ".java", ".c", ".cpp", ".h", ".cs", ".rb", ".rs", ".php", ".swift", ".kt", ".m", ".sh", ".r", ".pl", ".vb", ".ts", ".dart")):
content_parts.append(self.open_text(file_path))
if file_path.lower().endswith(('.docx', '.doc')):
content_parts.extend(self.open_words(file_path))
# Construct the HumanMessage content, handling multimodal input
content_parts.append(self.__create_text_contents(prompt))
#print("content_parts",content_parts)
return content_parts
def open_pdf(self, pdf_path):
from pdfminer.high_level import extract_text
results = ""
text = extract_text(pdf_path)
results += f"\n\n--- Content from {os.path.basename(pdf_path)} ---\n\n"
results += text.encode("utf-8", errors="replace").decode("utf-8") + "\n"
results += "--- End of Content ---\n\n"
return self.__create_text_contents(results)
def open_text(self, text_path):
results = ""
with open(text_path, 'r', encoding='utf-8', errors='ignore') as file:
text = file.read()
results += f"\n\n--- Content from {os.path.basename(text_path)} ---\n\n"
results += text.encode("utf-8", errors="replace").decode("utf-8") + "\n"
results += "\n--- End of Content ---\n\n"
return self.__create_text_contents(results)
def open_words(self, word_paths):
w2c=Word2Contents()
result = w2c.open(word_paths)
return result
def get_response(self,
prompt: str,
#chat_history: Optional[MemorySaver] = None,
file_paths: Optional[List[str]] = None,
system_prompt: Optional[str] = None,
tools: Optional[List[Any]] = None,
callbacks: Optional[BaseCallbackHandler] = None,
**kwargs) -> str:
# Construct the HumanMessage content, handling multimodal input
current_message_content = self.get_current_message_content(prompt, file_paths)
#history = chat_history if chat_history else []
if self.is_memory_Nodata():
agent_input = {
"messages": [SystemMessage(content=self.system_prompt), HumanMessage(content=current_message_content)],
}
else:
agent_input = {
"messages": [
HumanMessage(content=current_message_content)
],
}
run_config = {
"configurable": {"thread_id": self.memory_key},
"callbacks": callbacks
}
# ツールが指定されている場合のみAgentExecutorを使用
try:
for output in self.app.stream(agent_input, config=run_config, **kwargs):
for key, value in output.items():
if "messages" in value:
for msg in value["messages"]:
if hasattr(msg, 'content') and msg.content:
response = msg.content
# callbackがトークンを逐次print(上記のon_llm_new_tokenで処理)
pass # ここはループでチャンクを扱うが、callbackがメイン表示
state = self.app.get_state(run_config)
# ★★★【重要】成功した場合の戻り値を返す処理を追加 ★★★
if isinstance(response, dict):
return response.get("output", f"AgentExecutor returned a dict without 'output' key: {response}")
#print("response",response)
return str(response)
except InterruptedException:
print(f"{type(self).__name__} ({self.model_name}): AgentExecutor execution interrupted.")
raise # AIAgent側で処理するために再スロー
except NotImplementedError as e_not_implemented:
# bind_tools が実装されていないモデルでツールを使おうとした場合のエラー
import traceback
error_msg = f"{type(self).__name__} ({self.model_name}) error: The selected model or its LangChain wrapper does not support the required tool-calling feature (bind_tools). Please use a tool-compatible model (like `devstral` with the latest `langchain-ollama` package) or disable tools. Original error: {e_not_implemented}"
print(error_msg)
traceback.print_exc()
return error_msg
except ResponseError as e:
error_msg = f"Ollama server returned an error: {e}"
print(error_msg)
# このエラーもGUIに返す
return f"{type(self).__name__} AgentExecutor error: {error_msg}"
except Exception as e_agent:
import traceback
print(f"{type(self).__name__}: Error during AgentExecutor execution: {e_agent}")
traceback.print_exc() # デバッグ用に詳細なエラー情報を出力
return f"{type(self).__name__} AgentExecutor error: {e_agent}"
except BaseException as e_agent:
import traceback
print(f"{type(self).__name__}: Error during AgentExecutor execution: {e_agent}")
traceback.print_exc() # デバッグ用に詳細なエラー情報を出力
return f"{type(self).__name__} AgentExecutor error: {e_agent}"
@property
@abstractmethod # This must be abstract in LlmBase as it's specific to each LLM
def supports_images(self) -> bool:
pass