Commit da949056 authored by ZeinabRm13's avatar ZeinabRm13

Add cahrtgemma srvice

parent 80763c81
This diff is collapsed.
from fastapi import FastAPI from fastapi import FastAPI
from src.infrastructure.api.fastapi.routes import auth, charts from src.infrastructure.api.fastapi.routes import auth, charts
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from src.config import settings
app = FastAPI(debug = "True", app = FastAPI(
debug=True,
title="Chart Analyzer API", title="Chart Analyzer API",
description="API for analyzing charts and managing users.", description="API for analyzing charts and managing users.",
version="1.0.0", version="1.0.0",
...@@ -13,11 +14,11 @@ app = FastAPI(debug = "True", ...@@ -13,11 +14,11 @@ app = FastAPI(debug = "True",
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
engine = create_engine("postgresql+psycopg://chart_analyzer_user:chartanalyzer13@localhost:5432/chart_analyzer") engine = create_engine(settings.DATABASE_URL)
try: try:
with engine.connect() as conn: with engine.connect() as conn:
print("✅ Database connection successful") print("✅ Database connection successful")
print(f"🔗 Connection string: postgresql+psycopg://chart_analyzer_user:chartanalyzer13@localhost:5432/chart_analyzer") print(f"🔗 Connection string: {settings.DATABASE_URL}")
except OperationalError as e: except OperationalError as e:
print("❌ Database connection failed") print("❌ Database connection failed")
print(f"Error: {e}") print(f"Error: {e}")
...@@ -31,10 +32,17 @@ app.add_middleware( ...@@ -31,10 +32,17 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# Include routers
app.include_router(auth.router, prefix="/auth") app.include_router(auth.router, prefix="/auth")
app.include_router(charts.router, prefix="/charts") app.include_router(charts.router, prefix="/charts")
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "Welcome to Chart Analyzer"} return {
"message": "Welcome to Chart Analyzer API",
"version": "1.0.0",
"endpoints": {
"authentication": "/auth",
"chart_analysis": "/charts"
}
}
...@@ -11,7 +11,6 @@ class LLMRequestDTO(BaseModel): ...@@ -11,7 +11,6 @@ class LLMRequestDTO(BaseModel):
"""Input for LLM analysis requests""" """Input for LLM analysis requests"""
image_bytes: bytes = Field(..., description="Binary image data (PNG/JPEG)") image_bytes: bytes = Field(..., description="Binary image data (PNG/JPEG)")
question: str = Field(..., max_length=500, description="Question about the chart") question: str = Field(..., max_length=500, description="Question about the chart")
provider: LLMProvider = Field(LLMProvider.HUGGINGFACE, description="LLM service provider")
max_tokens: int = Field(300, gt=0, le=2000, description="Max response length") max_tokens: int = Field(300, gt=0, le=2000, description="Max response length")
# temperature: float = Field(0.3, ge=0.1, le=1.0, description="Creativity control") # temperature: float = Field(0.3, ge=0.1, le=1.0, description="Creativity control")
......
from pydantic import BaseModel
from typing import List, Optional
class CreateConversationRequestDTO(BaseModel):
chart_image_id: str
title: str
class SendMessageRequestDTO(BaseModel):
conversation_id: str
message: str
class ConversationResponseDTO(BaseModel):
id: str
title: str
created_at: str
updated_at: str
message_count: int
class MessageResponseDTO(BaseModel):
id: str
message_type: str
content: str
timestamp: str
class ConversationDetailResponseDTO(BaseModel):
id: str
title: str
created_at: str
updated_at: str
messages: List[MessageResponseDTO]
\ No newline at end of file
import httpx
from src.application.ports.llm_service_port import LLMServicePort
from src.application.dtos.LLM import LLMRequestDTO, LLMResponseDTO
from typing import Optional
import base64
import io
from PIL import Image
import json
import tempfile
import os
class ChartGemmaService(LLMServicePort):
def __init__(self, gradio_url: str):
self.gradio_url = gradio_url
async def analyze(self, request: LLMRequestDTO) -> LLMResponseDTO:
try:
# Check if image data is provided
if not request.image_bytes:
return LLMResponseDTO(
answer="Error: No image data provided",
model_used="ChartGemma",
processing_time=None
)
# Convert image (if provided as bytes) to PIL Image
image = self._load_image(request.image_bytes)
if not image:
return LLMResponseDTO(
answer="Error: Could not load image data",
model_used="ChartGemma",
processing_time=None
)
# Call Gradio API
response = await self._call_gradio_api(
image=image,
question=request.question
)
return LLMResponseDTO(
answer=response,
model_used="ChartGemma",
processing_time=None # Gradio doesn't return this
)
except Exception as e:
# Return a more informative error response
return LLMResponseDTO(
answer=f"Error analyzing chart: {str(e)}",
model_used="ChartGemma",
processing_time=None
)
def _load_image(self, image_data: Optional[bytes]) -> Optional[Image.Image]:
if not image_data:
return None
try:
return Image.open(io.BytesIO(image_data)).convert("RGB")
except Exception as e:
print(f"Error loading image: {e}")
return None
async def _call_gradio_api(self, image: Image.Image, question: str) -> str:
# Save image to a temporary file (since Gradio expects filepath)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
image.save(tmp_file, format='PNG')
tmp_file_path = tmp_file.name
try:
# Prepare the payload according to Gradio's API format
payload = {
"data": [
tmp_file_path, # Filepath to the image
question # Optional question
]
}
async with httpx.AsyncClient() as client:
# The Gradio API endpoint is /api/predict/
api_url = f"{self.gradio_url.rstrip('/')}/run/predict/"
response = await client.post(api_url, json=payload)
response.raise_for_status()
# Gradio returns the output in the "data" field as a list
return response.json()["data"][0]
finally:
# Clean up the temporary file
try:
os.unlink(tmp_file_path)
except:
pass
\ No newline at end of file
from src.domain.entities.chart_analysis import ChartAnalysis
from src.domain.ports.repositories.analysis_repository import AnalysisRepositoryPort
from src.application.ports.llm_service_port import LLMServicePort
from src.application.dtos.LLM import LLMRequestDTO
import uuid
from datetime import datetime, timezone
from typing import List
class ChatConversationUseCase:
def __init__(
self,
analysis_repo: AnalysisRepositoryPort,
llm_service: LLMServicePort
):
self._analysis_repo = analysis_repo
self._llm_service = llm_service
def execute(self, chart_image_id: str, conversation_history: List[str], new_question: str) -> ChartAnalysis:
"""Handle interactive chat conversation about a chart"""
# Build context from conversation history
context = self._build_conversation_context(conversation_history)
# Create enhanced question with context
enhanced_question = f"Context from previous conversation: {context}\n\nNew question: {new_question}"
# Create request for LLM
request = LLMRequestDTO(
image_bytes=b"", # Will be loaded by service
question=enhanced_question
)
# Get response from LLM
response = await self._llm_service.analyze(request)
# Save the conversation turn
analysis = ChartAnalysis(
id=str(uuid.uuid4()),
chart_image_id=chart_image_id,
question=new_question,
answer=response.answer,
created_at=datetime.now(timezone.utc)
)
self._analysis_repo.save_analysis(analysis)
return analysis
def _build_conversation_context(self, history: List[str]) -> str:
"""Build context string from conversation history"""
if not history:
return ""
context_parts = []
for i, entry in enumerate(history[-5:], 1): # Last 5 entries for context
context_parts.append(f"Turn {i}: {entry}")
return " | ".join(context_parts)
\ No newline at end of file
from src.domain.ports.repositories.analysis_repository import AnalysisRepositoryPort
from src.domain.entities.chart_analysis import ChartAnalysis
from typing import List
class GetAnalysisHistoryUseCase:
def __init__(self, analysis_repo: AnalysisRepositoryPort):
self._analysis_repo = analysis_repo
def execute(self, user_id: str, limit: int = 50) -> List[ChartAnalysis]:
"""Get analysis history for a user"""
return self._analysis_repo.get_analyses_by_user_id(user_id, limit)
\ No newline at end of file
from src.domain.entities.chart_analysis import ChartAnalysis
from src.domain.ports.repositories.analysis_repository import AnalysisRepositoryPort
from src.application.dtos.LLM import LLMResponseDTO
import uuid
from datetime import datetime, timezone
class SaveAnalysisUseCase:
def __init__(self, analysis_repo: AnalysisRepositoryPort):
self._analysis_repo = analysis_repo
def execute(self, chart_image_id: str, question: str, llm_response: LLMResponseDTO) -> ChartAnalysis:
"""Save analysis result to database for history tracking"""
analysis = ChartAnalysis(
id=str(uuid.uuid4()),
chart_image_id=chart_image_id,
question=question,
answer=llm_response.answer,
created_at=datetime.now(timezone.utc)
)
self._analysis_repo.save_analysis(analysis)
return analysis
\ No newline at end of file
...@@ -9,6 +9,12 @@ class Settings(BaseSettings): ...@@ -9,6 +9,12 @@ class Settings(BaseSettings):
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_MINUTES: int = 30 JWT_EXPIRE_MINUTES: int = 30
# LLM Service Configuration
OLLAMA_HOST: str = "http://localhost:11434"
OLLAMA_MODEL: str = "llava:34b"
CHARTGEMMA_GRADIO_URL: str = "https://3f4f53fba3f99b8778.gradio.live/"
DEFAULT_LLM_MODEL: str = "ollama"
# Use model_config instead of inner Config class # Use model_config instead of inner Config class
model_config = ConfigDict( model_config = ConfigDict(
env_file=".env", env_file=".env",
......
from fastapi import Depends from fastapi import Depends, Query
from typing import AsyncGenerator from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
...@@ -10,30 +10,15 @@ from src.domain.ports.repositories.charts_repository import ChartsRepositoryPort ...@@ -10,30 +10,15 @@ from src.domain.ports.repositories.charts_repository import ChartsRepositoryPort
from src.infrastructure.adapters.sqlserver.sql_user_repository import SqlUserRepository from src.infrastructure.adapters.sqlserver.sql_user_repository import SqlUserRepository
from src.infrastructure.adapters.sqlserver.sql_charts_repository import SqlChartsRepository from src.infrastructure.adapters.sqlserver.sql_charts_repository import SqlChartsRepository
from src.application.services.analyze_service import AnalyzeService from src.application.services.analyze_service import AnalyzeService
from src.application.services.llm_service import LLMService from src.application.services.ollama_service import OllamaService
from src.application.services.chartGemma_service import ChartGemmaService
from src.application.ports.llm_service_port import LLMServicePort from src.application.ports.llm_service_port import LLMServicePort
from src.infrastructure.adapters.sqlserver.sql_token_repository import SqlTokenRepository from src.infrastructure.adapters.sqlserver.sql_token_repository import SqlTokenRepository
# from infrastructure.services.llm.openai_service import OpenAIService # Concrete LLM impl
# from infrastructure.services.image.pillow_service import PillowImageService # Concrete impl
from src.config import settings from src.config import settings
from src.application.services.ollama_service import OllamaService
engine = create_async_engine(settings.DATABASE_URL, echo=True) engine = create_async_engine(settings.DATABASE_URL, echo=True)
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
def get_llm_service() -> LLMServicePort:
return OllamaService(host="http://172.25.1.141:11434")
async def get_db_session() -> AsyncGenerator[AsyncSession, None]: async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
async with AsyncSessionLocal() as session: async with AsyncSessionLocal() as session:
yield session yield session
...@@ -44,10 +29,9 @@ def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> User ...@@ -44,10 +29,9 @@ def get_user_repository(session: AsyncSession = Depends(get_db_session)) -> User
def get_charts_repository(session: AsyncSession = Depends(get_db_session)) -> ChartsRepositoryPort: def get_charts_repository(session: AsyncSession = Depends(get_db_session)) -> ChartsRepositoryPort:
return SqlChartsRepository(session) return SqlChartsRepository(session)
def get_token_repo(session: AsyncSession = Depends(get_db_session)) -> SqlTokenRepository: def get_token_repo(session: AsyncSession = Depends(get_db_session)) -> TokenRepositoryPort:
return SqlTokenRepository(session) return SqlTokenRepository(session)
def get_auth_service( def get_auth_service(
user_repo: UserRepositoryPort = Depends(get_user_repository), user_repo: UserRepositoryPort = Depends(get_user_repository),
token_repo: TokenRepositoryPort = Depends(get_token_repo) token_repo: TokenRepositoryPort = Depends(get_token_repo)
...@@ -57,11 +41,38 @@ def get_auth_service( ...@@ -57,11 +41,38 @@ def get_auth_service(
def get_upload_use_case(charts_repo: ChartsRepositoryPort = Depends(get_charts_repository)) -> UploadChartUseCase: def get_upload_use_case(charts_repo: ChartsRepositoryPort = Depends(get_charts_repository)) -> UploadChartUseCase:
return UploadChartUseCase(charts_repo) return UploadChartUseCase(charts_repo)
def get_analysis_service() -> AnalyzeService: def get_analysis_service() -> AnalyzeService:
"""Factory for the analysis service""" """Factory for the analysis service"""
llm_service = LLMService() llm_service = get_llm_service()
# image_service = PillowImageService()
return AnalyzeService(llm_service) return AnalyzeService(llm_service)
def get_ollama_service() -> OllamaService:
"""Factory for Ollama LLM service"""
return OllamaService(host=settings.OLLAMA_HOST)
def get_chartgemma_service() -> ChartGemmaService:
"""Factory for ChartGemma LLM service"""
return ChartGemmaService(gradio_url=settings.CHARTGEMMA_GRADIO_URL)
def get_llm_service(model: str = settings.DEFAULT_LLM_MODEL) -> LLMServicePort:
"""
Factory function to get the appropriate LLM service based on model selection.
This allows easy switching between different LLM providers.
"""
if model.lower() == "ollama":
return get_ollama_service()
elif model.lower() == "chartgemma":
return get_chartgemma_service()
else:
# Default to Ollama if unknown model is specified
return get_ollama_service()
def get_llm_service_by_query(
model: str = Query(default=settings.DEFAULT_LLM_MODEL, description="LLM model to use (ollama or chartgemma)")
) -> LLMServicePort:
"""
Dependency function that allows model selection via query parameter.
This provides a clean way for users to select their preferred LLM model.
"""
return get_llm_service(model)
from pydantic import BaseModel
from datetime import datetime, timezone
from typing import List, Optional
class ConversationMessage(BaseModel):
id: str
conversation_id: str
user_id: str
message_type: str # "user" or "assistant"
content: str
timestamp: datetime = datetime.now(timezone.utc)
class Conversation(BaseModel):
id: str
user_id: str
chart_image_id: str
title: str
created_at: datetime = datetime.now(timezone.utc)
updated_at: datetime = datetime.now(timezone.utc)
is_active: bool = True
messages: List[ConversationMessage] = []
\ No newline at end of file
from fastapi import APIRouter, Form, UploadFile, File, Depends from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Query
from src.application.services.analyze_service import AnalyzeService
from src.dependencies import get_upload_use_case
from src.application.dtos.LLM import LLMResponseDTO, LLMRequestDTO
from src.application.ports.llm_service_port import LLMServicePort from src.application.ports.llm_service_port import LLMServicePort
from src.dependencies import get_llm_service from src.application.dtos.LLM import LLMRequestDTO
from src.dependencies import get_llm_service_by_query, get_llm_service
from src.config import settings
from typing import Optional
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter(tags=["charts"]) router = APIRouter(tags=["charts"])
@router.post("/analyze")
@router.post("/ask", response_model=LLMResponseDTO) async def analyze_general(
async def ask_about_chart(
image: UploadFile = File(...), image: UploadFile = File(...),
question: str = Form(...), question: str = Query(default="", description="Question about the chart"),
llm_service: LLMServicePort = Depends(get_llm_service) llm_service: LLMServicePort = Depends(get_llm_service_by_query)
): ):
"""
Unified endpoint that works with both Ollama and ChartGemma models.
Users can select the model via the 'model' query parameter.
"""
try:
logger.info(f"Starting analysis with question: {question}")
logger.info(f"Image filename: {image.filename}, content_type: {image.content_type}")
image_bytes = await image.read()
logger.info(f"Image size: {len(image_bytes)} bytes")
request = LLMRequestDTO( request = LLMRequestDTO(
image_bytes=await image.read(), image_bytes=image_bytes,
question=question question=question
) )
return await llm_service.analyze(request)
logger.info("Calling LLM service...")
response = await llm_service.analyze(request)
logger.info(f"LLM service response received: {response.answer[:100]}...")
return {
"answer": response.answer,
"model": response.model_used,
"processing_time": response.processing_time
}
@router.post("/analyze", response_model=LLMResponseDTO) except Exception as e:
async def analyze_chart( logger.error(f"Error in analyze_general: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@router.post("/ask")
async def analyze_chart_specialized(
image: UploadFile = File(...), image: UploadFile = File(...),
# question: str = Form(...), question: str = Query(default="", description="Question about the chart"),
llm_service: LLMServicePort = Depends(get_llm_service) advanced_analysis: bool = Query(default=False, description="Enable advanced analysis with program of thought")
): ):
"""
Specialized endpoint for chart analysis with ChartGemma.
- advanced_analysis: When True, adds specific prompt engineering for charts
"""
try:
logger.info(f"Starting specialized analysis with question: {question}, advanced: {advanced_analysis}")
# Force ChartGemma for specialized chart analysis
service = get_llm_service("chartgemma")
# Enhanced prompt for charts if needed
processed_question = f"PROGRAM OF THOUGHT: {question}" if advanced_analysis else question
image_bytes = await image.read()
request = LLMRequestDTO( request = LLMRequestDTO(
image_bytes=await image.read(), image_bytes=image_bytes,
question="analyze the trends showing in this chart" question=processed_question
)
logger.info("Calling ChartGemma service...")
response = await service.analyze(request)
logger.info(f"ChartGemma service response received: {response.answer[:100]}...")
return {
"answer": response.answer,
"model": "ChartGemma",
"analysis_type": "advanced" if advanced_analysis else "basic",
"processing_time": response.processing_time
}
except Exception as e:
logger.error(f"Error in analyze_chart_specialized: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@router.get("/test-chartgemma")
async def test_chartgemma_connection():
"""
Test endpoint to verify ChartGemma service connection and response structure.
"""
try:
logger.info("Testing ChartGemma service connection...")
# Create a simple test request without an image
service = get_llm_service("chartgemma")
# Test the service with a minimal request
test_request = LLMRequestDTO(
image_bytes=b"", # Empty image for testing
question="Test question"
) )
return await llm_service.analyze(request)
\ No newline at end of file response = await service.analyze(test_request)
return {
"status": "success",
"service": "ChartGemma",
"response": response.answer,
"model_used": response.model_used
}
except Exception as e:
logger.error(f"Error testing ChartGemma: {str(e)}", exc_info=True)
return {
"status": "error",
"service": "ChartGemma",
"error": str(e)
}
@router.get("/models")
async def list_available_models():
"""
Endpoint to list available LLM services and their configurations.
"""
return {
"available_models": ["ollama", "chartgemma"],
"default_model": settings.DEFAULT_LLM_MODEL,
"model_configs": {
"ollama": {
"host": settings.OLLAMA_HOST,
"model": settings.OLLAMA_MODEL,
"description": "Local Ollama LLM service"
},
"chartgemma": {
"gradio_url": settings.CHARTGEMMA_GRADIO_URL,
"description": "Specialized chart analysis model"
}
}
}
\ No newline at end of file
from fastapi import APIRouter, Depends, HTTPException, status
from src.application.dtos.conversation import (
CreateConversationRequestDTO,
SendMessageRequestDTO,
ConversationResponseDTO,
ConversationDetailResponseDTO
)
from src.application.use_cases.chat_conversation import ChatConversationUseCase
from src.dependencies import get_current_user
from typing import List
import logging
logger = logging.getLogger(__name__)
router = APIRouter(tags=["conversations"])
@router.post("/", response_model=ConversationResponseDTO, status_code=status.HTTP_201_CREATED)
async def create_conversation(
request: CreateConversationRequestDTO,
current_user: dict = Depends(get_current_user)
):
"""Create a new conversation for a chart"""
try:
# Implementation would use conversation service
# For now, return mock response
return ConversationResponseDTO(
id="conv_123",
title=request.title,
created_at="2024-01-01T00:00:00Z",
updated_at="2024-01-01T00:00:00Z",
message_count=0
)
except Exception as e:
logger.error(f"Error creating conversation: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to create conversation")
@router.get("/", response_model=List[ConversationResponseDTO])
async def list_conversations(
current_user: dict = Depends(get_current_user),
limit: int = 20
):
"""List user's conversations"""
try:
# Implementation would fetch from repository
return []
except Exception as e:
logger.error(f"Error listing conversations: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to list conversations")
@router.get("/{conversation_id}", response_model=ConversationDetailResponseDTO)
async def get_conversation(
conversation_id: str,
current_user: dict = Depends(get_current_user)
):
"""Get conversation details with messages"""
try:
# Implementation would fetch from repository
return ConversationDetailResponseDTO(
id=conversation_id,
title="Sample Conversation",
created_at="2024-01-01T00:00:00Z",
updated_at="2024-01-01T00:00:00Z",
messages=[]
)
except Exception as e:
logger.error(f"Error getting conversation: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to get conversation")
@router.post("/{conversation_id}/messages", response_model=MessageResponseDTO)
async def send_message(
conversation_id: str,
request: SendMessageRequestDTO,
current_user: dict = Depends(get_current_user)
):
"""Send a message in a conversation"""
try:
# Implementation would use ChatConversationUseCase
return MessageResponseDTO(
id="msg_123",
message_type="assistant",
content="This is a sample response from the AI.",
timestamp="2024-01-01T00:00:00Z"
)
except Exception as e:
logger.error(f"Error sending message: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to send message")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment