import os
import secrets
import sqlite3
import time
import hashlib
import base64
import asyncio
import logging
import threading
import hmac
from collections import defaultdict
from contextlib import asynccontextmanager
from pathlib import Path
 
from mcp.server.fastmcp import FastMCP
from starlette.applications import Starlette
from starlette.routing import Route, Mount
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, Response, RedirectResponse
 
try:
    import fitz as _fitz
except ImportError:
    _fitz = None
 
# ---------------------------------------------------------------------------
# LOGGING
# ---------------------------------------------------------------------------
 
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
    datefmt="%Y-%m-%dT%H:%M:%S",
)
log = logging.getLogger("mcp-fileserver")
 
# ---------------------------------------------------------------------------
# CONFIGURAZIONE
# ---------------------------------------------------------------------------
 
def _require_env(name: str) -> str:
    value = os.environ.get(name)
    if not value:
        raise RuntimeError(f"Variabile obbligatoria non impostata: {name}")
    return value
 
 
BASE_DIR            = _require_env("MCP_BASE_DIR")
OAUTH_CLIENT_ID     = _require_env("MCP_CLIENT_ID")
OAUTH_CLIENT_SECRET = _require_env("MCP_CLIENT_SECRET")
ALLOWED_HOST        = _require_env("MCP_ALLOWED_HOST")
BIND_PORT           = int(os.environ.get("MCP_PORT", "8000"))
TOKEN_EXPIRY        = int(os.environ.get("MCP_TOKEN_EXPIRY", "86400"))
DB_PATH             = os.environ.get("MCP_DB_PATH", "/opt/mcp-fileserver/mcp.db")
 
MAX_INDEX_FILE_BYTES = 50 * 1024 * 1024  # 50 MB
 
log.info(
    "BASE_DIR=%s ALLOWED_HOST=%s PORT=%d TOKEN_EXPIRY=%s DB_PATH=%s",
    BASE_DIR, ALLOWED_HOST, BIND_PORT, TOKEN_EXPIRY, DB_PATH,
)
 
# ---------------------------------------------------------------------------
# DATABASE
# ---------------------------------------------------------------------------
 
_db_write_lock = threading.Lock()
 
 
def _db_connect() -> sqlite3.Connection:
    conn = sqlite3.connect(DB_PATH, check_same_thread=False)
    conn.execute("PRAGMA journal_mode=WAL")
    conn.execute("PRAGMA synchronous=NORMAL")
    return conn
 
 
def _db_init() -> None:
    with _db_write_lock:
        conn = _db_connect()
        conn.executescript("""
            CREATE TABLE IF NOT EXISTS tokens (
                token   TEXT PRIMARY KEY,
                expires REAL NOT NULL
            );
 
            CREATE TABLE IF NOT EXISTS file_index (
                path        TEXT PRIMARY KEY,
                mtime       REAL NOT NULL,
                size        INTEGER NOT NULL,
                indexed_at  REAL NOT NULL
            );
 
            CREATE VIRTUAL TABLE IF NOT EXISTS file_content
            USING fts5(
                path     UNINDEXED,
                content,
                tokenize = 'unicode61 remove_diacritics 1'
            );
        """)
        conn.commit()
        conn.close()
 
 
def _db_token_save(token: str, expires: float) -> None:
    with _db_write_lock:
        conn = _db_connect()
        conn.execute(
            "INSERT OR REPLACE INTO tokens (token, expires) VALUES (?, ?)",
            (token, expires),
        )
        conn.commit()
        conn.close()
 
 
def _db_token_delete(token: str) -> None:
    with _db_write_lock:
        conn = _db_connect()
        conn.execute("DELETE FROM tokens WHERE token = ?", (token,))
        conn.commit()
        conn.close()
 
 
def _db_token_cleanup(now: float) -> int:
    with _db_write_lock:
        conn = _db_connect()
        count = conn.execute(
            "DELETE FROM tokens WHERE expires <= ?", (now,)
        ).rowcount
        conn.commit()
        conn.close()
    return count
 
 
def _db_load_active_tokens() -> dict[str, float]:
    now = time.time()
    conn = _db_connect()
    rows = conn.execute(
        "SELECT token, expires FROM tokens WHERE expires > ?", (now,)
    ).fetchall()
    conn.close()
    return {token: expires for token, expires in rows}
 
 
def _db_index_status(rel_path: str) -> tuple[float, int] | None:
    conn = _db_connect()
    row = conn.execute(
        "SELECT mtime, size FROM file_index WHERE path = ?", (rel_path,)
    ).fetchone()
    conn.close()
    return row
 
 
