import ast
from typing import Optional, Dict, Set, List


class SafetyChecker(ast.NodeVisitor):
    """
    セキュリティレベル別 AST サンドボックス v5

    level:
        - "none"   : 構文エラーのみ検出
        - "low"    : exec / eval / __import__ などメタ実行系のみ禁止
        - "middle" : low +
                     - OS/ファイル/プロセス系の import / 呼び出し禁止
                     - 再帰関数検出（直接・間接）
        - "high"   : middle +
                     - import 全禁止
                     - def/class/while True 禁止
                     - 危険属性アクセス禁止
                     - type()/pickle/marshal/importlib 禁止
                     - メモリ攻撃（巨大リテラル・range・繰り返し）検出
                     ※ high では再帰チェックは行わない（関数定義禁止のため）
    """

    def __init__(
        self,
        level: str = "middle",
        large_number_threshold: int = 10**7,
    ):
        self.level = level.lower()
        self.errors: List[str] = []

        # レベルごとのポリシー
        self.banned_calls: Set[str] = set()
        self.banned_modules: Set[str] = set()
        self.banned_attributes: Set[str] = set()
        self.ban_all_imports: bool = False
        self.ban_functions: bool = False
        self.ban_classes: bool = False

        # import エイリアス解決用
        self.alias_map: Dict[str, str] = {}

        # 再帰検出用（middle のみ使用）
        self.current_function: Optional[str] = None
        self.function_calls: Dict[str, Set[str]] = {}

        # メモリ攻撃検出用閾値（high のみ使用）
        self.large_number_threshold = large_number_threshold

        self._configure_policy()

    # -------------------------
    # ポリシー設定
    # -------------------------
    def _configure_policy(self):
        if self.level == "none":
            return

        if self.level == "low":
            self.banned_calls |= {
                "exec",
                "eval",
                "__import__",
            }
            return

        if self.level == "middle":
            self.banned_calls |= {
                "exec",
                "eval",
                "__import__",
                "open",
            }
            self.banned_modules |= {
                "os",
                "sys",
                "subprocess",
                "pathlib",
                "shutil",
            }
            return

        if self.level == "high":
            # middle ベース
            self.banned_calls |= {
                "exec",
                "eval",
                "__import__",
                "open",
                "compile",
                "input",
                "type",
                "vars",
                "dir",
            }
            self.banned_modules |= {
                "os",
                "sys",
                "subprocess",
                "pathlib",
                "shutil",
                "socket",
                "http",
                "urllib",
                "ftplib",
                "ssl",
                "pickle",
                "marshal",
                "importlib",
            }

            self.ban_all_imports = True
            self.ban_functions = True
            self.ban_classes = True

            self.banned_attributes |= {
                "__subclasses__",
                "__mro__",
                "__globals__",
                "__dict__",
                "__getattribute__",
                "__setattr__",
                "__delattr__",
                "__reduce__",
                "__reduce_ex__",
                "__class__",
            }
            return

        raise ValueError(f"Unknown security level: {self.level}")

    # -------------------------
    # ノード訪問
    # -------------------------

    def visit_Import(self, node: ast.Import):
        if self.ban_all_imports:
            self.errors.append("import 文は禁止されています（security=high）")
        else:
            for alias in node.names:
                root = alias.name.split('.')[0]
                asname = alias.asname or alias.name
                self.alias_map[asname] = root

                if root in self.banned_modules:
                    self.errors.append(f"禁止された import: {alias.name}")
        self.generic_visit(node)

    def visit_ImportFrom(self, node: ast.ImportFrom):
        if self.ban_all_imports:
            self.errors.append("from import 文は禁止されています（security=high）")
        else:
            if node.module:
                root = node.module.split('.')[0]
                for alias in node.names:
                    asname = alias.asname or alias.name
                    self.alias_map[asname] = root

                if root in self.banned_modules:
                    self.errors.append(f"禁止された import: {node.module}")
        self.generic_visit(node)

    def visit_FunctionDef(self, node: ast.FunctionDef):
        if self.ban_functions:
            self.errors.append(f"関数定義は禁止されています（security=high）: {node.name}")

        # middle のみ再帰検出
        if self.level == "middle":
            prev = self.current_function
            self.current_function = node.name
            self.function_calls.setdefault(node.name, set())
            self.generic_visit(node)
            self.current_function = prev
        else:
            self.generic_visit(node)

    def visit_ClassDef(self, node: ast.ClassDef):
        if self.ban_classes:
            self.errors.append(f"クラス定義は禁止されています（security=high）: {node.name}")
        self.generic_visit(node)

    def visit_While(self, node: ast.While):
        if self.level == "high":
            if isinstance(node.test, ast.Constant) and node.test.value is True:
                self.errors.append("無限ループの可能性がある while True は禁止されています（security=high）")
        self.generic_visit(node)

    def visit_Attribute(self, node: ast.Attribute):
        if node.attr in self.banned_attributes:
            self.errors.append(f"危険な属性アクセス: {node.attr}")
        self.generic_visit(node)

    def visit_Call(self, node: ast.Call):
        func_name = self._get_func_name(node.func)
        full_name = func_name
        root = func_name.split('.')[0] if func_name else ""

        resolved_root = self.alias_map.get(root, root)

        if full_name in self.banned_calls or root in self.banned_calls:
            self.errors.append(f"禁止された関数呼び出し: {full_name or root}")

        if resolved_root in self.banned_modules:
            self.errors.append(f"危険なモジュール呼び出し: {full_name}")

        # middle のみ再帰検出
        if self.level == "middle" and self.current_function and func_name:
            self.function_calls[self.current_function].add(func_name)

        # high のみ巨大 range
        if self.level == "high" and func_name == "range":
            if node.args:
                arg0 = node.args[0]
                if isinstance(arg0, ast.Constant) and isinstance(arg0.value, int):
                    if arg0.value > self.large_number_threshold:
                        self.errors.append(
                            f"巨大 range は禁止されています: range({arg0.value})"
                        )

        self.generic_visit(node)

    def visit_BinOp(self, node: ast.BinOp):
        if self.level == "high" and isinstance(node.op, ast.Mult):
            if isinstance(node.right, ast.Constant) and isinstance(node.right.value, int):
                if node.right.value > self.large_number_threshold:
                    self.errors.append(
                        f"巨大な繰り返しによるリスト/文字列生成は禁止されています: * {node.right.value}"
                    )
        self.generic_visit(node)

    def visit_Constant(self, node: ast.Constant):
        if self.level == "high" and isinstance(node.value, int):
            if node.value > self.large_number_threshold:
                self.errors.append(
                    f"巨大な整数リテラルは禁止されています: {node.value}"
                )
        self.generic_visit(node)

    # -------------------------
    # ユーティリティ
    # -------------------------
    def _get_func_name(self, func) -> str:
        if isinstance(func, ast.Name):
            return func.id
        elif isinstance(func, ast.Attribute):
            base = self._get_func_name(func.value)
            return f"{base}.{func.attr}" if base else func.attr
        return ""


