AI Agent:Agents:llms:LlmBase.py:ソースコード

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Dict
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage, AIMessageChunk

from langchain.agents import create_agent
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.tools import tool

#from langchain.agents import create_tool_calling_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from tools.exception import InterruptedException
from Agents.llms.LlmInterface import LLMInterface # Import the interface
from langgraph.graph import StateGraph, MessagesState
from langgraph.checkpoint.memory import MemorySaver

try:
    # ollamaライブラリがインストールされている場合、専用のエラーをインポート
    from ollama import ResponseError
except ImportError:
    ResponseError = None # インストールされていない場合はNoneにしておく

from typing import Optional, List, Union

from typing import Annotated, TypedDict
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from langchain_core.messages import ToolMessage  # ツール結果用
import time
from langgraph.checkpoint.base import Checkpoint # これが必要かもしれません

from tools.word2contents import Word2Contents
#from langgraph.checkpoint.sqlite import SQLiteSaver
#from langchain_core.messages import SystemMessage
import time
import os
import requests
from pathlib import Path
import base64
#import GUI.PaintGUI as gui

g_mime_map = {
    ".png": "image/png",
    ".jpg": "image/jpeg",
    ".jpeg": "image/jpeg",
    ".gif": "image/gif",
    ".bmp": "image/bmp",
    ".tif": "image/tiff",
    ".tiff": "image/tiff",
    ".svg": "image/svg+xml",
    ".webp": "image/webp",
    ".emf": "image/x-emf",
    ".wmf": "application/x-msmetafile",
}


class State(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]  # メッセージ履歴
    # ツール結果を追加する場合(オプション)
    tool_results: list[str]  # 例: ツールの出力リスト





def extract_user_input(messages):
    # HumanMessage の最後のものを探す
    for msg in reversed(messages):
        if isinstance(msg, HumanMessage):
            return msg.content
   
    raise ValueError("HumanMessage が見つかりません")

def extract_memory(messages):
    # HumanMessage と AIMessage のペアを履歴として抽出
    history = []
    for msg in messages:
        if isinstance(msg, (HumanMessage, AIMessage)):
            history.append(msg)
    
    return history

def extract_scratchpad(messages):
    # ToolMessage や FunctionCallMessage などを scratchpad として抽出
    scratchpad = []
    for msg in messages:
        if isinstance(msg, ToolMessage):
            scratchpad.append(msg)
    
    return scratchpad

class LlmBase(LLMInterface): # LlmBase implements LLMInterface
    model_name: str
    temperature: float
    llm_kwargs: Dict[str, Any]
    _supports_images: bool
    llm: BaseChatModel # The actual Langchain LLM instance
    memory: MemorySaver 
    
    #chat_history : MemorySaver
    def __init__(self, model_identifier: str, temperature: float = 0, **kwargs):
        self.model_name = model_identifier
        self.temperature = temperature
        self.llm_kwargs = kwargs
        self.tools: List[Any] = []
        self.system_prompt: str = ""
        self._supports_images = False # Default, to be overridden by subclasses
        self.llm = self._initialize_llm() # Initialize the specific LLM in the constructor

        self.w2c=Word2Contents()
        self.private_memory = False
        self.name=""
        LlmBase.memory = MemorySaver()#SQLiteSaver.from_cwd()
        self.memory_key ="default"
        # ツールマップ作成(名前で自動振り分け用)
        self.tool_map = None
        self.app = self.create_character()
        self.__create_memory()
    @abstractmethod
    def _initialize_llm(self) -> BaseChatModel:
        """Abstract method to initialize the specific Langchain LLM instance."""
        pass

    def get_langchain_llm_instance(self) -> Optional[BaseChatModel]:
        return self.llm

#######################
    def create_character(self):

        return self.create_agent_executer()
    



    def create_agent_executer(self):
        #各種設定が行われたときに作り直す。get_responseでは作らない。連続してメッセージをやり取りするときの負荷低減        
        #self.agent = create_tool_calling_agent(self.llm, self.tools, self.prompt)

        self.agent = create_agent(
            model=self.llm,
            tools=self.tools,
            #prompt=self.prompt
        )

        return self.build()
    
    #####################################################################
    def append_tools(self, tools_list: list, new_tools: list) -> list:
        """ツールリストに新しいツール(関数 or Toolインスタンス)を追加。
        @tool付き関数は自動でToolに変換済みなので、そのまま追加。
        """
        for new_tool in new_tools:
            if callable(new_tool) and not hasattr(new_tool, 'invoke'):  # 生関数なら@toolでラップ
                new_tool = tool(new_tool)  # 自動デコレータ適用(ただし事前定義推奨)
            elif isinstance(new_tool, list):  # ネストリスト対応
                tools_list.extend(new_tool)
                continue
            tools_list.append(new_tool)
        self.tool_map = {tool.name: tool for tool in tools_list}



    def agent_node(self, state: State) -> State:
        #print("state", state)
        return self.agent.invoke(state)   # state そのまま渡すだけでOK!
   
    # tool_node: リスト登録ツールを自動実行(分岐なし!)
    def tool_node(self, state: State) -> State:
        outputs = []
        last_message = state["messages"][-1]

        for tool_call in last_message.tool_calls:  # 複数ツール呼び出し対応
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]

            # 自動振り分け: tool_map.get()で名前からツール取得
            selected_tool = self.tool_map.get(tool_name)
            if selected_tool:
                tool_result = selected_tool.invoke(tool_args)  # 動的実行
                outputs.append(tool_result)
            else:
                outputs.append(f"不明なツール: {tool_name}")

        # 結果をToolMessageとして状態に追加(LLMが読めるよう)
        return {
            "messages": [ToolMessage(
                content=str(outputs),
                tool_call_id=last_message.tool_calls[0].get("id") if last_message.tool_calls else None
            )]
        }
    
    def should_continue(self, state: State):
        last_message = state["messages"][-1]
        return "tools" if last_message.tool_calls else END
    
    def build(self):
        workflow = StateGraph(state_schema=State)
        workflow.add_node("agent", self.agent_node)  # ← self.agent_node を渡す
        workflow.add_node("tools", self.tool_node)  # ← self.agent_node を渡す

        workflow.set_entry_point("agent")
        workflow.add_conditional_edges("agent", self.should_continue)
        workflow.add_edge("tools", "agent")
        return workflow.compile(checkpointer=LlmBase.memory)