def _db_upsert_file(rel_path: str, testo: str, mtime: float, size: int) -> None:
    now = time.time()
    with _db_write_lock:
        conn = _db_connect()
        conn.execute("DELETE FROM file_content WHERE path = ?", (rel_path,))
        conn.execute(
            "INSERT INTO file_content (path, content) VALUES (?, ?)",
            (rel_path, testo),
        )
        conn.execute(
            "INSERT OR REPLACE INTO file_index (path, mtime, size, indexed_at) "
            "VALUES (?, ?, ?, ?)",
            (rel_path, mtime, size, now),
        )
        conn.commit()
        conn.close()
 
 
def _db_remove_file(rel_path: str) -> None:
    with _db_write_lock:
        conn = _db_connect()
        conn.execute("DELETE FROM file_content WHERE path = ?", (rel_path,))
        conn.execute("DELETE FROM file_index WHERE path = ?", (rel_path,))
        conn.commit()
        conn.close()
 
 
def _db_search(keyword: str, max_results: int) -> list[tuple[str, str]]:
    conn = _db_connect()
    rows = conn.execute(
        """
        SELECT path,
               snippet(file_content, 1, '[', ']', '...', 25)
        FROM   file_content
        WHERE  file_content MATCH ?
        LIMIT  ?
        """,
        (keyword, max_results),
    ).fetchall()
    conn.close()
    return rows
 
 
def _db_index_count() -> int:
    conn = _db_connect()
    count = conn.execute("SELECT COUNT(*) FROM file_index").fetchone()[0]
    conn.close()
    return count
 
# ---------------------------------------------------------------------------
# STATO IN MEMORIA
# ---------------------------------------------------------------------------
 
auth_codes: dict[str, dict] = {}
active_tokens: dict[str, float] = {}
 
_index_state = {
    "running": False,
    "total":   0,
    "done":    0,
    "errors":  0,
}
 
# ---------------------------------------------------------------------------
# INDICIZZATORE
# ---------------------------------------------------------------------------
 
ESTENSIONI_LEGGIBILI = {
    ".pdf", ".docx", ".xlsx", ".xls",
    ".txt", ".md", ".csv", ".json", ".xml", ".html", ".htm",
}
 
 
def _needs_reindex(rel_path: str, full_path: str) -> bool:
    try:
        stat = os.stat(full_path)
    except OSError:
        return False
    status = _db_index_status(rel_path)
    if status is None:
        return True
    indexed_mtime, indexed_size = status
    return stat.st_mtime != indexed_mtime or stat.st_size != indexed_size
 
 
def _index_single_file(full_path: str, rel_path: str) -> bool:
    try:
        stat = os.stat(full_path)
        if stat.st_size > MAX_INDEX_FILE_BYTES:
            log.info("Saltato (troppo grande): %s", rel_path)
            return False
        testo = _estrai_testo(full_path, max_chars=500_000)
        if testo.startswith("[Errore") or testo.startswith("[PDF protetto"):
            log.warning("Saltato per errore estrazione: %s", rel_path)
            return False
        _db_upsert_file(rel_path, testo, stat.st_mtime, stat.st_size)
        return True
    except Exception as e:
        log.error("Errore indicizzazione %s: %s", rel_path, e)
        return False
 
 
def _run_initial_indexing() -> None:
    _index_state["running"] = True
    log.info("Indicizzazione iniziale avviata...")
 
    da_indicizzare = []
    for root, dirs, files in os.walk(BASE_DIR):
        dirs.sort()
        for nome in sorted(files):
            ext = Path(nome).suffix.lower()
            if ext not in ESTENSIONI_LEGGIBILI:
                continue
            full_path = os.path.join(root, nome)
            rel_path = os.path.relpath(full_path, BASE_DIR)
            if _needs_reindex(rel_path, full_path):
                da_indicizzare.append((full_path, rel_path))
 
    _index_state["total"] = len(da_indicizzare)
    _index_state["done"] = 0
    _index_state["errors"] = 0
    log.info("File da indicizzare: %d", len(da_indicizzare))
 
    for full_path, rel_path in da_indicizzare:
        ok = _index_single_file(full_path, rel_path)
        if ok:
            _index_state["done"] += 1
        else:
            _index_state["errors"] += 1
 
    _index_state["running"] = False
    log.info(
        "Indicizzazione completata: %d indicizzati, %d errori, %d totale nell'indice",
        _index_state["done"],
        _index_state["errors"],
        _db_index_count(),
    )
 
 
