import sqlite3
import os  
from pathlib import Path
from ...core.shared import define
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from langchain_core.messages import ToolMessage  # ツール結果用
from typing import Optional, List
import json

from datetime import datetime, timezone
import inspect

def extract_user_input(messages):
    # HumanMessage の最後のものを探す
    for msg in reversed(messages):
        if isinstance(msg, HumanMessage):
            return msg.content
   
    raise ValueError("HumanMessage が見つかりません")

def extract_memory(messages):
    # HumanMessage と AIMessage のペアを履歴として抽出
    history = []
    for msg in messages:
        if isinstance(msg, (HumanMessage, AIMessage)):
            history.append(msg)
    
    return history

def extract_scratchpad(messages):
    # ToolMessage や FunctionCallMessage などを scratchpad として抽出
    scratchpad = []
    for msg in messages:
        if isinstance(msg, ToolMessage):
            scratchpad.append(msg)
    
    return scratchpad


class ChatDataBase():
    def __init__(self, id = "0", private_memory: bool = True, update_function = None, **kwargs):
        super().__init__()


        data_folder=self._get_kwargs_data(kwargs, "data_folder", "chat_data")
        chat_dir=os.path.join(define.DATA_SPACE_DIR,data_folder)


        if not os.path.exists(chat_dir):
            os.makedirs(chat_dir)
        file_name = self._get_kwargs_data(kwargs, "file_name", "chat_meta.db")
        data_path=os.path.join(chat_dir, file_name)
        self._app_metapath = Path(data_path)    



        #self.threads_meta_conn = sqlite3.connect(self._app_metapath)

        self.ensure_threads_meta_table()

        file_name = self._get_kwargs_data(kwargs, "file_name", "chat.db")
        data_path=os.path.join(chat_dir, file_name)

        self._persist_path = Path(data_path)
        ## threddate 関連----------------------
        temp_mode = self._get_kwargs_data(kwargs, "temp_mode", False)
        if temp_mode:
            self.memory = MemorySaver()  # メモリを一時的にメモリ上に作成
        
        else:
            #self.memory = SqliteSaver.from_conn_string(self._persist_path)
            self.memory_cm = SqliteSaver.from_conn_string(self._persist_path)
            self.memory = self.memory_cm.__enter__()


        self.thread_id = id
        self.private_memory = private_memory

        self.update_function = update_function
        self.__create_memory()



    def get_memory(self):
        return self.memory
    def get_thread_id(self):
        return self.thread_id
    def is_thread_Nodata(self):
        config = {
            "configurable": {
                "thread_id": self.thread_id
            }
        }

        saved_state = self.memory.get(config)
        if saved_state is None:
            return True
        if "value" not in saved_state:
            return True
        messages = saved_state["value"].get("messages", [])
        if len(messages) == 0:
            return True
        return False
    
    def set_private_memory(self, private_memory: bool):
        self.private_memory = private_memory
        self.__create_memory()
    def set_update_function(self, update_function):
        self.update_function = update_function

    def clear_thread(self):

        self._delete_thread(self.thread_id)


    def _delete_thread(self, thread_id):
        conn = sqlite3.connect(self._persist_path)
        
        cur = conn.cursor()

        # テーブル存在チェック
        tables = {row[0] for row in cur.execute(
            "SELECT name FROM sqlite_master WHERE type='table'"
        )}

        
        if "checkpoints" in tables:
            cur.execute("DELETE FROM checkpoints WHERE thread_id = ?", (thread_id,))

        if "checkpoint_blobs" in tables:
            cur.execute("DELETE FROM checkpoint_blobs WHERE thread_id = ?", (thread_id,))

        if "checkpoint_writes" in tables:
            cur.execute("DELETE FROM checkpoint_writes WHERE thread_id = ?", (thread_id,))

        conn.commit()
        conn.close()        

    def flatten_message(self, msg) -> dict:
        return {
            "id": getattr(msg, "id", None),
            "content": getattr(msg, "content", None),
            "additional_kwargs": getattr(msg, "additional_kwargs", {}),
            #"reasoning": getattr(msg, "reasoning", {}),
            "response_metadata": getattr(msg, "response_metadata", {}),
            "usage_metadata": getattr(msg, "usage_metadata", {}),
            "type": msg.__class__.__name__
        }
    def restore_message(self, d: dict) -> BaseMessage:
        msg_type = d.get("type", "AIMessage")
        cls_map = {
            "AIMessage": AIMessage,
            "HumanMessage": HumanMessage,
            "SystemMessage": SystemMessage,
            "ToolMessage": ToolMessage
        }
        cls = cls_map.get(msg_type, AIMessage)
        return cls(
            content=d.get("content", ""),
            additional_kwargs=d.get("additional_kwargs", {}),
            #reasoning=d.get("reasoning", {}),
            response_metadata=d.get("response_metadata", {}),
            id=d.get("id", None),
            usage_metadata=d.get("usage_metadata", {})
        )
        

        
    def __set_index_contents(self, index: int, type_str: str, contents: str):
        config = {
            "configurable": {
                "thread_id": self.thread_id,
            }
        }
            
        saved_state = self.memory.get(config)

        dict_data=self.flatten_message(saved_state["value"]["messages"][index])
        dict_data["type"] = type_str
        dict_data["content"] = contents
        new_msg=self.restore_message(dict_data)
        
        new_messages = saved_state["value"]["messages"].copy()
        new_messages[1] = new_msg
        new_state = {"messages": new_messages}
        self.update_function(config, new_state)
    
    def set_system_prompt(self, system_prompt: str):
        self.system_prompt = system_prompt
        # SQLiteSaverは "system" などのネームスペースを意識しなくても動くことが多いです
        # thread_id のみを設定 config は MemorySaver の時と同じ構造でOK

        if self.is_thread_Nodata():
            return False
        if "" == system_prompt:
            return False
        self.__set_index_contents(0, "SystemMessage", self.system_prompt)
        #state = self.app.get_state(config)
        #print("set_system_prompt:", state)
        #LlmBase.memory.put(config=config,checkpoint=history, metadata={}, new_versions=new_versions)
        return True
    def __create_memory(self):
        """
        会話履歴を保存するためのメモリを作成します。
        private_memoryフラグに基づいて、エージェント固有または共有のメモリを使用します。
        """

        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        print("self.private_memory", self.private_memory)        
 
        if False == self.private_memory:
            self.thread_id = "chat_history"

        self.clear_thread()

    def _get_kwargs_data(self, kwargs, key, default_value):
        if key in kwargs:
            return kwargs[key]
        return default_value
    

    def append_message(self, huma_message, ai_message):
        """
        メモリにHumanメッセージとAIメッセージを追加します。
        Args:
            huma_message (str): 追加するHumanメッセージの内容。
            ai_message (str): 追加するAIメッセージの内容。
        """
        self.append_human_message(huma_message)
        self.append_ai_message(ai_message)  
        
    def append_ai_message(self, content: str):
        """
        メモリにAIメッセージを追加します。
        Args:
            content (str): 追加するAIメッセージの内容。
        """
        config = {"configurable": {"thread_id": self.thread_id}}
        self.memory.put(config, {
            "messages": [
                AIMessage(content=content)
            ]
        })        
    def append_human_message(self, message):
        """
        メモリにHumanメッセージを追加します。
        Args:
            message (str): 追加するHumanメッセージの内容。
        """
        config = {"configurable": {"thread_id": self.thread_id}}
        self.memory.put(config, {
            "messages": [
                HumanMessage(content=message)
            ]
        })


    def update_input(
        self,
        imput_prompt_text: str,
        image_data_urls: Optional[List[str]] = None,
        target_index: int = None):
        """
        会話履歴内のユーザー入力（HumanMessage）を更新します。
        - target_index が指定されればその位置の HumanMessage を更新
        - 指定がなければ最後の HumanMessage を更新

        Args:
            imput_prompt_text (str): 更新するユーザー入力の文字列
            image_data_urls (Optional[List[str]]): 画像URLリスト
            target_index (Optional[int]): 更新対象のインデックス（Noneなら末尾）
        """
        pass        
    def update_input_text(
        self,
        imput_prompt_text: str,
        target_index: int = None):
        """
        会話履歴内のユーザー入力（HumanMessage）を更新します。
        - target_index が指定されればその位置の HumanMessage を更新
        - 指定がなければ最後の HumanMessage を更新

        Args:
            imput_prompt_text (str): 更新するユーザー入力の文字列
            image_data_urls (Optional[List[str]]): 画像URLリスト
            target_index (Optional[int]): 更新対象のインデックス（Noneなら末尾）
        """

        # stateを取得
        config = {
            "configurable": {
                "thread_id": self.thread_id,
            }
        }
  
        saved_state = self.memory.get(config)
        messages = saved_state["value"]["messages"]



        # 更新対象インデックスを決定

        if target_index is None:
            for i in range(len(messages) - 1, -1, -1):
                if messages[i].__class__.__name__ == "HumanMessage":
                    target_index = i
                    break
        #if isinstance(target_index,list):
        #    print("target_index is list", target_index)
    
        if target_index is None or target_index < 0 or len(messages) <= target_index :
            print("指定されたインデックスに HumanMessage が存在しません")
            return
        ## flattenしたdictを返す
        #if target_index is not None and 0 <= target_index < len(messages):
        self.__set_index_contents(target_index, "HumanMessage", imput_prompt_text)
