"""
Gestion de la base de données avec SQLAlchemy
"""
import json
import os
from datetime import datetime, timedelta
from typing import Optional, List, Dict, Any
from contextlib import contextmanager

from sqlalchemy import create_engine, func, and_, or_
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import StaticPool

from app.db_models import Base, QueueModel, APIResponseModel, LogModel
from app.models import QueuedRequest, APIResponse, RequestStatus
from app.logger import get_logger

logger = get_logger(__name__)


class Database:
    """Classe de gestion de la base de données avec SQLAlchemy"""

    def __init__(self, dbpath: str):
        self.dbpath = dbpath

        # Créer le répertoire parent si nécessaire
        parent_dir = os.path.dirname(dbpath)
        if parent_dir:  # Ne pas appeler makedirs si le chemin est vide
            os.makedirs(parent_dir, exist_ok=True)

        # Résoudre les liens symboliques
        if os.path.islink(dbpath):
            dbpath = os.path.realpath(dbpath)

        # Créer l'engine SQLAlchemy
        # Pour SQLite, on utilise check_same_thread=False et StaticPool pour éviter les problèmes de threading
        self.engine = create_engine(
            f"sqlite:///{dbpath}",
            connect_args={"check_same_thread": False},
            poolclass=StaticPool,
            echo=False  # Mettre à True pour debug SQL
        )

        # Créer la session factory
        self.SessionLocal = sessionmaker(
            autocommit=False,
            autoflush=False,
            bind=self.engine
        )

        # Initialiser la base de données
        self.init_db()

    def init_db(self):
        """Initialise les tables"""
        Base.metadata.create_all(bind=self.engine)
        logger.info("database_initialized", dbpath=self.dbpath)

    @contextmanager
    def get_session(self) -> Session:
        """Context manager pour obtenir une session SQLAlchemy"""
        session = self.SessionLocal()
        try:
            yield session
            session.commit()
        except Exception as e:
            session.rollback()
            logger.error("database_session_error", error=str(e), exc_info=True)
            raise
        finally:
            session.close()

    def add_to_queue(self, request: QueuedRequest) -> int:
        """Ajoute une requête à la queue"""
        with self.get_session() as session:
            queue_item = QueueModel(
                event_type=request.event_type,
                data=json.dumps(request.data),
                priority=request.priority,
                request_metadata=json.dumps(request.metadata) if request.metadata else None,
                status=request.status,
                max_retries=request.max_retries
            )
            session.add(queue_item)
            session.flush()  # Pour obtenir l'ID avant le commit
            request_id = queue_item.id

            logger.debug(
                "queue_item_added",
                request_id=request_id,
                event_type=request.event_type
            )

            return request_id

    def get_pending_requests(self, limit: int = 10) -> List[QueuedRequest]:
        """Récupère les requêtes en attente"""
        with self.get_session() as session:
            # Requêtes pending ou retrying dont le next_retry_at est passé
            now = datetime.utcnow()

            query = session.query(QueueModel).filter(
                or_(
                    QueueModel.status == RequestStatus.PENDING,
                    and_(
                        QueueModel.status == RequestStatus.RETRYING,
                        QueueModel.next_retry_at <= now
                    )
                )
            ).order_by(
                QueueModel.priority.asc(),
                QueueModel.created_at.asc()
            ).limit(limit)

            rows = query.all()

            requests = []
            for row in rows:
                requests.append(QueuedRequest(
                    id=row.id,
                    event_type=row.event_type,
                    data=json.loads(row.data),
                    priority=row.priority,
                    metadata=json.loads(row.request_metadata) if row.request_metadata else None,
                    status=row.status,
                    retry_count=row.retry_count,
                    max_retries=row.max_retries,
                    created_at=row.created_at
                ))

            return requests

    def update_request_status(
        self,
        request_id: int,
        status: RequestStatus,
        retry_count: Optional[int] = None,
        next_retry_at: Optional[datetime] = None,
        error_details: Optional[Dict[str, Any]] = None
    ):
        """Met à jour le statut d'une requête"""
        with self.get_session() as session:
            queue_item = session.query(QueueModel).filter(
                QueueModel.id == request_id
            ).first()

            if not queue_item:
                logger.warning("queue_item_not_found", request_id=request_id)
                return

            queue_item.status = status

            if retry_count is not None:
                queue_item.retry_count = retry_count

            if next_retry_at is not None:
                queue_item.next_retry_at = next_retry_at

            if error_details is not None:
                queue_item.error_details = json.dumps(error_details, ensure_ascii=False, indent=2)

            # Mettre à jour processed_at si completed ou failed
            if status in [RequestStatus.COMPLETED, RequestStatus.FAILED]:
                queue_item.processed_at = datetime.utcnow()

            logger.debug(
                "queue_status_updated",
                request_id=request_id,
                status=status.value,
                retry_count=retry_count
            )

    def save_api_response(self, response: APIResponse):
        """Sauvegarde une réponse API"""
        with self.get_session() as session:
            api_response = APIResponseModel(
                queue_id=response.queue_id,
                destination=response.destination,
                request_payload=json.dumps(response.request_payload) if response.request_payload else None,
                status_code=response.status_code,
                response_data=json.dumps(response.response_data) if response.response_data else None,
                error=response.error,
                duration_ms=response.duration_ms
            )
            session.add(api_response)

            logger.debug(
                "api_response_saved",
                queue_id=response.queue_id,
                destination=response.destination,
                status_code=response.status_code
            )

    def get_request_responses(self, queue_id: int) -> List[APIResponse]:
        """Récupère toutes les réponses d'une requête"""
        with self.get_session() as session:
            rows = session.query(APIResponseModel).filter(
                APIResponseModel.queue_id == queue_id
            ).all()

            responses = []
            for row in rows:
                responses.append(APIResponse(
                    queue_id=row.queue_id,
                    destination=row.destination,
                    request_payload=json.loads(row.request_payload) if row.request_payload else None,
                    status_code=row.status_code,
                    response_data=json.loads(row.response_data) if row.response_data else None,
                    error=row.error,
                    duration_ms=row.duration_ms,
                    timestamp=row.timestamp
                ))

            return responses

    def get_queue_stats(self) -> Dict[str, int]:
        """Statistiques de la queue"""
        with self.get_session() as session:
            # Group by status et count
            results = session.query(
                QueueModel.status,
                func.count(QueueModel.id).label("count")
            ).group_by(QueueModel.status).all()

            stats = {row.status.value: row.count for row in results}
            return stats

    def cleanup_old_records(self, days: int = 30) -> int:
        """Nettoie les anciens enregistrements"""
        cutoff_date = datetime.utcnow() - timedelta(days=days)

        with self.get_session() as session:
            # Supprimer les requêtes complétées ou échouées avant la date limite
            deleted = session.query(QueueModel).filter(
                and_(
                    QueueModel.status.in_([RequestStatus.COMPLETED, RequestStatus.FAILED]),
                    QueueModel.processed_at < cutoff_date
                )
            ).delete(synchronize_session=False)

            logger.info(
                "old_records_cleaned",
                deleted_count=deleted,
                cutoff_days=days
            )

            return deleted

    def get_connection(self):
        """
        Méthode de compatibilité avec l'ancien code
        Retourne une session SQLAlchemy au lieu d'une connexion sqlite3
        """
        return self.get_session()

    def add_log(self, queue_id: Optional[int], level: str, message: str):
        """Ajoute un log dans la base de données"""
        with self.get_session() as session:
            log = LogModel(
                queue_id=queue_id,
                level=level,
                message=message
            )
            session.add(log)

    def get_request_by_id(self, request_id: int) -> Optional[QueuedRequest]:
        """Récupère une requête par son ID"""
        with self.get_session() as session:
            row = session.query(QueueModel).filter(
                QueueModel.id == request_id
            ).first()

            if not row:
                return None

            return QueuedRequest(
                id=row.id,
                event_type=row.event_type,
                data=json.loads(row.data),
                priority=row.priority,
                metadata=json.loads(row.request_metadata) if row.request_metadata else None,
                status=row.status,
                retry_count=row.retry_count,
                max_retries=row.max_retries,
                created_at=row.created_at
            )

    def get_queue_model_by_id(self, request_id: int) -> Optional[QueueModel]:
        """Récupère le modèle QueueModel par ID (pour compatibilité avec main.py)"""
        with self.get_session() as session:
            return session.query(QueueModel).filter(
                QueueModel.id == request_id
            ).first()