def _start_indexing_thread() -> threading.Thread:
    t = threading.Thread(target=_run_initial_indexing, daemon=True, name="indexer")
    t.start()
    return t
 
# ---------------------------------------------------------------------------
# WATCHER
# ---------------------------------------------------------------------------
 
def _start_file_watcher() -> None:
    try:
        from watchdog.observers import Observer
        from watchdog.events import FileSystemEventHandler
    except ImportError:
        log.warning("watchdog non installato: l'indice non si aggiornerà in tempo reale.")
        return
 
    class _Handler(FileSystemEventHandler):
        def _rel(self, path: str) -> str:
            return os.path.relpath(path, BASE_DIR)
 
        def _eligible(self, path: str) -> bool:
            return Path(path).suffix.lower() in ESTENSIONI_LEGGIBILI
 
        def on_created(self, event):
            if event.is_directory or not self._eligible(event.src_path):
                return
            rel = self._rel(event.src_path)
            log.info("Watcher: nuovo file: %s", rel)
            _index_single_file(event.src_path, rel)
 
        def on_modified(self, event):
            if event.is_directory or not self._eligible(event.src_path):
                return
            rel = self._rel(event.src_path)
            log.info("Watcher: file modificato: %s", rel)
            _index_single_file(event.src_path, rel)
 
        def on_deleted(self, event):
            if event.is_directory or not self._eligible(event.src_path):
                return
            rel = self._rel(event.src_path)
            log.info("Watcher: file eliminato: %s", rel)
            _db_remove_file(rel)
 
        def on_moved(self, event):
            if event.is_directory:
                return
            if self._eligible(event.src_path):
                _db_remove_file(self._rel(event.src_path))
            if self._eligible(event.dest_path):
                _index_single_file(event.dest_path, self._rel(event.dest_path))
 
    observer = Observer()
    observer.schedule(_Handler(), BASE_DIR, recursive=True)
    observer.daemon = True
    observer.start()
    log.info("Watcher avviato su %s", BASE_DIR)
 
# ---------------------------------------------------------------------------
# RATE LIMITING
# ---------------------------------------------------------------------------
 
_rate_lock = threading.Lock()
_rate_counters: dict[str, list[float]] = defaultdict(list)
 
RATE_LIMIT_WINDOW    = 60
RATE_LIMIT_MAX_AUTH  = 10   # endpoint /authorize e /token: più restrittivo
RATE_LIMIT_MAX_API   = 60   # endpoint MCP: più permissivo
 
 
def _is_rate_limited(ip: str, limit: int = RATE_LIMIT_MAX_API) -> bool:
    now = time.time()
    with _rate_lock:
        key = f"{ip}:{limit}"
        ts = _rate_counters[key]
        _rate_counters[key] = [t for t in ts if now - t < RATE_LIMIT_WINDOW]
        if len(_rate_counters[key]) >= limit:
            return True
        _rate_counters[key].append(now)
        return False
 
# ---------------------------------------------------------------------------
# CLEANUP PERIODICO
# ---------------------------------------------------------------------------
 
async def _cleanup_expired() -> None:
    while True:
        await asyncio.sleep(600)
        now = time.time()
 
        expired_tokens = [t for t, exp in list(active_tokens.items()) if exp <= now]
        for t in expired_tokens:
            active_tokens.pop(t, None)
 
        expired_codes = [c for c, data in list(auth_codes.items()) if data["expires"] <= now]
        for c in expired_codes:
            auth_codes.pop(c, None)
 
        db_deleted = _db_token_cleanup(now)
 
        if expired_tokens or expired_codes or db_deleted:
            log.info(
                "Cleanup: %d token (memoria), %d codici, %d token (DB)",
                len(expired_tokens), len(expired_codes), db_deleted,
            )
 
# ---------------------------------------------------------------------------
# PKCE
# ---------------------------------------------------------------------------
 
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
    digest = hashlib.sha256(code_verifier.encode()).digest()
    computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
    # Confronto a tempo costante per prevenire timing attack
    return hmac.compare_digest(computed, code_challenge)
 
# ---------------------------------------------------------------------------
# UTILITÀ
# ---------------------------------------------------------------------------
 
def _get_client_ip(request) -> str:
    # Cloudflare imposta CF-Connecting-IP con l'IP reale del client
    cf_ip = request.headers.get("cf-connecting-ip")
    if cf_ip:
        return cf_ip
    return request.client.host if request.client else "unknown"
 
 
