Commit 8f656a91 authored by ZeinabRm13's avatar ZeinabRm13

Add filesystem storage

parent 59fc0d8b
from fastapi import FastAPI from fastapi import FastAPI
from src.infrastructure.api.fastapi.routes import auth, charts from api.fastapi.routes import auth, charts
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from src.domain.entities.user import User
from src.application.dtos.authentication import ( from src.application.dtos.authentication import (
RegisterRequestDTO, RegisterRequestDTO,
LoginRequestDTO, LoginRequestDTO,
...@@ -10,7 +8,36 @@ from src.application.dtos.authentication import ( ...@@ -10,7 +8,36 @@ from src.application.dtos.authentication import (
class AuthServicePort(ABC): class AuthServicePort(ABC):
@abstractmethod @abstractmethod
async def register(self, email, password) -> User: ... async def register(self, register_dto: RegisterRequestDTO) -> UserResponseDTO:
"""
Register a new user with email and password
:param register_dto: RegisterRequestDTO containing email and password
:return: UserResponseDTO containing user details
"""
...
@abstractmethod
async def login(self, login_dto: LoginRequestDTO) -> TokenResponseDTO:
"""
Authenticate user and return JWT token
:param login_dto: LoginRequestDTO containing email and password
:return: TokenResponseDTO containing access token
"""
...
@abstractmethod
async def logout(self, token: str) -> None:
"""
Invalidate a JWT token
:param token: JWT token to invalidate
"""
...
@abstractmethod @abstractmethod
async def login(self, email, password) -> str: ... async def validate_token(self, token: str) -> bool:
\ No newline at end of file """
Validate if a JWT token is still valid
:param token: JWT token to validate
:return: Boolean indicating token validity
"""
...
\ No newline at end of file
from abc import ABC, abstractmethod
from uuid import UUID
from typing import Tuple, Optional
from datetime import datetime
class FileStoragePort(ABC):
"""Abstract base class for file storage operations"""
@abstractmethod
async def save_chart_image(
self,
user_id: UUID,
image_data: bytes,
content_type: str
) -> str:
"""Save chart image to storage and return file path"""
pass
@abstractmethod
async def get_chart_image(self, file_path: str) -> Tuple[bytes, str]:
"""Retrieve chart image data and content type"""
pass
@abstractmethod
async def generate_thumbnail(
self,
original_path: str,
output_path: str,
dimensions: Tuple[int, int] = (300, 300)
) -> str:
"""Generate and save thumbnail version"""
pass
@abstractmethod
async def delete_file(self, file_path: str) -> bool:
"""Delete a stored file"""
pass
@abstractmethod
def generate_file_path(
self,
user_id: UUID,
extension: str
) -> str:
"""Generate storage path for a new file"""
pass
\ No newline at end of file
# src/application/services/auth.py # src/application/services/auth.py
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from jose import jwt from jose import jwt, JWTError
from passlib.context import CryptContext from passlib.context import CryptContext
from src.domain.ports.repositories.user_repository import UserRepositoryPort from src.domain.ports.repositories.user_repository import UserRepositoryPort
from src.domain.ports.repositories.token_repository import TokenRepositoryPort from src.domain.ports.repositories.token_repository import TokenRepositoryPort
...@@ -15,15 +15,11 @@ from src.application.dtos.authentication import ( ...@@ -15,15 +15,11 @@ from src.application.dtos.authentication import (
TokenResponseDTO TokenResponseDTO
) )
# Update src/application/services/auth.py
from jose import JWTError, jwt
from datetime import datetime, timezone, timedelta
class AuthService(AuthServicePort): class AuthService(AuthServicePort):
def __init__( def __init__(
self, self,
user_repo: UserRepositoryPort, user_repo: UserRepositoryPort,
token_repo: TokenRepositoryPort, # Add this token_repo: TokenRepositoryPort,
secret_key: str = Settings().JWT_SECRET, secret_key: str = Settings().JWT_SECRET,
algorithm: str = "HS256", algorithm: str = "HS256",
expires_minutes: int = 30 expires_minutes: int = 30
...@@ -35,8 +31,38 @@ class AuthService(AuthServicePort): ...@@ -35,8 +31,38 @@ class AuthService(AuthServicePort):
self._expires_minutes = expires_minutes self._expires_minutes = expires_minutes
self._pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") self._pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
async def register(self, register_dto: RegisterRequestDTO) -> UserResponseDTO:
"""
Register a new user with email and password
Returns UserResponseDTO containing user details
"""
if await self._user_repo.get_by_email(register_dto.email):
raise ValueError("Email already registered")
user = User(
id=str(uuid.uuid4()),
email=register_dto.email,
password_hash=self._hash_password(register_dto.password),
is_active=True
)
await self._user_repo.create_user(user)
return self._user_to_dto(user)
async def login(self, login_dto: LoginRequestDTO) -> TokenResponseDTO:
"""
Authenticate user and return JWT token
Returns TokenResponseDTO containing access token
"""
user = await self._user_repo.get_by_email(login_dto.email)
if not user or not self._verify_password(login_dto.password, user.password_hash):
raise ValueError("Invalid credentials")
access_token = self._create_access_token(user.email)
return TokenResponseDTO(access_token=access_token)
async def logout(self, token: str) -> None: async def logout(self, token: str) -> None:
"""Invalidate a JWT token""" """Invalidate a JWT token by adding it to blacklist"""
try: try:
payload = jwt.decode(token, self._secret_key, algorithms=[self._algorithm]) payload = jwt.decode(token, self._secret_key, algorithms=[self._algorithm])
exp = payload.get("exp") exp = payload.get("exp")
...@@ -47,7 +73,10 @@ class AuthService(AuthServicePort): ...@@ -47,7 +73,10 @@ class AuthService(AuthServicePort):
pass # Token is invalid anyway pass # Token is invalid anyway
async def validate_token(self, token: str) -> bool: async def validate_token(self, token: str) -> bool:
"""Check if token is valid and not blacklisted""" """
Check if token is valid and not blacklisted
Returns boolean indicating token validity
"""
try: try:
if await self._token_repo.is_blacklisted(token): if await self._token_repo.is_blacklisted(token):
return False return False
...@@ -57,36 +86,16 @@ class AuthService(AuthServicePort): ...@@ -57,36 +86,16 @@ class AuthService(AuthServicePort):
except JWTError: except JWTError:
return False return False
async def register(self, email, password) -> User:
if await self._user_repo.get_by_email(email):
raise ValueError("Email already registered")
user = User(
id=str(uuid.uuid4()),
email=email,
password_hash=self._hash_password(password)
)
await self._user_repo.create_user(user)
return user
async def login(self, email: str, password: str) -> str:
user = await self._user_repo.get_by_email(email)
if not user or not self._verify_password(password, user.password_hash):
raise ValueError("Invalid credentials")
# Return just the token string, not the whole response
return self._create_access_token(user.email)
def _hash_password(self, password: str) -> str: def _hash_password(self, password: str) -> str:
"""Hash password using bcrypt"""
return self._pwd_context.hash(password) return self._pwd_context.hash(password)
def _verify_password(self, plain_password: str, hashed_password: str) -> bool: def _verify_password(self, plain_password: str, hashed_password: str) -> bool:
"""Verify password against stored hash"""
return self._pwd_context.verify(plain_password, hashed_password) return self._pwd_context.verify(plain_password, hashed_password)
def _create_access_token(self, email: str) -> str: def _create_access_token(self, email: str) -> str:
"""Create JWT token with expiration"""
expires = datetime.now(timezone.utc) + timedelta(minutes=self._expires_minutes) expires = datetime.now(timezone.utc) + timedelta(minutes=self._expires_minutes)
return jwt.encode( return jwt.encode(
{"sub": email, "exp": expires}, {"sub": email, "exp": expires},
...@@ -95,6 +104,7 @@ class AuthService(AuthServicePort): ...@@ -95,6 +104,7 @@ class AuthService(AuthServicePort):
) )
def _user_to_dto(self, user: User) -> UserResponseDTO: def _user_to_dto(self, user: User) -> UserResponseDTO:
"""Convert User entity to UserResponseDTO"""
return UserResponseDTO( return UserResponseDTO(
id=user.id, id=user.id,
email=user.email, email=user.email,
......
from src.infrastructure.persistence.models.base import Base
from src.infrastructure.persistence.models.user_model import UserModel
from src.infrastructure.persistence.models.chart_model import ChartImageModel
from src.infrastructure.persistence.models.conversation_models import ConversationModel, ConversationMessageModel
from src.infrastructure.persistence.models.blacklisted_token_model import BlacklistedTokenModel
from src.infrastructure.persistence.models.analysis_model import ChartAnalysisModel
\ No newline at end of file
...@@ -4,17 +4,14 @@ from sqlalchemy.orm import relationship ...@@ -4,17 +4,14 @@ from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
import uuid import uuid
class ChartAnalysis(Base): class ChartAnalysisModel(Base, TimestampMixin):
__tablename__ = 'chart_analyses' __tablename__ = 'chart_analysis'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
chart_image_id = Column(UUID(as_uuid=True), ForeignKey('chart_images.id'), nullable=False) chart_image_id = Column(UUID(as_uuid=True), ForeignKey('chart_images.id'), nullable=False)
question = Column(Text, nullable=False) question = Column(Text, nullable=False)
answer = Column(Text, nullable=False) answer = Column(Text, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) metadata = Column(JSONB) # Added for additional analysis data
# Relationship # Relationship
chart_image = relationship("ChartImage", back_populates="analyses") chart_image = relationship("ChartImageModel", back_populates="analysis")
def __repr__(self):
return f"<ChartAnalysis(id={self.id}, chart_image_id={self.chart_image_id})>"
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, DateTime, func
from datetime import datetime, timezone
Base = declarative_base()
class TimestampMixin:
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
\ No newline at end of file
...@@ -7,13 +7,10 @@ from sqlalchemy.dialects.postgresql import UUID ...@@ -7,13 +7,10 @@ from sqlalchemy.dialects.postgresql import UUID
import uuid import uuid
class BlacklistedToken(Base): class BlacklistedTokenModel(Base):
__tablename__ = 'blacklisted_tokens' __tablename__ = 'blacklisted_tokens'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
token = Column(String(512), unique=True, nullable=False) token = Column(String(512), unique=True, nullable=False)
expires_at = Column(DateTime(timezone=True), nullable=False) expires_at = Column(DateTime(timezone=True), nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<BlacklistedToken(token={self.token[:10]}...)>"
\ No newline at end of file
...@@ -7,16 +7,15 @@ from sqlalchemy.dialects.postgresql import UUID ...@@ -7,16 +7,15 @@ from sqlalchemy.dialects.postgresql import UUID
import uuid import uuid
class ChartImage(Base): class ChartImageModel(Base, TimestampMixin):
__tablename__ = 'chart_images' __tablename__ = 'chart_images'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
user_id = Column(UUID(as_uuid=True), ForeignKey('users.id'), nullable=False) user_id = Column(UUID(as_uuid=True), ForeignKey('users.id'), nullable=False)
image_data = Column(LargeBinary, nullable=False) # For storing binary data file_path = Column(String(512), nullable=False) # Changed from image_data to file_path
uploaded_at = Column(DateTime(timezone=True), server_default=func.now()) thumbnail_path = Column(String(512), nullable=True)
# Relationship # Relationships
analyses = relationship("ChartAnalysis", back_populates="chart_image") user = relationship("UserModel", back_populates="chart_images")
analysis = relationship("ChartAnalysisModel", back_populates="chart_image")
def __repr__(self): conversations = relationship("ConversationModel", back_populates="chart_image")
return f"<ChartImage(id={self.id}, user_id={self.user_id})>"
\ No newline at end of file
...@@ -3,32 +3,30 @@ from sqlalchemy.orm import relationship ...@@ -3,32 +3,30 @@ from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from src.infrastructure.persistence.models.base import Base from src.infrastructure.persistence.models.base import Base
from datetime import datetime, timezone from datetime import datetime, timezone
class ConversationModel(Base, TimestampMixin):
class ConversationModel(Base):
"""SQLAlchemy model for conversations"""
__tablename__ = "conversations" __tablename__ = "conversations"
id = Column(UUID(as_uuid=True), primary_key=True) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
user_id = Column(UUID(as_uuid=True), nullable=False) user_id = Column(UUID(as_uuid=True), ForeignKey('users.id'), nullable=False)
chart_image_id = Column(UUID(as_uuid=True), nullable=False) chart_image_id = Column(UUID(as_uuid=True), ForeignKey('chart_images.id'), nullable=True)
title = Column(String(255), nullable=False) title = Column(String(255), nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
# Relationship to messages # Relationships
messages = relationship("ConversationMessageModel", back_populates="conversation", cascade="all, delete-orphan") user = relationship("UserModel", back_populates="conversations")
chart_image = relationship("ChartImageModel", back_populates="conversations")
messages = relationship("ConversationMessageModel", back_populates="conversation",
cascade="all, delete-orphan")
class ConversationMessageModel(Base): class ConversationMessageModel(Base):
"""SQLAlchemy model for conversation messages"""
__tablename__ = "conversation_messages" __tablename__ = "conversation_messages"
id = Column(UUID(as_uuid=True), primary_key=True) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"), nullable=False) conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"), nullable=False)
user_id = Column(UUID(as_uuid=True), nullable=False)
message_type = Column(String(20), nullable=False) # 'user' or 'assistant' message_type = Column(String(20), nullable=False) # 'user' or 'assistant'
content = Column(Text, nullable=False) content = Column(Text, nullable=False)
timestamp = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), server_default=func.now())
metadata = Column(JSONB) # For additional message data
# Relationship to conversation
conversation = relationship("ConversationModel", back_populates="messages") # Relationship
\ No newline at end of file conversation = relationship("ConversationModel", back_populates="messages")
\ No newline at end of file
...@@ -4,20 +4,17 @@ from sqlalchemy.orm import declarative_base ...@@ -4,20 +4,17 @@ from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
import uuid import uuid
from .base import Base, TimestampMixin
# 1. Create a Base class that all your models will inherit from. class UserModel(Base, TimestampMixin):
Base = declarative_base()
class User(Base):
__tablename__ = 'users' __tablename__ = 'users'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
email = Column(String(255), unique=True, nullable=False) email = Column(String(255), unique=True, nullable=False)
password_hash = Column(String(255), nullable=False) # Store only hashed passwords password_hash = Column(String(255), nullable=False)
is_active = Column(Boolean, default=True) is_active = Column(Boolean, default=True)
created_at = Column(DateTime(timezone=True), server_default=func.now())
last_login = Column(DateTime(timezone=True), nullable=True) last_login = Column(DateTime(timezone=True), nullable=True)
def __repr__(self): # Relationships
return f"<User(id={self.id}, email={self.email})>" chart_images = relationship("ChartImageModel", back_populates="user")
conversations = relationship("ConversationModel", back_populates="user")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment