from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
from src.domain.ports.repositories.conversation_repository import ConversationRepositoryPort
from src.domain.entities.conversation import Conversation, ConversationMessage
from src.infrastructure.persistence.models.conversation_models import ConversationModel, ConversationMessageModel
from typing import List, Optional
import logging

logger = logging.getLogger(__name__)

class SqlConversationRepository(ConversationRepositoryPort):
    def __init__(self, session: AsyncSession):
        self._session = session

    async def create_conversation(self, conversation: Conversation) -> Conversation:
        """Create a new conversation"""
        try:
            conversation_model = ConversationModel(
                id=conversation.id,
                user_id=conversation.user_id,
                chart_image_id=conversation.chart_image_id,
                title=conversation.title,
                created_at=conversation.created_at,
                updated_at=conversation.updated_at,
                is_active=conversation.is_active
            )
            
            self._session.add(conversation_model)
            await self._session.commit()
            await self._session.refresh(conversation_model)
            
            return conversation
            
        except Exception as e:
            logger.error(f"Error creating conversation: {str(e)}")
            await self._session.rollback()
            raise

    async def get_conversation_by_id(self, conversation_id: str) -> Optional[Conversation]:
        """Get conversation by ID with messages"""
        try:
            stmt = select(ConversationModel).options(
                selectinload(ConversationModel.messages)
            ).where(ConversationModel.id == conversation_id)
            
            result = await self._session.execute(stmt)
            conversation_model = result.scalar_one_or_none()
            
            if not conversation_model:
                return None
            
            # Convert to domain entity
            messages = [
                ConversationMessage(
                    id=msg.id,
                    conversation_id=msg.conversation_id,
                    user_id=msg.user_id,
                    message_type=msg.message_type,
                    content=msg.content,
                    timestamp=msg.timestamp
                )
                for msg in conversation_model.messages
            ]
            
            return Conversation(
                id=conversation_model.id,
                user_id=conversation_model.user_id,
                chart_image_id=conversation_model.chart_image_id,
                title=conversation_model.title,
                created_at=conversation_model.created_at,
                updated_at=conversation_model.updated_at,
                is_active=conversation_model.is_active,
                messages=messages
            )
            
        except Exception as e:
            logger.error(f"Error getting conversation by ID: {str(e)}")
            raise

    async def get_conversations_by_user_id(self, user_id: str, limit: int = 20) -> List[Conversation]:
        """Get all conversations for a user"""
        try:
            stmt = select(ConversationModel).options(
                selectinload(ConversationModel.messages)
            ).where(
                ConversationModel.user_id == user_id
            ).order_by(
                ConversationModel.updated_at.desc()
            ).limit(limit)
            
            result = await self._session.execute(stmt)
            conversation_models = result.scalars().all()
            
            conversations = []
            for conv_model in conversation_models:
                messages = [
                    ConversationMessage(
                        id=msg.id,
                        conversation_id=msg.conversation_id,
                        user_id=msg.user_id,
                        message_type=msg.message_type,
                        content=msg.content,
                        timestamp=msg.timestamp
                    )
                    for msg in conv_model.messages
                ]
                
                conversation = Conversation(
                    id=conv_model.id,
                    user_id=conv_model.user_id,
                    chart_image_id=conv_model.chart_image_id,
                    title=conv_model.title,
                    created_at=conv_model.created_at,
                    updated_at=conv_model.updated_at,
                    is_active=conv_model.is_active,
                    messages=messages
                )
                conversations.append(conversation)
            
            return conversations
            
        except Exception as e:
            logger.error(f"Error getting conversations by user ID: {str(e)}")
            raise

    async def update_conversation(self, conversation: Conversation) -> Conversation:
        """Update an existing conversation"""
        try:
            # Update conversation
            stmt = update(ConversationModel).where(
                ConversationModel.id == conversation.id
            ).values(
                title=conversation.title,
                updated_at=conversation.updated_at,
                is_active=conversation.is_active
            )
            
            await self._session.execute(stmt)
            
            # Update messages
            for message in conversation.messages:
                # Check if message exists
                existing_msg = await self._session.get(ConversationMessageModel, message.id)
                if not existing_msg:
                    # Create new message
                    message_model = ConversationMessageModel(
                        id=message.id,
                        conversation_id=message.conversation_id,
                        user_id=message.user_id,
                        message_type=message.message_type,
                        content=message.content,
                        timestamp=message.timestamp
                    )
                    self._session.add(message_model)
            
            await self._session.commit()
            return conversation
            
        except Exception as e:
            logger.error(f"Error updating conversation: {str(e)}")
            await self._session.rollback()
            raise

    async def delete_conversation(self, conversation_id: str) -> bool:
        """Delete a conversation (hard delete)"""
        try:
            # Delete messages first
            stmt = delete(ConversationMessageModel).where(
                ConversationMessageModel.conversation_id == conversation_id
            )
            await self._session.execute(stmt)
            
            # Delete conversation
            stmt = delete(ConversationModel).where(
                ConversationModel.id == conversation_id
            )
            await self._session.execute(stmt)
            
            await self._session.commit()
            return True
            
        except Exception as e:
            logger.error(f"Error deleting conversation: {str(e)}")
            await self._session.rollback()
            raise

    async def get_active_conversations_by_user_id(self, user_id: str, limit: int = 20) -> List[Conversation]:
        """Get active conversations for a user"""
        try:
            stmt = select(ConversationModel).options(
                selectinload(ConversationModel.messages)
            ).where(
                ConversationModel.user_id == user_id,
                ConversationModel.is_active == True
            ).order_by(
                ConversationModel.updated_at.desc()
            ).limit(limit)
            
            result = await self._session.execute(stmt)
            conversation_models = result.scalars().all()
            
            conversations = []
            for conv_model in conversation_models:
                messages = [
                    ConversationMessage(
                        id=msg.id,
                        conversation_id=msg.conversation_id,
                        user_id=msg.user_id,
                        message_type=msg.message_type,
                        content=msg.content,
                        timestamp=msg.timestamp
                    )
                    for msg in conv_model.messages
                ]
                
                conversation = Conversation(
                    id=conv_model.id,
                    user_id=conv_model.user_id,
                    chart_image_id=conv_model.chart_image_id,
                    title=conv_model.title,
                    created_at=conv_model.created_at,
                    updated_at=conv_model.updated_at,
                    is_active=conv_model.is_active,
                    messages=messages
                )
                conversations.append(conversation)
            
            return conversations
            
        except Exception as e:
            logger.error(f"Error getting active conversations: {str(e)}")
            raise

    async def add_message_to_conversation(self, conversation_id: str, message: ConversationMessage) -> ConversationMessage:
        """Add a message to a conversation"""
        try:
            message_model = ConversationMessageModel(
                id=message.id,
                conversation_id=message.conversation_id,
                user_id=message.user_id,
                message_type=message.message_type,
                content=message.content,
                timestamp=message.timestamp
            )
            
            self._session.add(message_model)
            await self._session.commit()
            await self._session.refresh(message_model)
            
            return message
            
        except Exception as e:
            logger.error(f"Error adding message to conversation: {str(e)}")
            await self._session.rollback()
            raise

    async def get_messages_by_conversation_id(self, conversation_id: str, limit: int = 50) -> List[ConversationMessage]:
        """Get messages for a conversation"""
        try:
            stmt = select(ConversationMessageModel).where(
                ConversationMessageModel.conversation_id == conversation_id
            ).order_by(
                ConversationMessageModel.timestamp.asc()
            ).limit(limit)
            
            result = await self._session.execute(stmt)
            message_models = result.scalars().all()
            
            messages = [
                ConversationMessage(
                    id=msg.id,
                    conversation_id=msg.conversation_id,
                    user_id=msg.user_id,
                    message_type=msg.message_type,
                    content=msg.content,
                    timestamp=msg.timestamp
                )
                for msg in message_models
            ]
            
            return messages
            
        except Exception as e:
            logger.error(f"Error getting messages by conversation ID: {str(e)}")
            raise 