###########################################################
    def flatten_message(self, msg) -> dict:
        return {
            "id": getattr(msg, "id", None),
            "content": getattr(msg, "content", None),
            "additional_kwargs": getattr(msg, "additional_kwargs", {}),
            "response_metadata": getattr(msg, "response_metadata", {}),
            "usage_metadata": getattr(msg, "usage_metadata", {}),
            "type": msg.__class__.__name__
        }
    def restore_message(self, d: dict) -> BaseMessage:
        msg_type = d.get("type", "AIMessage")
        cls_map = {
            "AIMessage": AIMessage,
            "HumanMessage": HumanMessage,
            "SystemMessage": SystemMessage,
            "ToolMessage": ToolMessage
        }
        cls = cls_map.get(msg_type, AIMessage)
        return cls(
            content=d.get("content", ""),
            additional_kwargs=d.get("additional_kwargs", {}),
            response_metadata=d.get("response_metadata", {}),
            id=d.get("id", None),
            usage_metadata=d.get("usage_metadata", {})
        )
    def clear_memory(self):
        LlmBase.memory.delete_thread(self.memory_key)
        self.set_system_prompt(self.system_prompt)
        
    def __set_index_contents(self, index: int, type_str: str, contents: str):
        config = {
            "configurable": {
                "thread_id": self.memory_key,
            }
        }
        state = self.app.get_state(config)        

        dict_data=self.flatten_message(state.values["messages"][index])
        dict_data["type"] = type_str
        dict_data["content"] = contents
        new_msg=self.restore_message(dict_data)
        
        new_messages = state.values["messages"].copy()
        new_messages[1] = new_msg
        new_state = {"messages": new_messages}
        self.app.update_state(config, new_state)        
    def is_memory_Nodata(self) -> bool:
        config = {
            "configurable": {
                "thread_id": self.memory_key,
            }
        }
        state = self.app.get_state(config)        
        if state is None:
            return True
        if "messages" not in state.values:
            return True
        if len(state.values["messages"]) == 0:
            return True
        return False
    
    def set_system_prompt(self, system_prompt: str):
        self.system_prompt = system_prompt
        # SQLiteSaverは "system" などのネームスペースを意識しなくても動くことが多いです
        # thread_id のみを設定 config は MemorySaver の時と同じ構造でOK
        config = {
            "configurable": {
                "thread_id": self.memory_key,
            }
        }
        if self.is_memory_Nodata():
            return False
        
        self.__set_index_contents(0, "SystemMessage", self.system_prompt)
        state = self.app.get_state(config)
        #print("set_system_prompt:", state)
        #LlmBase.memory.put(config=config,checkpoint=history, metadata={}, new_versions=new_versions)
        return True
    def get_memory(self):
        return LlmBase.memory
    
    def __create_memory(self):
        """
        会話履歴を保存するためのメモリを作成します。
        private_memoryフラグに基づいて、エージェント固有または共有のメモリを使用します。
        """

        
 
        if self.private_memory:
                
                self.memory_key = self.name + "_chat_history" # アンダースコア区切りが一般的

        else:
                self.memory_key="chat_history"
        self.clear_memory()

    def is_private_memory(self, is_private):
        self.private_memory = is_private

    def set_name(self, name):
        self.name = name

        self.__create_memory()
    def append_message(self, huma_message, ai_message):
        """
        メモリにHumanメッセージとAIメッセージを追加します。
        Args:
            huma_message (str): 追加するHumanメッセージの内容。
            ai_message (str): 追加するAIメッセージの内容。
        """
        self.append_human_message(huma_message)
        self.append_ai_message(ai_message)  
        
    def append_ai_message(self, content: str):
        """
        メモリにAIメッセージを追加します。
        Args:
            content (str): 追加するAIメッセージの内容。
        """
        config = {"configurable": {"thread_id": self.memory_key}}
        LlmBase.memory.put(config, {
            "messages": [
                AIMessage(content=content)
            ]
        })        
    def append_human_message(self, message):
        """
        メモリにHumanメッセージを追加します。
        Args:
            message (str): 追加するHumanメッセージの内容。
        """
        config = {"configurable": {"thread_id": self.memory_key}}
        LlmBase.memory.put(config, {
            "messages": [
                HumanMessage(content=message)
            ]
        })


    def update_input(
        self,
        imput_prompt_text: str,
        image_data_urls: Optional[List[str]] = None,
        target_index: int = None):
        """
        会話履歴内のユーザー入力(HumanMessage)を更新します。
        - target_index が指定されればその位置の HumanMessage を更新
        - 指定がなければ最後の HumanMessage を更新

        Args:
            imput_prompt_text (str): 更新するユーザー入力の文字列
            image_data_urls (Optional[List[str]]): 画像URLリスト
            target_index (Optional[int]): 更新対象のインデックス(Noneなら末尾)
        """
        pass
    def update_input_text(
        self,
        imput_prompt_text: str,
        target_index: int = None):
        """
        会話履歴内のユーザー入力(HumanMessage)を更新します。
        - target_index が指定されればその位置の HumanMessage を更新
        - 指定がなければ最後の HumanMessage を更新

        Args:
            imput_prompt_text (str): 更新するユーザー入力の文字列
            image_data_urls (Optional[List[str]]): 画像URLリスト
            target_index (Optional[int]): 更新対象のインデックス(Noneなら末尾)
        """

        # stateを取得
        config = {
            "configurable": {
                "thread_id": self.memory_key,
            }
        }
        state = self.app.get_state(config)
        messages = state.values.get("messages", [])

        # 更新対象インデックスを決定

        if target_index is None:
            for i in range(len(messages) - 1, -1, -1):
                if messages[i].__class__.__name__ == "HumanMessage":
                    target_index = i
                    break
        #if isinstance(target_index,list):
        #    print("target_index is list", target_index)
    
        if target_index is None or target_index < 0 or len(messages) <= target_index :
            print("指定されたインデックスに HumanMessage が存在しません")
            return
        ## flattenしたdictを返す
        #if target_index is not None and 0 <= target_index < len(messages):
        self.__set_index_contents(target_index, "HumanMessage", imput_prompt_text)