def _risolvi_path(path: str) -> tuple[str, str | None]:
    base = os.path.realpath(BASE_DIR)
    clean = path.strip("/").strip()
    full = os.path.realpath(os.path.join(BASE_DIR, clean)) if clean else base
    # Previene path traversal: il percorso risolto deve stare dentro BASE_DIR
    if not (full.startswith(base + os.sep) or full == base):
        return "", "Accesso negato: percorso non valido."
    return full, None
 
 
def _dimensione_leggibile(n_bytes: int) -> str:
    for unita in ("B", "KB", "MB", "GB"):
        if n_bytes < 1024:
            return f"{n_bytes:.0f} {unita}"
        n_bytes /= 1024
    return f"{n_bytes:.1f} GB"
 
# ---------------------------------------------------------------------------
# ENDPOINT OAUTH
# ---------------------------------------------------------------------------
 
async def well_known(request):
    base = str(request.base_url).rstrip("/")
    return JSONResponse({
        "issuer": base,
        "authorization_endpoint": f"{base}/authorize",
        "token_endpoint": f"{base}/token",
        "response_types_supported": ["code"],
        "grant_types_supported": ["authorization_code", "client_credentials"],
        "code_challenge_methods_supported": ["S256"],
    })
 
 
async def authorize(request):
    ip = _get_client_ip(request)
    if _is_rate_limited(ip, RATE_LIMIT_MAX_AUTH):
        log.warning("Rate limit /authorize IP=%s", ip)
        return Response("Troppe richieste", status_code=429)
 
    params = dict(request.query_params)
    client_id             = params.get("client_id", "")
    redirect_uri          = params.get("redirect_uri", "")
    code_challenge        = params.get("code_challenge", "")
    code_challenge_method = params.get("code_challenge_method", "")
    state                 = params.get("state", "")
 
    if not hmac.compare_digest(client_id, OAUTH_CLIENT_ID):
        log.warning("client_id non valido da IP=%s", ip)
        return Response("Non autorizzato", status_code=401)
 
    if code_challenge_method != "S256":
        return Response("Solo S256 supportato per code_challenge_method", status_code=400)
 
    if not code_challenge:
        return Response("code_challenge obbligatorio", status_code=400)
 
    code = secrets.token_hex(32)
    auth_codes[code] = {
        "code_challenge": code_challenge,
        "redirect_uri":   redirect_uri,
        "expires":        time.time() + 300,
    }
    log.info("Codice autorizzazione emesso per IP=%s", ip)
    return RedirectResponse(url=f"{redirect_uri}?code={code}&state={state}", status_code=302)
 
 
async def token_endpoint(request):
    ip = _get_client_ip(request)
    if _is_rate_limited(ip, RATE_LIMIT_MAX_AUTH):
        log.warning("Rate limit /token IP=%s", ip)
        return Response("Troppe richieste", status_code=429)
 
    form          = await request.form()
    client_id     = form.get("client_id", "")
    client_secret = form.get("client_secret", "")
 
    auth_header = request.headers.get("authorization", "")
    if auth_header.startswith("Basic "):
        try:
            decoded = base64.b64decode(auth_header[6:]).decode()
            client_id, _, client_secret = decoded.partition(":")
        except Exception:
            return JSONResponse({"error": "invalid_client"}, status_code=401)
 
    grant_type = form.get("grant_type", "")
 
    if grant_type == "authorization_code":
        code          = form.get("code", "")
        code_verifier = form.get("code_verifier", "")
        redirect_uri  = form.get("redirect_uri", "")
 
        if not hmac.compare_digest(client_id, OAUTH_CLIENT_ID):
            log.warning("Token: client_id non valido da IP=%s", ip)
            return JSONResponse({"error": "invalid_client"}, status_code=401)
 
        if client_secret and not hmac.compare_digest(client_secret, OAUTH_CLIENT_SECRET):
            log.warning("Token: client_secret errato da IP=%s", ip)
            return JSONResponse({"error": "invalid_client"}, status_code=401)
 
        stored = auth_codes.pop(code, None)
        if not stored or time.time() > stored["expires"]:
            return JSONResponse({"error": "invalid_grant"}, status_code=400)
 
        if stored["redirect_uri"] != redirect_uri:
            return JSONResponse({"error": "invalid_grant"}, status_code=400)
 
        if not _verify_pkce(code_verifier, stored["code_challenge"]):
            log.warning("PKCE fallita da IP=%s", ip)
            return JSONResponse({"error": "invalid_grant"}, status_code=400)
 
        token   = secrets.token_hex(32)
        expires = time.time() + TOKEN_EXPIRY
        active_tokens[token] = expires
        _db_token_save(token, expires)
        log.info("Token emesso per IP=%s (scade tra %ss)", ip, TOKEN_EXPIRY)
        return JSONResponse({
            "access_token": token,
            "token_type":   "bearer",
            "expires_in":   TOKEN_EXPIRY,
        })
 
    if grant_type == "client_credentials":
        id_ok     = hmac.compare_digest(client_id, OAUTH_CLIENT_ID)
        secret_ok = hmac.compare_digest(client_secret, OAUTH_CLIENT_SECRET)
        if not (id_ok and secret_ok):
            log.warning("Client credentials non valide da IP=%s", ip)
            return JSONResponse({"error": "invalid_client"}, status_code=401)
        token   = secrets.token_hex(32)
        expires = time.time() + TOKEN_EXPIRY
        active_tokens[token] = expires
        _db_token_save(token, expires)
        log.info("Token (client_credentials) emesso per IP=%s", ip)
        return JSONResponse({
            "access_token": token,
            "token_type":   "bearer",
            "expires_in":   TOKEN_EXPIRY,
        })
 
    return JSONResponse({"error": "unsupported_grant_type"}, status_code=400)
 
 