######################################################################
    def ensure_threads_meta_table(self):
        #テーブルがなければ作りあれば何もしない。
        conn = sqlite3.connect(self._app_metapath)
        cursor = conn.cursor()
        cursor.execute("""
            SELECT name FROM sqlite_master
            WHERE type='table' AND name='threads_meta';
        """)
        exists = cursor.fetchone()
        
        if exists:
            cursor.execute("DELETE FROM threads_meta WHERE title IS NULL;")
        self._create_meta_table(cursor)
        result= cursor.lastrowid
        self.thread_id=result
        print("ensure_threads_meta_table result",result)
        
        conn.commit()
        conn.close()
        #self.thread_id=self.create_thread_id()

        
    def _create_meta_table(self,cursor):
        cursor.execute("""
            CREATE TABLE IF NOT EXISTS threads_meta (
                thread_id INTEGER PRIMARY KEY AUTOINCREMENT,
                title TEXT NOT NULL DEFAULT '無題',
                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
                updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
            )
        """)        
######################################################################
    def set_thread_id(self, thread_id):
        print("set_thread_id thread_id",thread_id)
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        self.thread_id = thread_id
        
        
    def get_thread_ids_list(self):
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        
        conn = sqlite3.connect(self._app_metapath)
        cursor = conn.cursor()

        cursor.execute("SELECT thread_id FROM threads_meta")
        rows = cursor.fetchall()
        conn.close()
        
        return [r[0] for r in rows]

    def set_thread_title(self, thread_id, title):
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        conn = sqlite3.connect(self._app_metapath)
        print ("set_thread_title thread_id",thread_id)
        conn.execute(
            "INSERT OR REPLACE INTO threads_meta (thread_id, title, created_at, updated_at) VALUES (?, ?, datetime('now'), datetime('now'))",
            (thread_id, title)
        )
        conn.commit()    
        conn.close()


    def update_thread_timestamp(self, thread_id):
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        conn = sqlite3.connect(self._app_metapath)
        print ("update_thread_timestamp thread_id",thread_id)
        conn.execute(
            "UPDATE threads_meta SET updated_at = datetime('now') WHERE thread_id = ?",
            (thread_id,)
        )
        conn.commit()
        conn.close()

    def get_all_threads_meta_sorted(self):
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        conn = sqlite3.connect(self._app_metapath)
        cursor = conn.cursor()

        cursor.execute("""
            SELECT thread_id, title, created_at, updated_at
            FROM threads_meta
            ORDER BY updated_at DESC
        """)

        rows = cursor.fetchall()
        conn.close()
        #print("rows",rows)
        return [
            {
                "thread_id": r[0],
                "title": r[1],
                "created_at": r[2],
                "updated_at": r[3],
            }
            for r in rows
        ]        
    def create_thread_id(self):
        
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        conn = sqlite3.connect(self._app_metapath)
        cursor = conn.cursor()

        self._create_meta_table(cursor)

        
        cursor.execute("INSERT INTO threads_meta DEFAULT VALUES")
        self.thread_id = cursor.lastrowid
        #print("create_thread_id result",result)
        
        conn.commit()
        conn.close()
        return self.thread_id 
    def delete_thread(self, thread_id):
        #print("delete_thread thread_id",thread_id)
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        conn = sqlite3.connect(self._app_metapath)
        cursor = conn.cursor()
        cursor.execute("DELETE FROM threads_meta WHERE thread_id = ?", (thread_id,))
        conn.commit()
        conn.close()
        self._delete_thread(thread_id)

    def close(self):
        #stack = inspect.stack()
        #for frame in stack:
        #    print(f"{frame.filename}:{frame.lineno} - {frame.function}")
        #self.threads_meta_conn.close()
        super().close()
    def __del__(self):
        #self.threads_meta_conn.close()
        if hasattr(self, 'memory_cm'):
            self.memory_cm.__exit__(None, None, None)
    
    #######################
    def get_chat_data_list(self):
        config = {
            "configurable": {
                "thread_id": self.thread_id,
            }
        }
        #print("config",config)        
        saved_state = self.memory.get(config)
        #print("saved_state",saved_state)  
        if None is saved_state:
            print("saved_state is None",config)
            return []
        listd=saved_state['channel_values']["messages"]
        
        result=[]
        for data in listd:
            buf={}
            if isinstance(data, HumanMessage):
                #print("\nuser----\n")
                count=0
                for contenst in data.content:
                    print(contenst.keys())
                    print(count)
                    count+=1
                    if "user" in buf:
                        if "text" in contenst:
                            buf["user"] += contenst["text"]
                    else:
                        if "text" in contenst:
                            buf["user"] = contenst["text"]
                result.append(buf)
            elif isinstance(data,AIMessage):
                #print("\nai----\n")
                
                if "reasoning_content" in data.additional_kwargs:
                    #print("reasoning_content---\n")
                    print(data.additional_kwargs["reasoning_content"])
                    #print("reasoning_content---end\n")
                    buf["thinking"]=data.additional_kwargs["reasoning_content"]
                if "model_name" in data. response_metadata:
                    if "model_provider" in data. response_metadata:
                        buf["model"]=data. response_metadata["model_provider"]+":"+data. response_metadata["model_name"]
                    else:
                        buf["model"]=data. response_metadata["model_name"]
                else:
                    buf["model"]="unknown"
                buf["ai"]=data.content
                result.append(buf)

        #self.print_chat_threads_raw_data()
            
        return  result
    
    def is_first_message(self):
        config = {
            "configurable": {
                "thread_id": self.thread_id,
            }
        }
        #print("config",config)        
        saved_state = self.memory.get(config)
        #print("saved_state",saved_state)  
        listd=saved_state['channel_values']["messages"]        
        if len(listd)<=2:
            return True
        return False
    
    def print_chat_threads_raw_data(self):
        conn = sqlite3.connect(self._persist_path)
        cur = conn.cursor()

        cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
        print("cur.fetchall 01",cur.fetchall())

        cur.execute("PRAGMA table_info(checkpoints)")
        print("checkpoints",cur.fetchall())
        cur.execute("PRAGMA table_info(checkpoint_blobs)")
        print("checkpoint_blobs",cur.fetchall())
        cur.execute("PRAGMA table_info(checkpoint_writes)")
        print("checkpoint_writes",cur.fetchall())
        cur.execute("PRAGMA table_info(writes)")
        print("writes",cur.fetchall())
        cur.execute("PRAGMA table_info(sqlite_sequence)")
        print("sqlite_sequence",cur.fetchall())        


        cur.execute("""
            SELECT checkpoint_id, idx, channel, type, value
            FROM writes
            WHERE thread_id = ?
            ORDER BY checkpoint_id ASC, idx ASC
        """, (self.thread_id,))

        cur.execute("SELECT thread_id, checkpoint_id, idx, channel, type, value FROM writes ORDER BY thread_id, checkpoint_id, idx")
        rows = cur.fetchall()


        print("rows:", rows)   # ← これで中身を確認
        messages = []
        for thread_id, checkpoint_id, idx, channel, type_, blob in rows:
            try:
                print("----")
                print("thread_id:", thread_id)
                print("checkpoint_id:", checkpoint_id)
                print("idx:", idx)
                print("channel:", channel)
                print("type:", type_)
                print("json:", json.loads(blob))
            except Exception as e:
                print("decode error:", e)
                print("raw blob:", blob)

        for checkpoint_id, idx, channel, type_, blob, *_ in rows:
            if channel != "messages":
                continue
            
            msg = json.loads(blob)
            messages.append(msg)
    
        print("messages",messages)


        turns=self.get_chat_turns(messages)
        print("turns",turns)
        conn.close()
        return turns
    def get_chat_turns(self, messages):
        turns = []
        current = None
        for msg in messages:
            if msg["role"] == "user":
                if current:
                    turns.append(current)
                current = {"user": msg["content"], "assistant": None, "thinking": None}
            elif msg["role"] == "assistant":
                if "thinking" in msg:
                    current["thinking"] = msg["thinking"]
                else:
                    current["assistant"] = msg["content"]

            if current:
                turns.append(current)
        return turns
    
    ####################################################################################
    # ---------- モデルソート ----------
    def _ensure_model_registry_table(self):
        """model_registry テーブルが無ければ作成する（内部呼び出し用）"""
        conn = sqlite3.connect(self._app_metapath)
        cur = conn.cursor()
        cur.execute("""
            CREATE TABLE IF NOT EXISTS model_registry (
                id               INTEGER PRIMARY KEY AUTOINCREMENT,
                model_provider   TEXT NOT NULL,
                model_name       TEXT NOT NULL,
                registered_at    TEXT NOT NULL DEFAULT (datetime('now'))
            );
        """)
        conn.commit()
        conn.close()

    # ① 重複登録を許可した登録メソッド
    def register_model(self, model_provider: str, model_name: str) -> int:
        """
        指定したモデル名／プロバイダーを **必ず 1 行追加** で登録。
        既に同一ペアがあってもそのまま新しい行として保存するので
        「重複登録」は許可される。

        Returns:
            int: 新しく挿入された行の id（内部利用・デバッグに便利）
        """
        self._ensure_model_registry_table()
        conn = sqlite3.connect(self._app_metapath)
        cur = conn.cursor()

        # 現在の登録数を取得
        cur.execute("SELECT COUNT(*) FROM model_registry")
        cnt = cur.fetchone()[0]

        # 超過している場合は最古 (MIN id) を削除して上限 100 を維持
        if cnt >= 100:
            cur.execute(
                "DELETE FROM model_registry WHERE id=(SELECT MIN(id) FROM model_registry);"
            )
            conn.commit()

        # INSERT → id を取得
        cur.execute(
            "INSERT INTO model_registry (model_provider, model_name) VALUES (?, ?)",
            (model_provider, model_name),
        )
        new_id = cur.lastrowid
        conn.commit()
        conn.close()
        return new_id

    # ② 重要度係数（1 行ごとのスコア）算出ロジック
    def _calc_importance(self, reg_at: str) -> float:
        """
        登録時刻から現在時刻までの経過時間でスコアを算出。
        age_seconds が短いほどスコアが大きくなる（1/(age+1)）。
        """
        now = datetime.now(timezone.utc)
        reg_dt = datetime.fromisoformat(reg_at).replace(tzinfo=timezone.utc)
        age_seconds = (now - reg_dt).total_seconds()
        return 1.0 / (age_seconds + 1)

    # ③ 取得・集計・ソートロジック
    def get_sorted_models(self) -> List[dict]:
        """
        全登録行を取得し、同じ (model_provider, model_name) の組み合わせについて
        **重要度係数の合計** で降順にソートしたリストを返す。

        戻り値は 1 件ずつの dict で、キーは以下の通り：
            - model_provider
            - model_name
            - count          : 同一ペアが登録された回数
            - importance_sum   : 係数の合計（ソートに使用）
            - latest_registered_at : 最新の registered_at (ISO8601)
        """
        self._ensure_model_registry_table()
        conn = sqlite3.connect(self._app_metapath)
        cur = conn.cursor()

        rows = cur.execute(
            "SELECT id, model_provider, model_name, registered_at "
            "FROM model_registry ORDER BY id ASC"
        )
        # 1 行ごとの重要度スコアを計算し、キーごとに集約
        aggregate: dict[tuple[str, str], dict] = {}
        for _id, provider, name, reg_at in rows:
            importance = self._calc_importance(reg_at)
            key = (provider, name)

            if key not in aggregate:
                aggregate[key] = {
                    "model_provider": provider,
                    "model_name": name,
                    "count": 0,
                    "importance_sum": 0.0,
                    "latest_registered_at": reg_at,
                }
            agg = aggregate[key]
            agg["count"] += 1
            agg["importance_sum"] += importance
            # 最新の登録時刻を保持（ソートや表示に便利）
            if reg_at > agg["latest_registered_at"]:
                agg["latest_registered_at"] = reg_at

        conn.close()

        # 重要度合計で降順ソート → ソートキーは -importance_sum
        sorted_aggs = sorted(
            aggregate.values(),
            key=lambda x: x["importance_sum"],
            reverse=True,
        )
        return sorted_aggs
    # --------------------------------------------------------------