###################

    def update_tools(self, tools: List[Any]):
        """Updates the list of tools available to the LLM provider."""
        self.tools = tools

    def update_system_prompt(self, system_prompt: str):
        """Updates the system prompt for the LLM provider."""
        self.system_prompt_str = system_prompt
        self.set_system_prompt(system_prompt)

#########################
    def _encode_image_to_data_url(self, image_path_or_url: str) -> str:
        """画像パスまたはURLからBase64エンコードされたデータURL文字列を生成する"""
        if image_path_or_url.startswith("http://") or image_path_or_url.startswith("https://"):
            # import requests # AIAgent.py の冒頭で import 済みのはず
            response = requests.get(image_path_or_url, timeout=10)
            response.raise_for_status()
            image_data = response.content
            mime_type = response.headers.get('Content-Type', 'image/jpeg')
        elif Path(image_path_or_url).exists():
            import os # ローカルインポートで良いか、クラス冒頭でimportするか検討
            with open(image_path_or_url, "rb") as image_file:
                image_data = image_file.read()
            _, ext = os.path.splitext(image_path_or_url.lower())
            mime_type = g_mime_map.get(ext)
        else:
            raise FileNotFoundError(f"Image path or URL not found or not accessible: {image_path_or_url}")


        base64_encoded_data = base64.b64encode(image_data).decode('utf-8')
        return {"type": "image", "base64": base64_encoded_data, "mime_type":mime_type}

        
    
    def __create_text_contents(self, text):
        content_parts = {"type": "text", "text": text}
        return content_parts
    def _create_image_contents(self, image_path_or_url: str):
        return self._encode_image_to_data_url()
        
    def get_current_message_content(self,prompt,file_paths):
        content_parts = []
        print("get_current_message_content file_paths",file_paths)
        for file_path in file_paths:
            if file_path.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                content_parts.append(self._encode_image_to_data_url(file_path))
            if file_path.lower().endswith(('.pdf')):
                content_parts.append(self.open_pdf(file_path))
            if file_path.lower().endswith(('.txt', '.md', ".csv", ".json", ".log", ".xml", ".html", "htm", ".css", ".js", ".sql", ".yaml", ".yml", ".ini",
                                           ".py", ".java", ".c", ".cpp", ".h", ".cs", ".rb", ".rs", ".php", ".swift", ".kt", ".m", ".sh", ".r", ".pl", ".vb", ".ts", ".dart")):
                content_parts.append(self.open_text(file_path))
            if file_path.lower().endswith(('.docx', '.doc')):
                content_parts.extend(self.open_words(file_path))
       # Construct the HumanMessage content, handling multimodal input
        content_parts.append(self.__create_text_contents(prompt))
        #print("content_parts",content_parts)
        return content_parts


    def open_pdf(self, pdf_path):
        from pdfminer.high_level import extract_text
        results = ""
        
        text = extract_text(pdf_path)
        results += f"\n\n--- Content from {os.path.basename(pdf_path)} ---\n\n"
        results += text.encode("utf-8", errors="replace").decode("utf-8") + "\n"
        results += "--- End of Content ---\n\n"

        return self.__create_text_contents(results)
    
    def open_text(self, text_path):
        results = ""
       
        
        with open(text_path, 'r', encoding='utf-8', errors='ignore') as file:
            text = file.read()
        results += f"\n\n--- Content from {os.path.basename(text_path)} ---\n\n"
        results += text.encode("utf-8", errors="replace").decode("utf-8") + "\n"
        results += "\n--- End of Content ---\n\n"

        return self.__create_text_contents(results)
    
    def open_words(self, word_paths):
        w2c=Word2Contents()
        result = w2c.open(word_paths)
        return result


    def get_response(self,
                     prompt: str,
                     #chat_history: Optional[MemorySaver] = None,
                     file_paths: Optional[List[str]] = None,
                     system_prompt: Optional[str] = None,
                     tools: Optional[List[Any]] = None,
                     callbacks: Optional[BaseCallbackHandler] = None,
                     **kwargs) -> str:


        # Construct the HumanMessage content, handling multimodal input
        current_message_content = self.get_current_message_content(prompt, file_paths)
        #history = chat_history if chat_history else []
        if self.is_memory_Nodata():
            agent_input = {
                "messages": [SystemMessage(content=self.system_prompt), HumanMessage(content=current_message_content)],
            }
        else:
            agent_input = {
                "messages": [
                    HumanMessage(content=current_message_content)
                ],
            }

        run_config = {
            "configurable": {"thread_id": self.memory_key},
            "callbacks": callbacks
        }
        # ツールが指定されている場合のみAgentExecutorを使用

        try:

            for output in self.app.stream(agent_input, config=run_config, **kwargs):
                for key, value in output.items():
                    if "messages" in value:
                        for msg in value["messages"]:
                            if hasattr(msg, 'content') and msg.content:
                                response = msg.content
                                # callbackがトークンを逐次print(上記のon_llm_new_tokenで処理)
                                pass  # ここはループでチャンクを扱うが、callbackがメイン表示


            state = self.app.get_state(run_config)
            # ★★★【重要】成功した場合の戻り値を返す処理を追加 ★★★
            if isinstance(response, dict):
                return response.get("output", f"AgentExecutor returned a dict without 'output' key: {response}")
            #print("response",response)
            return str(response)
        except InterruptedException:
            print(f"{type(self).__name__} ({self.model_name}): AgentExecutor execution interrupted.")
            raise # AIAgent側で処理するために再スロー
        except NotImplementedError as e_not_implemented:
            # bind_tools が実装されていないモデルでツールを使おうとした場合のエラー
            import traceback
            error_msg = f"{type(self).__name__} ({self.model_name}) error: The selected model or its LangChain wrapper does not support the required tool-calling feature (bind_tools). Please use a tool-compatible model (like `devstral` with the latest `langchain-ollama` package) or disable tools. Original error: {e_not_implemented}"
            print(error_msg)
            traceback.print_exc()
            return error_msg
        except ResponseError as e:
            error_msg = f"Ollama server returned an error: {e}"
            print(error_msg)
            # このエラーもGUIに返す
            return f"{type(self).__name__} AgentExecutor error: {error_msg}"
        except Exception as e_agent:
            import traceback
            print(f"{type(self).__name__}: Error during AgentExecutor execution: {e_agent}")
            traceback.print_exc() # デバッグ用に詳細なエラー情報を出力
            return f"{type(self).__name__} AgentExecutor error: {e_agent}"
        except BaseException as e_agent:
            import traceback
            print(f"{type(self).__name__}: Error during AgentExecutor execution: {e_agent}")
            traceback.print_exc() # デバッグ用に詳細なエラー情報を出力
            return f"{type(self).__name__} AgentExecutor error: {e_agent}"



    @property
    @abstractmethod # This must be abstract in LlmBase as it's specific to each LLM
    def supports_images(self) -> bool:
        pass