#!/usr/bin/env python3
import argparse
import json
import os
import shutil
import sqlite3
import sys
import urllib.request
import warnings
from datetime import datetime
from pathlib import Path
warnings.filterwarnings("ignore")
import numpy as np
try:
    import cv2
except Exception as exc:
    raise SystemExit(f"Missing opencv-contrib-python-headless: {exc}")
try:
    import pymysql
except Exception:
    pymysql = None
BASE_DIR = Path(__file__).resolve().parents[1]
MODEL_FILES = {
    "detector": {
        "name": "face_detection_yunet_2023mar.onnx",
        "url": "https://github.com/opencv/opencv_zoo/raw/main/models/face_detection_yunet/face_detection_yunet_2023mar.onnx",
    },
    "recognizer": {
        "name": "face_recognition_sface_2021dec.onnx",
        "url": "https://github.com/opencv/opencv_zoo/raw/main/models/face_recognition_sface/face_recognition_sface_2021dec.onnx",
    },
}
def load_env():
    env = {}
    env_path = BASE_DIR / ".env"
    if not env_path.exists():
        return env
    for line in env_path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line or line.startswith("#") or "=" not in line:
            continue
        key, value = line.split("=", 1)
        value = value.strip().strip("\"").strip("'")
        env[key.strip()] = value
    return env
def env_value(env, key, default=None):
    return os.getenv(key) or env.get(key) or default
def db_placeholder(db_type):
    return "?" if db_type == "sqlite" else "%s"
def connect_db(env, db_config_path=None):
    # If a db-config JSON file is provided, use it (overrides env)
    if db_config_path and Path(db_config_path).exists():
        try:
            with open(db_config_path, "r") as f:
                cfg = json.load(f)
            connection = cfg.get("connection", "sqlite")
            if connection == "sqlite":
                db_path = Path(cfg.get("database", str(BASE_DIR / "database" / "database.sqlite")))
                conn = sqlite3.connect(db_path)
                conn.row_factory = sqlite3.Row
                return conn, "sqlite"
            if connection in ("mysql", "mariadb"):
                if pymysql is None:
                    raise SystemExit("Missing pymysql dependency for MySQL/MariaDB.")
                conn = pymysql.connect(
                    host=cfg.get("host", "127.0.0.1"),
                    port=int(cfg.get("port", 3306)),
                    user=cfg.get("username", "root"),
                    password=cfg.get("password", ""),
                    database=cfg.get("database", "laravel"),
                    charset="utf8mb4",
                    cursorclass=pymysql.cursors.DictCursor,
                )
                return conn, "mysql"
        except (json.JSONDecodeError, KeyError):
            pass  # Fall through to env-based connection
    connection = env_value(env, "DB_CONNECTION", "sqlite")
    if connection == "sqlite":
        db_path = env_value(env, "DB_DATABASE", str(BASE_DIR / "database" / "database.sqlite"))
        db_path = Path(db_path)
        conn = sqlite3.connect(db_path)
        conn.row_factory = sqlite3.Row
        return conn, "sqlite"
    if connection in ("mysql", "mariadb"):
        if pymysql is None:
            raise SystemExit("Missing pymysql dependency for MySQL/MariaDB.")
        conn = pymysql.connect(
            host=env_value(env, "DB_HOST", "127.0.0.1"),
            port=int(env_value(env, "DB_PORT", 3306)),
            user=env_value(env, "DB_USERNAME", "root"),
            password=env_value(env, "DB_PASSWORD", ""),
            database=env_value(env, "DB_DATABASE", "laravel"),
            charset=env_value(env, "DB_CHARSET", "utf8mb4"),
            cursorclass=pymysql.cursors.DictCursor,
        )
        return conn, "mysql"
    raise SystemExit(f"Unsupported DB_CONNECTION: {connection}")
def resolve_media_path(disk, original_path, optimized_path):
    path = optimized_path or original_path
    if not path:
        return None
    path = str(path).lstrip("/\\")
    storage_root = BASE_DIR / "storage" / "app"
    if disk == "public":
        return storage_root / "public" / path
    if disk in ("client_media", "client_media_original"):
        return storage_root / "private" / path
    if disk in ("local", "private", "", None):
        return storage_root / path
    return None
def ensure_model(model_dir, key):
    model = MODEL_FILES[key]
    model_dir.mkdir(parents=True, exist_ok=True)
    target = model_dir / model["name"]
    if target.exists():
        return target
    tmp = target.with_suffix(".tmp")
    with urllib.request.urlopen(model["url"]) as response, open(tmp, "wb") as out:
        shutil.copyfileobj(response, out)
    tmp.replace(target)
    return target