async def health(request):
    return JSONResponse({
        "status":   "ok",
        "indexed":  _db_index_count(),
        "indexing": _index_state["running"],
        "progress": {
            "done":   _index_state["done"],
            "total":  _index_state["total"],
            "errors": _index_state["errors"],
        },
    })
 
# ---------------------------------------------------------------------------
# MIDDLEWARE DI SICUREZZA
# ---------------------------------------------------------------------------
 
class SecurityMiddleware(BaseHTTPMiddleware):
    PERCORSI_PUBBLICI = {
        "/authorize",
        "/token",
        "/health",
        "/.well-known/oauth-authorization-server",
        "/.well-known/oauth-protected-resource",
    }
 
    async def dispatch(self, request, call_next):
        if request.url.path not in self.PERCORSI_PUBBLICI and request.method != "OPTIONS":
            auth_response = await self._check_auth(request)
            if auth_response is not None:
                return auth_response
 
        response = await call_next(request)
 
        # Header di sicurezza su tutte le risposte
        response.headers["X-Content-Type-Options"] = "nosniff"
        response.headers["X-Frame-Options"]        = "DENY"
        response.headers["Cache-Control"]          = "no-store"
        response.headers["Referrer-Policy"]        = "no-referrer"
 
        return response
 
    async def _check_auth(self, request) -> Response | None:
        auth_header = request.headers.get("authorization", "")
        if not auth_header.startswith("Bearer "):
            ip = _get_client_ip(request)
            log.warning("Richiesta senza Bearer da IP=%s path=%s", ip, request.url.path)
            return Response("Non autorizzato", status_code=401)
 
        token  = auth_header[7:]
        expiry = active_tokens.get(token)
 
        if expiry is None:
            ip = _get_client_ip(request)
            log.warning("Token sconosciuto da IP=%s", ip)
            return Response("Token non valido", status_code=401)
 
        if time.time() > expiry:
            active_tokens.pop(token, None)
            _db_token_delete(token)
            return Response("Token scaduto", status_code=401)
 
        return None
 
# ---------------------------------------------------------------------------
# ESTRAZIONE TESTO
# ---------------------------------------------------------------------------
 