def _detect_recursive_functions(function_calls: Dict[str, Set[str]]) -> Set[str]:
    recursive: Set[str] = set()

    def dfs(start: str, current: str, visited: Set[str]):
        if current not in function_calls:
            return
        for callee in function_calls[current]:
            if callee == start:
                recursive.add(start)
                return
            if callee in visited:
                continue
            visited.add(callee)
            dfs(start, callee, visited)

    for func in function_calls.keys():
        dfs(func, func, set())

    return recursive


def check_code_safety(
    code: str,
    level: str = "middle",
    large_number_threshold: int = 10**7,
) -> List[str]:
    try:
        tree = ast.parse(code)
    except SyntaxError as e:
        return [f"構文エラー: {e}"]

    checker = SafetyChecker(level=level, large_number_threshold=large_number_threshold)
    checker.visit(tree)

    # middle のみ再帰チェック
    if level == "middle":
        recursive_funcs = _detect_recursive_functions(checker.function_calls)
        for name in sorted(recursive_funcs):
            checker.errors.append(
                f"再帰関数は禁止されています（security=middle）: {name}() が再帰的に呼び出されています"
            )

    return checker.errors


# -------------------------
# 使用例
# -------------------------
if __name__ == "__main__":
    samples = {
        # --- exec ---
        "exec_usage": """
exec("print(1)")
""",

        # --- eval ---
        "eval_usage": """
x = eval("2+2")
""",

        # --- __import__ ---
        "dunder_import_usage": """
m = __import__("os")
""",

        # --- 組み合わせ ---
        "mixed_usage": """
exec("x=1")
y = eval("x+1")
m = __import__("sys")
""",

        # --- low では OK（確認用） ---
        "import_os_ok": """
import os
os.listdir(".")
""",

        "function_ok": """
def f():
    return 1
f()
""",

        "class_ok": """
class X:
    pass
""",

        "attribute_ok": """
x = object.__subclasses__
""",

        # --- 基本 ---
        "safe": """
x = 1 + 2
y = x * 3
print(y)
""",

        # --- import 系 ---
        "import_os": """
import os
os.remove("test.txt")
""",

        "alias_import": """
import os as o
o.remove("test.txt")
""",

        "from_alias": """
from os import remove as rm
rm("test.txt")
""",

        # --- 関数内 import ---
        "def_with_import": """
def f():
    import os
    os.remove("x")
""",

        # --- 制御構文 ---
        "while_true": """
while True:
    pass
""",

        "class_def": """
class X:
    pass
""",

        # --- 危険属性 ---
        "dangerous_attribute": """
x = object.__subclasses__()
""",

        # --- type() ---
        "type_call": """
C = type("X", (), {})
""",

        # --- pickle ---
        "pickle_usage": """
import pickle
pickle.loads(b"test")
""",

        # --- メモリ攻撃（AST で検出できるもの） ---
        "huge_int_literal": """
x = 9999999999
""",

        "huge_range_literal": """
for i in range(9999999999):
    pass
""",

        "huge_list_literal": """
x = [0] * 9999999999
""",

        # --- メモリ攻撃（AST では値が分からないが、テストとして含める） ---
        "huge_int_expr": """
x = 10**12
""",

        "huge_range_expr": """
for i in range(10**12):
    pass
""",

        "huge_list_expr": """
x = [0] * (10**12)
""",

        # --- 再帰 ---
        "direct_recursion": """
def f():
    return f()
f()
""",

        "indirect_recursion": """
def f():
    return g()
def g():
    return f()
f()
""",
    }

    for name, code in samples.items():
        print(f"=== sample: {name} ===")
        for level in ["none", "low", "middle", "high"]:
            errors = check_code_safety(code, level=level)
            print(f"[level={level}]")
            if errors:
                for e in errors:
                    print("  -", e)
            else:
                print("  OK")
        print()