def build_detector(model_dir, score_threshold):
    if not hasattr(cv2, "FaceDetectorYN_create"):
        raise SystemExit("OpenCV build missing FaceDetectorYN. Install opencv-contrib-python-headless.")
    model_path = ensure_model(model_dir, "detector")
    detector = cv2.FaceDetectorYN_create(str(model_path), "", (320, 320), score_threshold, 0.3, 5000)
    return detector
def build_recognizer(model_dir):
    if not hasattr(cv2, "FaceRecognizerSF_create"):
        raise SystemExit("OpenCV build missing FaceRecognizerSF. Install opencv-contrib-python-headless.")
    model_path = ensure_model(model_dir, "recognizer")
    recognizer = cv2.FaceRecognizerSF_create(str(model_path), "")
    return recognizer
SNAP_MAGIC = b"SNAPFILE"
SNAP_HEADER_SIZE = 32

def read_image(path):
    """Read image from disk, auto-stripping SNAPFILE obfuscation header."""
    path = str(path)

    # Check if file is obfuscated (starts with "SNAPFILE" magic)
    with open(path, "rb") as f:
        magic = f.read(8)
        if magic == SNAP_MAGIC:
            # Skip the full 32-byte header and read raw image bytes
            f.seek(SNAP_HEADER_SIZE)
            raw = f.read()
            nparr = np.frombuffer(raw, np.uint8)
            image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
            return image

    # Not obfuscated — read normally
    image = cv2.imread(path)
    if image is None:
        return None
    return image
def detect_faces(detector, image, min_score):
    height, width = image.shape[:2]
    detector.setInputSize((width, height))
    _ret, faces = detector.detect(image)
    if faces is None:
        return []
    results = []
    for face in faces:
        score = float(face[4])
        if score < min_score:
            continue
        results.append(face)
    return results
def face_embedding(recognizer, image, face):
    aligned = recognizer.alignCrop(image, face)
    embedding = recognizer.feature(aligned)
    if embedding is None:
        return None
    emb = embedding.flatten().astype(np.float32)
    norm = np.linalg.norm(emb)
    if norm > 0:
        emb = emb / norm
    return emb
def fetch_event_media(cursor, db_type, event_id):
    placeholder = db_placeholder(db_type)
    cursor.execute(
        f"SELECT id, disk, original_path, optimized_path, file_type, status FROM event_media WHERE event_id = {placeholder}",
        (event_id,),
    )
    return cursor.fetchall()
def fetch_indexed_media_ids(cursor, db_type, event_id):
    placeholder = db_placeholder(db_type)
    cursor.execute(
        f"SELECT DISTINCT event_media_id FROM event_media_faces WHERE event_id = {placeholder}",
        (event_id,),
    )
    rows = cursor.fetchall()
    ids = set()
    for row in rows:
        ids.add(int(row["event_media_id"]))
    return ids
def insert_face(cursor, db_type, event_id, media_id, face, embedding):
    placeholder = db_placeholder(db_type)
    now = datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
    embedding_blob = embedding.astype(np.float32).tobytes()
    bbox = [float(face[0]), float(face[1]), float(face[2]), float(face[3])]
    bbox_json = json.dumps(bbox)
    cursor.execute(
        f"""
        INSERT INTO event_media_faces
            (event_id, event_media_id, embedding, confidence, bbox, created_at, updated_at)
        VALUES
            ({placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder}, {placeholder})
        """,
        (event_id, media_id, embedding_blob, float(face[4]), bbox_json, now, now),
    )
def index_event(conn, db_type, detector, recognizer, event_id, min_score, force):
    cursor = conn.cursor()
    placeholder = db_placeholder(db_type)
    if force:
        cursor.execute(
            f"DELETE FROM event_media_faces WHERE event_id = {placeholder}",
            (event_id,),
        )
        conn.commit()
    indexed = set() if force else fetch_indexed_media_ids(cursor, db_type, event_id)
    media_rows = fetch_event_media(cursor, db_type, event_id)
    inserted = 0
    skipped = 0
    for row in media_rows:
        if row["file_type"] != "image":
            continue
        media_id = int(row["id"])
        if media_id in indexed:
            skipped += 1
            continue
        path = resolve_media_path(row["disk"], row["original_path"], row["optimized_path"])
        if not path or not path.exists():
            skipped += 1
            continue
        image = read_image(path)
        if image is None:
            skipped += 1
            continue
        faces = detect_faces(detector, image, min_score)
        if not faces:
            skipped += 1
            continue
        for face in faces:
            embedding = face_embedding(recognizer, image, face)
            if embedding is None or embedding.size == 0:
                continue
            insert_face(cursor, db_type, event_id, media_id, face, embedding)
            inserted += 1
        conn.commit()
    return {"inserted": inserted, "skipped": skipped}