def _estrai_testo(full_path: str, max_chars: int = 200_000) -> str:
    ext = Path(full_path).suffix.lower()
 
    if ext == ".pdf":
        if _fitz is None:
            return "[Errore: PyMuPDF non installato. Esegui: pip install pymupdf]"
        try:
            doc = _fitz.open(full_path)
            pagine = []
            totale = 0
            for i, pagina in enumerate(doc, 1):
                testo = pagina.get_text("text")
                if not testo.strip():
                    continue
                h = f"--- Pagina {i} ---\n"
                pagine.append(h + testo)
                totale += len(h) + len(testo)
                if totale >= max_chars:
                    break
            doc.close()
            return "\n\n".join(pagine) if pagine else "[PDF senza testo estraibile: probabilmente scansionato]"
        except _fitz.FileDataError:
            return "[PDF protetto da password o danneggiato]"
        except Exception as e:
            return f"[Errore lettura PDF: {e}]"
 
    if ext == ".xls":
        try:
            import xlrd
            wb = xlrd.open_workbook(full_path)
            righe = []
            totale = 0
            for nome in wb.sheet_names():
                foglio = wb.sheet_by_name(nome)
                righe.append(f"=== Foglio: {nome} ===")
                for i in range(foglio.nrows):
                    r = "\t".join(str(foglio.cell_value(i, j)) for j in range(foglio.ncols))
                    righe.append(r)
                    totale += len(r)
                    if totale >= max_chars:
                        return "\n".join(righe)
            return "\n".join(righe) if righe else "[XLS vuoto]"
        except Exception as e:
            return f"[Errore lettura XLS: {e}]"
 
    if ext == ".docx":
        try:
            from docx import Document
            from docx.oxml.ns import qn
            doc = Document(full_path)
            righe: list[str] = []
            totale = 0
 
            def _aggiungi(testo: str) -> bool:
                nonlocal totale
                if not testo.strip():
                    return True
                righe.append(testo)
                totale += len(testo)
                return totale < max_chars
 
            for p in doc.paragraphs:
                if not _aggiungi(p.text):
                    return "\n".join(righe)
 
            for tabella in doc.tables:
                for riga in tabella.rows:
                    if not _aggiungi("\t".join(c.text for c in riga.cells)):
                        return "\n".join(righe)
 
            try:
                for txbx in doc.element.body.iter(qn("w:txbxContent")):
                    for p in txbx.iter(qn("w:p")):
                        testo_p = "".join(t.text for t in p.iter(qn("w:t")) if t.text)
                        if not _aggiungi("[TextBox] " + testo_p):
                            return "\n".join(righe)
            except Exception:
                pass
 
            return "\n".join(righe) if righe else "[DOCX vuoto]"
        except Exception as e:
            return f"[Errore lettura DOCX: {e}]"
 
    if ext == ".xlsx":
        try:
            import openpyxl
            wb = openpyxl.load_workbook(full_path, read_only=True, data_only=True)
            righe = []
            totale = 0
            for nome in wb.sheetnames:
                foglio = wb[nome]
                righe.append(f"=== Foglio: {nome} ===")
                for riga in foglio.iter_rows(values_only=True):
                    if any(c is not None for c in riga):
                        r = "\t".join(str(c) if c is not None else "" for c in riga)
                        righe.append(r)
                        totale += len(r)
                        if totale >= max_chars:
                            return "\n".join(righe)
            return "\n".join(righe) if righe else "[XLSX vuoto]"
        except Exception as e:
            return f"[Errore lettura XLSX: {e}]"
 
    try:
        with open(full_path, "r", encoding="utf-8", errors="replace") as f:
            return f.read(max_chars)
    except Exception as e:
        return f"[Errore lettura file: {e}]"
 
# ---------------------------------------------------------------------------
# MCP SERVER E TOOLS
# ---------------------------------------------------------------------------
 
mcp = FastMCP(
    "FileServer Progetto",
    streamable_http_path="/",
    host=ALLOWED_HOST,
)
 
 
@mcp.prompt()
def istruzioni_file_server() -> str:
    return """
Hai accesso al file server del progetto tramite cinque strumenti:
get_structure, list_all, read_file, search_files e search_content.
 
COMPORTAMENTO OBBLIGATORIO:
- Prima di rispondere a qualsiasi domanda su documenti, specifiche,
  requisiti, distinte, procedure o qualunque contenuto del progetto,
  DEVI sempre cercare e leggere i file rilevanti. Non rispondere mai
  a memoria su contenuti del progetto.
- Se l'utente menziona un documento (es. "la lista ACU", "il planning",
  "la distinta"), cerca subito con search_files e poi leggi con read_file.
- Se non trovi per nome, usa search_content con parole chiave dalla domanda.
 
FLUSSO CONSIGLIATO:
1. search_files con parte del nome del file (se conosci il nome)
2. read_file per leggere il contenuto
3. Se non trovi per nome, search_content per cercare nel testo
4. get_structure solo se hai bisogno di orientarti nella struttura
5. list_all per esplorare una cartella specifica
 
REGOLE SUI PERCORSI:
- Usa sempre percorsi RELATIVI alla radice del server.
- Non usare mai "/", ".", "..", o percorsi assoluti.
- I percorsi validi sono solo quelli mostrati dagli strumenti.
 
Se un file è troncato (vedi [TRONCATO] alla fine), richiama read_file
con max_chars più alto. Comunica in italiano.
"""
 
 
@mcp.tool()
def get_structure() -> str:
    """
    Restituisce l'albero delle CARTELLE del file server senza elencare i file.
    Usalo per orientarti nella struttura prima di qualsiasi altra operazione.
    Non richiede parametri.
    """
    base = os.path.realpath(BASE_DIR)
    righe = []
    for root, dirs, files in os.walk(base):
        dirs.sort()
        livello   = os.path.relpath(root, base).count(os.sep)
        indent    = "  " * livello
        nome = os.path.basename(root.rstrip(os.sep)) if root != base else "."
        n_file    = len(files)
        etichetta = f"  ({n_file} file)" if n_file else ""
        righe.append(f"{indent}{nome}/{etichetta}")
    totale = sum(len(f) for _, _, f in os.walk(base))
    righe.append(f"\nTotale file: {totale}")
    return "\n".join(righe)
 
 
