AI Agent:tools:LocalSearchClient.py:ソースコード

import sqlite3
import hnswlib
import numpy as np
from pathlib import Path
import json
import os
from os import path

if "__main__" == __name__:
    import tools_define
else:
    import tools.tools_define as tools_define

import hnswlib
#print(hnswlib.__file__)
#print(hasattr(hnswlib.Index, "write_index"))

class LocalIndices:
    def __init__(self, parent):
        self.parent = parent

    def exists(self, index: str):
        return index in self.parent.index_metadata

    def create(self, index: str, body=None):
        # すでに存在するなら何もしない
        if index in self.parent.index_metadata:
            return {"acknowledged": True, "created": False}
        self.parent._create_index_dir(index)
        # mapping を保存
        self.parent.index_metadata[index] = body or {}
        self.parent._save_index_metadata()
        #print("self.parent.index_metadata[index] =", self.parent.index_metadata[index])
        # --- mapping 解析 ---
        props = body.get("mappings", {}).get("properties", {})

        # embedding フィールド抽出
        embedding_fields = []
        for name, spec in props.items():
            if spec.get("type") == "knn_vector":
                embedding_fields.append(name)

        self.parent.embedding_fields = embedding_fields

        # dimension を取得
        if embedding_fields:
            first = embedding_fields[0]
            dim = props[first]["dimension"]
            self.parent.dim = dim

        # --- SQLite スキーマ更新 ---
        with self.parent._connect() as conn:
            # 既存カラム取得
            cur = conn.execute("PRAGMA table_info(items)")
            existing_cols = {row[1] for row in cur.fetchall()}

            for name, spec in props.items():
                if name in embedding_fields:
                    continue  # embedding は SQLite に保存しない

                if name not in existing_cols:
                    sql_type = "TEXT"
                    if spec.get("type") == "integer":
                        sql_type = "INTEGER"
                    elif spec.get("type") == "binary":
                        sql_type = "BLOB"

                    conn.execute(
                        f"ALTER TABLE items ADD COLUMN {name} {sql_type}"
                    )

        # --- hnswlib index 再構築 ---
        self.parent._init_hnsw_indexes()
        
        
        self.parent.save_index(index)

        return {"acknowledged": True, "created": True}

    def get_alias(self, name="*"):
        return {idx: {} for idx in self.parent.index_metadata.keys()}