def load_embeddings(cursor, db_type, event_id):
    placeholder = db_placeholder(db_type)
    cursor.execute(
        f"SELECT event_media_id, embedding FROM event_media_faces WHERE event_id = {placeholder}",
        (event_id,),
    )
    rows = cursor.fetchall()
    media_ids = []
    embeddings = []
    for row in rows:
        blob = row["embedding"]
        if isinstance(blob, memoryview):
            blob = blob.tobytes()
        if blob is None:
            continue
        emb = np.frombuffer(blob, dtype=np.float32)
        if emb.size == 0:
            continue
        media_ids.append(int(row["event_media_id"]))
        embeddings.append(emb)
    return media_ids, embeddings
def match_event(conn, db_type, detector, recognizer, event_id, image_path, threshold, limit, min_score):
    # Indexing is done separately (by the queue job when photos are processed).
    # Do NOT call index_event here — it would re-scan all photos on every search.
    image_path = Path(image_path)
    if not image_path.exists():
        return {"matched_ids": [], "scores": {}}
    image = read_image(image_path)
    if image is None:
        return {"matched_ids": [], "scores": {}}
    faces = detect_faces(detector, image, min_score)
    if not faces:
        return {"matched_ids": [], "scores": {}}
    faces = sorted(faces, key=lambda item: float(item[4]), reverse=True)
    query = face_embedding(recognizer, image, faces[0])
    if query is None or query.size == 0:
        return {"matched_ids": [], "scores": {}}
    media_ids, embeddings = load_embeddings(conn.cursor(), db_type, event_id)
    if not embeddings:
        return {"matched_ids": [], "scores": {}}
    matrix = np.vstack(embeddings)
    matrix_norm = np.linalg.norm(matrix, axis=1, keepdims=True)
    matrix = np.divide(matrix, matrix_norm, out=np.zeros_like(matrix), where=matrix_norm != 0)
    query_norm = np.linalg.norm(query)
    if query_norm == 0:
        return {"matched_ids": [], "scores": {}}
    query = query / query_norm
    scores = matrix.dot(query)
    best = {}
    for media_id, score in zip(media_ids, scores):
        score = float(score)
        if score < threshold:
            continue
        current = best.get(media_id)
        if current is None or score > current:
            best[media_id] = score
    sorted_ids = sorted(best.items(), key=lambda item: item[1], reverse=True)
    if limit:
        sorted_ids = sorted_ids[:limit]
    return {
        "matched_ids": [media_id for media_id, _score in sorted_ids],
        "scores": {str(media_id): score for media_id, score in sorted_ids},
    }
def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="command", required=True)
    index_parser = subparsers.add_parser("index")
    index_parser.add_argument("--event-id", type=int, required=True)
    index_parser.add_argument("--model-dir", default=str(BASE_DIR / "storage" / "app" / "face-models"))
    index_parser.add_argument("--min-score", type=float, default=0.6)
    index_parser.add_argument("--force", action="store_true")
    index_parser.add_argument("--no-force", action="store_true")
    index_parser.add_argument("--db-config", default=None)
    match_parser = subparsers.add_parser("match")
    match_parser.add_argument("--event-id", type=int, required=True)
    match_parser.add_argument("--image", required=True)
    match_parser.add_argument("--threshold", type=float, default=0.50)
    match_parser.add_argument("--limit", type=int, default=120)
    match_parser.add_argument("--model-dir", default=str(BASE_DIR / "storage" / "app" / "face-models"))
    match_parser.add_argument("--min-score", type=float, default=0.6)
    match_parser.add_argument("--db-config", default=None)
    args = parser.parse_args()
    env = load_env()
    conn, db_type = connect_db(env, getattr(args, "db_config", None))
    model_dir = Path(args.model_dir)
    detector = build_detector(model_dir, args.min_score)
    recognizer = build_recognizer(model_dir)
    if args.command == "index":
        force = bool(args.force)
        if args.no_force:
            force = False
        result = index_event(conn, db_type, detector, recognizer, args.event_id, args.min_score, force)
        print(json.dumps(result))
        return
    if args.command == "match":
        result = match_event(
            conn,
            db_type,
            detector,
            recognizer,
            args.event_id,
            args.image,
            args.threshold,
            args.limit,
            args.min_score,
        )
        print(json.dumps(result))
        return
if __name__ == "__main__":
    main()