@mcp.tool()
def list_all(path: str = "") -> str:
    """
    Elenca file e cartelle dentro una directory specifica del file server.
    path: percorso relativo della cartella (es. "Design" o "Distinta/Rev2").
    Lascia vuoto per la radice.
    Ogni file mostra dimensione e data di modifica.
    """
    full_path, errore = _risolvi_path(path)
    if errore:
        return errore
    if not os.path.exists(full_path):
        return f"Percorso non trovato: '{path}'. Usa get_structure per vedere le cartelle disponibili."
 
    righe   = []
    n_totale = 0
    for root, dirs, files in os.walk(full_path):
        dirs.sort()
        files.sort()
        livello       = os.path.relpath(root, full_path).count(os.sep)
        indent        = "  " * livello
        nome_cartella = os.path.basename(root) if livello > 0 else (path.strip("/") or ".")
        righe.append(f"{indent}{nome_cartella}/")
        for nome_file in files:
            fp = os.path.join(root, nome_file)
            try:
                stat = os.stat(fp)
                dim  = _dimensione_leggibile(stat.st_size)
                data = time.strftime("%d/%m/%Y", time.localtime(stat.st_mtime))
                info = f"  [{dim}, {data}]"
            except OSError:
                info = ""
            righe.append(f"{indent}  {nome_file}{info}")
            n_totale += 1
    righe.append(f"\nTotale file: {n_totale}")
    return "\n".join(righe)
 
 
@mcp.tool()
def read_file(path: str, max_chars: int = 50_000) -> str:
    """
    Legge il contenuto di un file del progetto.
    Supporta PDF, DOCX, XLSX, XLS, TXT, MD, CSV, JSON, XML, HTML.
    path: percorso relativo esatto come mostrato dagli altri strumenti.
    max_chars: limite caratteri (default 50000, massimo 500000).
    Se vedi [TRONCATO] alla fine, richiama con max_chars più alto.
    """
    max_chars = min(max_chars, 500_000)
 
    full_path, errore = _risolvi_path(path)
    if errore:
        return errore
    if not os.path.isfile(full_path):
        return f"File non trovato: '{path}'. Usa list_all o search_files per trovare il percorso corretto."
 
    ext = Path(full_path).suffix.lower()
    if ext not in ESTENSIONI_LEGGIBILI:
        return (
            f"Estensione '{ext}' non supportata. "
            f"Formati leggibili: {', '.join(sorted(ESTENSIONI_LEGGIBILI))}"
        )
 
    try:
        stat = os.stat(full_path)
        dim  = _dimensione_leggibile(stat.st_size)
        data = time.strftime("%d/%m/%Y %H:%M", time.localtime(stat.st_mtime))
    except OSError:
        dim, data = "?", "?"
 
    intestazione = f"[File: {path} | Dimensione: {dim} | Ultima modifica: {data}]\n\n"
    risultato    = _estrai_testo(full_path, max_chars=max_chars + 1)
 
    if len(risultato) > max_chars:
        return (
            intestazione + risultato[:max_chars]
            + f"\n\n[TRONCATO: il file supera {max_chars} caratteri. "
            f"Richiama con max_chars più alto per leggere il resto.]"
        )
    return intestazione + risultato
 
 
