from langchain.tools import tool
from pydantic import BaseModel, Field


import sqlite3
from typing import List, Tuple


class ThreadedBBS:
    def __init__(self, db_path: str = "bbs.db"):
        self.db_path = db_path
        self._init_db()

    def _init_db(self):
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        # スレッドテーブル
        c.execute("""
            CREATE TABLE IF NOT EXISTS threads (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                title TEXT,
                purpose TEXT
            )
        """)

        # 投稿テーブル（スレッドに紐付け）
        c.execute("""
            CREATE TABLE IF NOT EXISTS posts (
                id INTEGER PRIMARY KEY AUTOINCREMENT,   -- 全体ユニークID
                thread_id INTEGER,
                seq INTEGER,                            -- スレッド内番号
                name TEXT,
                message TEXT,
                FOREIGN KEY(thread_id) REFERENCES threads(id)
            )
        """)

        conn.commit()
        conn.close()

    # --- スレッド管理 ---
    def create_thread(self, title: str, purpose: str = "") -> int:
        """新しいスレッドを作成してIDを返す"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        c.execute("INSERT INTO threads (title, purpose) VALUES (?, ?)", (title, purpose))
        conn.commit()
        tid = c.lastrowid
        conn.close()
        return tid


    def list_threads(self) -> List[Tuple[int, str, str]]:
        """スレッド一覧を取得 (id, title, purpose)"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        c.execute("SELECT * FROM threads ORDER BY id ASC")
        threads = c.fetchall()
        conn.close()
        return threads

    # --- 投稿管理 ---
    def add_post(self, thread_id: int, name: str, message: str) -> None:
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        # スレッド内の最大seqを取得して +1
        c.execute("SELECT COALESCE(MAX(seq), 0) + 1 FROM posts WHERE thread_id=?", (thread_id,))
        seq = c.fetchone()[0]
        c.execute("INSERT INTO posts (thread_id, seq, name, message) VALUES (?, ?, ?, ?)",
                  (thread_id, seq, name, message))
        conn.commit()
        conn.close()


    def list_posts(self, thread_id: int, limit: int = 10) -> List[Tuple[int, int, str, str]]:
        """指定スレッドの投稿一覧を取得（古い順）"""
        conn = sqlite3.connect(self.db_path)
        c = conn.cursor()
        c.execute("SELECT * FROM posts WHERE thread_id=? ORDER BY id ASC LIMIT ?", 
                  (thread_id, limit))
        posts = c.fetchall()
        conn.close()
        return posts



g_bbs = ThreadedBBS()

# --- スレッド作成 ---
class CreateThreadArgs(BaseModel):
    title: str = Field(..., description="スレッドのタイトル")
    purpose: str = Field(..., description="スレッドの目的")

@tool(args_schema=CreateThreadArgs)
def create_thread(title: str, purpose: str) -> str:
    """新しいスレッドを作成する"""
    tid = g_bbs.create_thread(title, purpose)
    return f"Thread {tid} created (title={title}, purpose={purpose})"


# --- 投稿追加 ---
class AddPostArgs(BaseModel):
    thread_id: int = Field(..., description="スレッドID")
    name: str = Field(..., description="投稿者名")
    message: str = Field(..., description="投稿内容")

@tool(args_schema=AddPostArgs)
def add_post(thread_id: int, name: str, message: str) -> str:
    """指定スレッドに投稿を追加する"""
    g_bbs.add_post(thread_id, name, message)
    return f"Post added to thread {thread_id} by {name}"


# --- スレッド一覧取得 ---
class ListThreadsArgs(BaseModel):
    pass  # 引数なし

@tool(args_schema=ListThreadsArgs)
def list_threads() -> str:
    """スレッド一覧を取得する"""
    threads = g_bbs.list_threads()
    return str([{"id": t[0], "title": t[1], "purpose": t[2]} for t in threads])


# --- 投稿一覧取得 ---
class ListPostsArgs(BaseModel):
    thread_id: int = Field(..., description="スレッドID")

@tool(args_schema=ListPostsArgs)
def list_posts(thread_id: int) -> str:
    """指定スレッドの投稿一覧を取得する"""
    posts = g_bbs.list_posts(thread_id)
    return str([{"seq": p[2], "name": p[3], "message": p[4]} for p in posts])



def get_threaded_bbs_tools() -> List:
    return [create_thread, add_post, list_threads, list_posts]

#if __name__ == "__main__":
#    bbs = ThreadedBBS()
#
#    # スレッド作成
#    thread_id = bbs.create_thread("最初のスレッド")
#
#    # 投稿追加
#    bbs.add_post(thread_id, "Alice", "こんにちは！")
#    bbs.add_post(thread_id, "Bob", "こんばんは！")
#
#    # スレッド一覧表示
#    threads = bbs.list_threads()
#    print("スレッド一覧:")
#    for t in threads:
#        print(f"ID: {t[0]}, タイトル: {t[1]}")
#
#    # 投稿一覧表示
#    posts = bbs.list_posts(thread_id)
#    print(f"\nスレッドID {thread_id} の投稿一覧:")
#    for p in posts:
#        print(f"ID: {p[0]}, 名前: {p[2]}, メッセージ: {p[3]}")