class LocalSearchClient:
    def __init__(self, db_path="local.db", index_dir="vector_indexes", dim=768):#, index_name="base"):
        self._db_path_base = db_path
        self._index_dir_base = index_dir
        self.dim = dim
        self.embedding_fields = []
        self.index_metadata = {}
        #self._create_index_dir(index_name)

        # indices API
        self.indices = LocalIndices(self)

    def _create_index_dir(self, index):
        self._set_index_dir(index)
        # SQLite 初期化
        self._init_sqlite()

        # index metadata
        self.index_metadata = self._load_index_metadata()

        # hnswlib index
        self.vector_indexes = {}
        self._init_hnsw_indexes()          

    def _set_index_dir(self, index):
        data_dir =path.join(tools_define.DATA_SPACE_DIR, "LocalSearchClient")
        data_dir = data_dir.replace("\\", "/")
        Path(data_dir).mkdir(exist_ok=True)   

        data_dir =path.join(data_dir, index)
        data_dir = data_dir.replace("\\", "/")
        Path(data_dir).mkdir(exist_ok=True)          
     
        self.db_path = path.join(data_dir, self._db_path_base)
        self.db_path = self.db_path.replace("\\", "/")

        self.index_dir = path.join(data_dir, self._index_dir_base)
        self.index_dir = self.index_dir.replace("\\", "/")
        self.index_dir = Path(self.index_dir)
        Path(self.index_dir).mkdir(exist_ok=True)      


    # -------------------------
    # SQLite 接続
    # -------------------------
    def _connect(self):
        conn = sqlite3.connect(self.db_path)
        conn.row_factory = sqlite3.Row
        return conn

    # -------------------------
    # SQLite 初期化
    # -------------------------
    def _init_sqlite(self):
        with self._connect() as conn:
            conn.execute("""
            CREATE TABLE IF NOT EXISTS items (
                id TEXT PRIMARY KEY,
                hnsw_id INTEGER UNIQUE
            )
            """)

            conn.execute("""
            CREATE TABLE IF NOT EXISTS id_counter (
                name TEXT PRIMARY KEY,
                current INTEGER
            )
            """)

            conn.execute("""
            INSERT OR IGNORE INTO id_counter (name, current)
            VALUES ('hnsw', 0)
            """)

            conn.execute("""
            CREATE TABLE IF NOT EXISTS index_metadata (
                name TEXT PRIMARY KEY,
                body TEXT
            )
            """)

    # -------------------------
    # index metadata
    # -------------------------
    def _load_index_metadata(self):
        with self._connect() as conn:
            rows = conn.execute("SELECT name, body FROM index_metadata").fetchall()
            result = {}
            for r in rows:
                try:
                    result[r["name"]] = json.loads(r["body"])
                except Exception:
                    # JSON でない場合は空 dict として扱う
                    result[r["name"]] = {}
            return result

            

    def _save_index_metadata(self):
        with self._connect() as conn:
            conn.execute("DELETE FROM index_metadata")
            for name, body in self.index_metadata.items():
                conn.execute(
                    "INSERT INTO index_metadata (name, body) VALUES (?, ?)",
                    (name, json.dumps(body)),
                )
    # -------------------------
    # hnswlib index
    # -------------------------
    def _init_hnsw_indexes(self):
        self.vector_indexes = {}
        
        for field in self.embedding_fields:
            index = hnswlib.Index(space="l2", dim=self.dim)
            index.init_index(max_elements=200_000, ef_construction=200, M=16)
            index.set_ef(64)
            self.vector_indexes[field] = index

    def _save_hnsw_index(self, index: str):
        self._set_index_dir(index)
        index_dir = self.index_dir
        index_dir.mkdir(parents=True, exist_ok=True)

        for field, hnsw in self.vector_indexes.items():
            
            path = index_dir / f"{field}.bin"
            try:
                hnsw.save_index(str(path))
            except Exception as e:
                print(f"[WARN] Failed to save index for {field}: {e}")
    # -------------------------
    # hnsw_id 採番
    # -------------------------
    def _next_hnsw_id(self):
        with self._connect() as conn:
            cur = conn.execute("SELECT current FROM id_counter WHERE name='hnsw'")
            current = cur.fetchone()[0]
            next_id = current + 1
            conn.execute("UPDATE id_counter SET current=?", (next_id,))
            return next_id



    def _validate_embedding(self, field, vec):
        if vec is None:
            return False
        if not isinstance(vec, (list, tuple, np.ndarray)):
            return False

        vec = np.asarray(vec, dtype=np.float32)

        if vec.ndim != 1:
            return False

        # field ごとの dimension を参照
        # expected_dim = self.field_dims.get(field, self.dim)
        expected_dim = self.dim

        if len(vec) != expected_dim:
            return False

        return True

    # -------------------------
    # OpenSearch互換 API
    # -------------------------

    def index(self, index: str, id: str, body: dict):
        """OpenSearch の index() 相当"""
        self.load_index(index)
        hnsw_id = self._next_hnsw_id()
        
        # SQLite に保存
        with self._connect() as conn:
            # items に存在しないカラムを追加
            cur = conn.execute("PRAGMA table_info(items)")
            existing_cols = {row[1] for row in cur.fetchall()}

            for key in body.keys():
                # embedding 系は SQLite に保存しない

                if key in self.embedding_fields:
                    continue
                # カラムが無ければ追加

                if key not in existing_cols:
                    conn.execute(f"ALTER TABLE items ADD COLUMN {key} TEXT")

            # INSERT
            cols = ["id", "hnsw_id"]
            vals = [id, hnsw_id]

            for key, value in body.items():
                if key in self.embedding_fields:
                    continue

                cols.append(key)

                # list や dict は JSON に変換
                if isinstance(value, (list, dict)):
                    vals.append(json.dumps(value, ensure_ascii=False))
                else:
                    vals.append(value)

            placeholders = ",".join("?" for _ in cols)
            conn.execute(
                f"INSERT OR REPLACE INTO items ({','.join(cols)}) VALUES ({placeholders})",
                vals,
            )

        # embedding を hnswlib に追加
        for field in self.embedding_fields:
            if field in body:
                if not self._validate_embedding(field, body[field]):
                    # ログだけ出してスキップ
                    continue

                vec = np.asarray(body[field], dtype=np.float32)
                self.vector_indexes[field].add_items(
                    vec.reshape(1, -1),
                    np.array([hnsw_id], dtype=np.int64)
                )
        self.save_index(index)
        

        return {"result": "created", "_id": id}

    def get(self, index: str, id: str):
        self.load_index(index)
        with self._connect() as conn:
            row = conn.execute("SELECT * FROM items WHERE id=?", (id,)).fetchone()
            return dict(row) if row else None

    def update(self, index: str, id: str, body: dict):
        self.load_index(index)
        doc = body.get("doc", {})
        existing = self.get(index, id)
        if not existing:
            return None

        # hnsw_id は変えない
        hnsw_id = existing["hnsw_id"]

        # --- SQLite のメタデータ更新 ---
        with self._connect() as conn:
            # 既存のカラムを確認
            cur = conn.execute("PRAGMA table_info(items)")
            existing_cols = {row[1] for row in cur.fetchall()}

            # 新しいフィールドがあれば追加
            for key, value in doc.items():
                if key in self.embedding_fields:
                    continue
                if key not in existing_cols:
                    conn.execute(f"ALTER TABLE items ADD COLUMN {key} TEXT")

            # UPDATE 文を生成
            set_clause = ", ".join([f"{k}=?" for k in doc.keys() if k not in self.embedding_fields])
            values = [doc[k] for k in doc.keys() if k not in self.embedding_fields]

            if set_clause:
                conn.execute(
                    f"UPDATE items SET {set_clause} WHERE id=?",
                    values + [id]
                )

        # --- hnswlib の embedding 更新(上書き) ---
        for field in self.embedding_fields:
            if field in doc:
                if not self._validate_embedding(field, body[field]):
                    # ログだけ出してスキップ
                    continue

                vec = np.asarray(doc[field], dtype=np.float32)
                self.vector_indexes[field].add_items(
                    vec.reshape(1, -1),
                    np.array([hnsw_id], dtype=np.int64)
                )
        self.save_index(index)
        return {"result": "updated", "_id": id}


    def delete(self, index: str, id: str):
        self.load_index(index)
        with self._connect() as conn:
            row = conn.execute("SELECT hnsw_id FROM items WHERE id=?", (id,)).fetchone()
            if not row:
                return {"result": "not_found"}

            hnsw_id = row["hnsw_id"]

            # hnswlib mark_deleted
            for field in self.embedding_fields:
                try:
                    self.vector_indexes[field].mark_deleted(hnsw_id)
                except RuntimeError:
                    pass

            conn.execute("DELETE FROM items WHERE id=?", (id,))
        self.save_index(index)
        return {"result": "deleted"}

    def search(self, index: str, body: dict):
        self.load_index(index)

        query = body.get("query", {})

        # --- KNN 検索 ---
        if "knn" in query:
            knn = query["knn"]

            # 形式A: {"field": "...", "query_vector": [...], "k": 5}
            if "field" in knn:
                field = knn["field"]
                vec = np.asarray(knn["query_vector"], dtype=np.float32)
                k = knn.get("k", 5)

            # 形式B: {"embedding": {"vector": [...], "k": 5}}
            else:
                field = list(knn.keys())[0]
                inner = knn[field]
                vec = np.asarray(inner["vector"], dtype=np.float32)
                k = inner.get("k", 5)
            try:
                labels, distances = self.vector_indexes[field].knn_query(
                    vec.reshape(1, -1), k=k
                )
            except RuntimeError:
                # 壊れた index の場合は空結果を返す
                return {"hits": {"hits": []}}

            results = []
            with self._connect() as conn:
                for hnsw_id, dist in zip(labels[0], distances[0]):
                    row = conn.execute(
                        "SELECT * FROM items WHERE hnsw_id=?", (int(hnsw_id),)
                    ).fetchone()
                    if row:
                        results.append({"_source": dict(row), "_score": float(dist)})

            return {"hits": {"hits": results}}

        # --- メタデータ検索(簡易) ---
        if "match" in query:
            key, value = list(query["match"].items())[0]
            with self._connect() as conn:
                rows = conn.execute(
                    f"SELECT * FROM items WHERE {key} LIKE ?", (f"%{value}%",)
                ).fetchall()
                return {"hits": {"hits": [{"_source": dict(r)} for r in rows]}}

        return {"hits": {"hits": []}}
    
    def save_index(self, index: str):
        """hnswlib index を index ごとのフォルダに保存する"""
        self._set_index_dir(index)
        # index フォルダ作成
        index_dir = self.index_dir
        #index_dir.mkdir(parents=True, exist_ok=True)

        # embedding_fields ごとに保存
        for field, hnsw in self.vector_indexes.items():
            path = index_dir / f"{field}.bin"
            try:
                hnsw.save_index(str(path))
            except Exception as e:
                print(f"[WARN] Failed to save index for {field}: {e}")    
        self._save_hnsw_index(index)

    def load_index(self, index: str):
        """index ごとのフォルダから hnswlib index を読み込む"""
        pre_index = self.index_dir
        self._set_index_dir(index)
        index_dir = self.index_dir
        if index_dir == pre_index:
            return  # すでにロード済み
        self.index_metadata = self._load_index_metadata()
        #print("self.index_metadata =", self.index_metadata)
        # mapping が無い場合は何もしない
        if index not in self.index_metadata:
            print(f"[WARN] No metadata for index: {index}")
            #print("self.index_metadata", self.index_metadata)
            return

        # mapping から embedding_fields を復元
        props = self.index_metadata[index]["mappings"]["properties"]
        self.embedding_fields = [
            name for name, spec in props.items()
            if spec.get("type") == "knn_vector"
        ]
        # print("self.embedding_fields", self.embedding_fields)

        # dimension を復元
        if self.embedding_fields:
            first = self.embedding_fields[0]
            self.dim = props[first]["dimension"]

        # hnswlib index を読み込み or 新規作成
        self.vector_indexes = {}

        for field in self.embedding_fields:
            path = index_dir / f"{field}.bin"
            index_obj = hnswlib.Index(space="l2", dim=self.dim)

            if path.exists():
                try:
                    index_obj.load_index(str(path))
                    print(f"[INFO] Loaded index: {path}")
                except Exception as e:
                    print(f"[WARN] Failed to load {path}, reinitializing: {e}")
                    index_obj.init_index(max_elements=200_000, ef_construction=200, M=16)
            else:
                index_obj.init_index(max_elements=200_000, ef_construction=200, M=16)

            index_obj.set_ef(64)
            self.vector_indexes[field] = index_obj
        
        # index metadata
        #self.index_metadata = self._load_index_metadata()

        self._init_hnsw_indexes()