@mcp.tool()
def search_files(keyword: str) -> str:
    """
    Cerca file il cui NOME contiene una parola chiave (case-insensitive).
    keyword: parola o parte di parola da cercare nel nome del file.
    Restituisce percorsi relativi da usare con read_file.
    Se non conosci il nome ma sai cosa c'è scritto dentro, usa search_content.
    """
    if len(keyword) < 2:
        return "La parola chiave deve essere di almeno 2 caratteri."
    if len(keyword) > 200:
        return "Parola chiave troppo lunga."
 
    results = []
    for root, dirs, files in os.walk(BASE_DIR):
        dirs.sort()
        for nome in sorted(files):
            if nome.startswith("~$"):
                continue
            if keyword.lower() in nome.lower():
                relativo = os.path.relpath(os.path.join(root, nome), BASE_DIR)
                results.append(relativo)
 
    if not results:
        return (
            f"Nessun file trovato con '{keyword}' nel nome. "
            "Prova con una parola chiave diversa o usa search_content."
        )
    return f"File trovati ({len(results)}):\n" + "\n".join(results)
 
 
@mcp.tool()
def search_content(keyword: str, max_results: int = 10) -> str:
    """
    Cerca una parola o frase nel CONTENUTO dei file tramite indice full-text.
    keyword: parola o frase. Supporta operatori FTS5: AND, OR, NOT,
             "frase esatta", prefisso* (es: "requisit*", "motore AND potenza").
    max_results: numero massimo di file restituiti (default 10, massimo 50).
    Per ogni file trovato mostra il percorso e un estratto del contesto
    attorno alla parola trovata (la parola è racchiusa tra []).
    Usa i percorsi restituiti con read_file per leggere i file completi.
    """
    if len(keyword) < 2:
        return "La parola chiave deve essere di almeno 2 caratteri."
    if len(keyword) > 500:
        return "Parola chiave troppo lunga."
    max_results = min(max_results, 50)
 
    avviso = ""
    if _index_state["running"]:
        done  = _index_state["done"]
        total = _index_state["total"]
        avviso = (
            f"[Indice in costruzione: {done}/{total} file elaborati. "
            f"Risultati potrebbero essere parziali.]\n\n"
        )
 
    try:
        rows = _db_search(keyword, max_results)
    except Exception as e:
        return (
            f"Sintassi di ricerca non valida: {e}\n"
            "Usa parole semplici o operatori FTS5 validi: AND, OR, NOT, "
            "\"frase esatta\", prefisso*."
        )
 
    if not rows:
        n_totale = _db_index_count()
        return (
            avviso
            + f"Nessun file contiene '{keyword}' (indice: {n_totale} file). "
            "Prova con una parola chiave più generica o verifica l'ortografia."
        )
 
    n_totale     = _db_index_count()
    intestazione = f"File che contengono '{keyword}' ({len(rows)} trovati"
    if len(rows) == max_results:
        intestazione += f", mostrati i primi {max_results}"
    intestazione += f", su {n_totale} file nell'indice):\n\n"
 
    righe = [f"{path}\n  contesto: {estratto}" for path, estratto in rows]
    return avviso + intestazione + "\n\n".join(righe)
 
# ---------------------------------------------------------------------------
# AVVIO
# ---------------------------------------------------------------------------
 
if __name__ == "__main__":
    import uvicorn
 
    _db_init()
    active_tokens.update(_db_load_active_tokens())
    log.info("Token attivi ricaricati dal DB: %d", len(active_tokens))
 
    _start_indexing_thread()
    _start_file_watcher()
 
    @asynccontextmanager
    async def lifespan(app):
        cleanup_task = asyncio.create_task(_cleanup_expired())
        log.info("Task di cleanup token avviato")
        async with mcp.session_manager.run():
            log.info("MCP session manager avviato")
            yield
        cleanup_task.cancel()
        log.info("Shutdown completato")
 
    routes = [
        Route("/.well-known/oauth-authorization-server", well_known,    methods=["GET"]),
        Route("/.well-known/oauth-protected-resource",   well_known,    methods=["GET"]),
        Route("/authorize",                              authorize,      methods=["GET"]),
        Route("/token",                                  token_endpoint, methods=["POST"]),
        Route("/health",                                 health,         methods=["GET"]),
        Mount("/", app=mcp.streamable_http_app()),
    ]
 
    app = Starlette(routes=routes, lifespan=lifespan)
    app.add_middleware(SecurityMiddleware)
 
    # Il server ascolta SOLO su localhost.
    # L'accesso esterno passa interamente tramite cloudflared tunnel,
    # che gestisce TLS, autenticazione Cloudflare Access e forwarding.
    uvicorn.run(
        app,
        host="127.0.0.1",
        port=BIND_PORT,
    )