import httpx
from src.application.ports.llm_service_port import LLMServicePort
from src.application.dtos.LLM import LLMRequestDTO, LLMResponseDTO
from typing import Optional
import base64
import io
from PIL import Image
import json
import tempfile
import os

class ChartGemmaService(LLMServicePort):
    def __init__(self, gradio_url: str):
        self.gradio_url = gradio_url
    
    async def analyze(self, request: LLMRequestDTO) -> LLMResponseDTO:
        try:
            # Check if image data is provided
            if not request.image_bytes:
                return LLMResponseDTO(
                    answer="Error: No image data provided",
                    model_used="ChartGemma",
                    processing_time=None
                )
            
            # Convert image (if provided as bytes) to PIL Image
            image = self._load_image(request.image_bytes)
            
            if not image:
                return LLMResponseDTO(
                    answer="Error: Could not load image data",
                    model_used="ChartGemma",
                    processing_time=None
                )
            
            # Call Gradio API
            response = await self._call_gradio_api(
                image=image,
                question=request.question
            )
            
            return LLMResponseDTO(
                answer=response,
                model_used="ChartGemma",
                processing_time=None  # Gradio doesn't return this
            )
        except Exception as e:
            # Return a more informative error response
            return LLMResponseDTO(
                answer=f"Error analyzing chart: {str(e)}",
                model_used="ChartGemma",
                processing_time=None
            )
    
    def _load_image(self, image_data: Optional[bytes]) -> Optional[Image.Image]:
        if not image_data:
            return None
        try:
            return Image.open(io.BytesIO(image_data)).convert("RGB")
        except Exception as e:
            print(f"Error loading image: {e}")
            return None
    
    async def _call_gradio_api(self, image: Image.Image, question: str) -> str:
        # Save image to a temporary file (since Gradio expects filepath)
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
            image.save(tmp_file, format='PNG')
            tmp_file_path = tmp_file.name

        try:
            # Prepare the payload according to Gradio's API format
            payload = {
                "data": [
                    tmp_file_path,  # Filepath to the image
                    question       # Optional question
                ]
            }

            async with httpx.AsyncClient() as client:
                # The Gradio API endpoint is /api/predict/
                api_url = f"{self.gradio_url.rstrip('/')}/run/predict/"
                response = await client.post(api_url, json=payload)
                response.raise_for_status()

                # Gradio returns the output in the "data" field as a list
                return response.json()["data"][0]
        finally:
            # Clean up the temporary file
            try:
                os.unlink(tmp_file_path)
            except:
                pass