Commit adcc7a11 authored by ZeinabRm13's avatar ZeinabRm13

Add Docs dir

parent deb050c8
......@@ -5,6 +5,7 @@ from src.dependencies import get_llm_service_by_query, get_llm_service
from src.config import settings
from typing import Optional
import logging
import time # Required for time.time()
# Set up logging
logging.basicConfig(level=logging.INFO)
......@@ -59,35 +60,42 @@ async def analyze_chart_specialized(
- advanced_analysis: When True, adds specific prompt engineering for charts
"""
try:
logger.info(f"Starting specialized analysis with question: {question}, advanced: {advanced_analysis}")
logger.info(f"Starting analysis - Question: '{question}', Advanced: {advanced_analysis}")
# Force ChartGemma for specialized chart analysis
service = get_llm_service("chartgemma")
image_bytes = await image.read()
# Enhanced prompt for charts if needed
processed_question = f"PROGRAM OF THOUGHT: {question}" if advanced_analysis else question
# Enhanced prompt handling
processed_question = question
if advanced_analysis:
processed_question = f"ANALYZE IN DETAIL: {question}"
logger.info("Using advanced analysis mode")
image_bytes = await image.read()
request = LLMRequestDTO(
image_bytes=image_bytes,
question=processed_question
)
logger.info("Calling ChartGemma service...")
logger.debug("Calling ChartGemma service...")
start_time = time.time()
response = await service.analyze(request)
logger.info(f"ChartGemma service response received: {response.answer[:100]}...")
processing_time = time.time() - start_time
logger.info(f"Analysis completed in {processing_time:.2f}s")
return {
"answer": response.answer,
"model": "ChartGemma",
"analysis_type": "advanced" if advanced_analysis else "basic",
"processing_time": response.processing_time
"processing_time": processing_time # Now we calculate it ourselves
}
except Exception as e:
logger.error(f"Error in analyze_chart_specialized: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
logger.error(f"Analysis failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Chart analysis error: {str(e)}"
)
@router.get("/test-chartgemma")
async def test_chartgemma_connection():
"""
......
......@@ -12,7 +12,7 @@ class Settings(BaseSettings):
# LLM Service Configuration
OLLAMA_HOST: str = "http://localhost:11434"
OLLAMA_MODEL: str = "llava:34b"
CHARTGEMMA_GRADIO_URL: str = "https://3f4f53fba3f99b8778.gradio.live/"
CHARTGEMMA_GRADIO_URL: str = "https://c55d0e06f974d4ba67.gradio.live/"
DEFAULT_LLM_MODEL: str = "ollama"
# Use model_config instead of inner Config class
......
......@@ -15,43 +15,103 @@ class ChartGemmaService(LLMServicePort):
async def analyze(self, request: LLMRequestDTO) -> LLMResponseDTO:
try:
# Check if image data is provided
if not request.image_bytes:
return LLMResponseDTO(
answer="Error: No image data provided",
model_used="ChartGemma",
processing_time=None
)
# Convert image (if provided as bytes) to PIL Image
image = self._load_image(request.image_bytes)
if not image:
return LLMResponseDTO(
answer="Error: Could not load image data",
model_used="ChartGemma",
processing_time=None
)
# Call Gradio API
response = await self._call_gradio_api(
image=image,
question=request.question
)
return LLMResponseDTO(
answer=response,
model_used="ChartGemma",
processing_time=None # Gradio doesn't return this
)
# Create temporary file
with tempfile.NamedTemporaryFile(suffix='.png', delete=True) as tmp_file:
try:
tmp_file.write(request.image_bytes)
tmp_file.flush()
#os.chmod(tmp_file.name, 0o644)
# Explicitly close the file handle before API call
tmp_file.close()
# Call Gradio API - NOTE: Changed parameter name to match method
response = await self._call_gradio_api(
image_path =tmp_file.name, # This matches the method parameter
question =request.question
)
return LLMResponseDTO(
answer=response,
model_used="ChartGemma",
processing_time=None
)
finally:
try:
os.unlink(tmp_file.name)
except:
pass
except Exception as e:
# Return a more informative error response
return LLMResponseDTO(
answer=f"Error analyzing chart: {str(e)}",
model_used="ChartGemma",
processing_time=None
)
async def _call_gradio_api(self, image_path: str, question: str) -> str:
"""Robust Gradio API caller with proper error handling and resource management"""
payload = {
"data": [
image_path, # File path
question
]
}
# Configure HTTP transport with retries
transport = httpx.AsyncHTTPTransport(
retries=3, # Automatically retry failed requests
verify=False # Only use if you get SSL errors (not recommended for production)
)
async with httpx.AsyncClient(
transport=transport,
timeout=30.0,
follow_redirects=True # Important for Gradio's HTTPS redirects
) as client:
# Try both common Gradio endpoint patterns
endpoints = [
"https://c55d0e06f974d4ba67.gradio.live/run/predict/", # Older versions
"https://c55d0e06f974d4ba67.gradio.live/" # Newer versions
]
for endpoint in endpoints:
try:
response = await client.post(
endpoint,
json=payload,
headers={"Content-Type": "application/json"}
)
response.raise_for_status() # Raises HTTPStatusError for 4XX/5XX
# Debug output
print(f"Successful API call to {endpoint}")
print(f"Response: {response.text[:200]}...") # Truncate long responses
return response.json()["data"][0]
except httpx.HTTPStatusError as e:
error_detail = f"Status: {e.response.status_code}"
if e.response.content:
error_detail += f" | Response: {e.response.text[:200]}"
print(f"API Error at {endpoint}: {error_detail}")
continue # Try next endpoint
except httpx.RequestError as e:
print(f"Network error at {endpoint}: {str(e)}")
continue # Try next endpoint
# All endpoints failed
return "Error: Failed to connect to Gradio API after multiple attempts"
def _load_image(self, image_data: Optional[bytes]) -> Optional[Image.Image]:
if not image_data:
return None
......@@ -61,32 +121,4 @@ class ChartGemmaService(LLMServicePort):
print(f"Error loading image: {e}")
return None
async def _call_gradio_api(self, image: Image.Image, question: str) -> str:
# Save image to a temporary file (since Gradio expects filepath)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
image.save(tmp_file, format='PNG')
tmp_file_path = tmp_file.name
try:
# Prepare the payload according to Gradio's API format
payload = {
"data": [
tmp_file_path, # Filepath to the image
question # Optional question
]
}
async with httpx.AsyncClient() as client:
# The Gradio API endpoint is /api/predict/
api_url = f"{self.gradio_url.rstrip('/')}/run/predict/"
response = await client.post(api_url, json=payload)
response.raise_for_status()
# Gradio returns the output in the "data" field as a list
return response.json()["data"][0]
finally:
# Clean up the temporary file
try:
os.unlink(tmp_file_path)
except:
pass
\ No newline at end of file
\ No newline at end of file
from paddlex import create_model
model = create_model('PP-Chart2Table')
results = model.predict(
input={"image": "/home/zeinabrm/Documents/ChartToText/Project/ChartAnalyzer/multiset_barchart.png"},
batch_size=1
)
for res in results:
res.print()
res.save_to_json(f"./output/res.json")
from gradio_client import Client
from PIL import Image
import base64
import io
def test_gradio():
client = Client("https://5be063b83d5cf0d78c.gradio.live/")
# 1. Load and encode image properly
img_path = "/home/zeinabrm/Documents/ChartToText/Project/ChartAnalyzer/multiset_barchart.png"
with Image.open(img_path) as img:
# Convert to RGB if needed
if img.mode != 'RGB':
img = img.convert('RGB')
# Convert to base64
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
# 2. Create proper image dictionary
img_data = {
"data": f"data:image/png;base64,{img_str}",
"name": "chart.png"
}
# 3. Make the API call
try:
result = client.predict(
img_data, # Send as properly formatted image dict
"What does this chart show?",
api_name="/predict"
)
print("✅ Success:", result)
except Exception as e:
print("❌ Failed:", str(e))
if hasattr(e, 'response'):
print("Server response:", e.response.text)
if __name__ == "__main__":
test_gradio()
# Available models: models=[Model(model='qwen2.5:latest', modified_at=datetime.datetime(2025, 4, 20, 17, 56, 38, 480761, tzinfo=TzInfo(UTC)), digest='845dbda0ea48ed749caafd9e6037047aa19acfcfd82e704d7ca97d631a0b697e', size=4683087332, details=ModelDetails(parent_model='', format='gguf', family='qwen2', families=['qwen2'], parameter_size='7.6B', quantization_level='Q4_K_M'))
# , Model(model='almaghrabima/ALLaM-Thinking:latest', modified_at=datetime.datetime(2025, 3, 20, 19, 42, 49, 644450, tzinfo=TzInfo(UTC)), digest='dd366456e63970b19d414a5fcf6e210d25d3e754eb797aac1e8850885e93f365', size=4263195351, details=ModelDetails(parent_model='', format='gguf', family='llama', families=['llama'], parameter_size='7.0B', quantization_level='Q4_K_M')),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment