Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException, WebSocket | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import List, Optional, Dict, AsyncGenerator | |
| import os | |
| from dotenv import load_dotenv | |
| from aimakerspace.vectordatabase import VectorDatabase | |
| from aimakerspace.openai_utils.embedding import EmbeddingModel | |
| from aimakerspace.text_utils import CharacterTextSplitter, PDFLoader | |
| from aimakerspace.openai_utils.prompts import ( | |
| UserRolePrompt, | |
| SystemRolePrompt, | |
| AssistantRolePrompt, | |
| ) | |
| from aimakerspace.openai_utils.chatmodel import ChatOpenAI | |
| import asyncio | |
| import tempfile | |
| import shutil | |
| import json | |
| from uuid import uuid4 | |
| # Load environment variables | |
| load_dotenv() | |
| app = FastAPI() | |
| # Mount static files | |
| app.mount("/", StaticFiles(directory="static", html=True), name="static") | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:3000"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize components | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
| chat_openai = ChatOpenAI() | |
| # Define prompts | |
| system_template = """\ | |
| You are a helpful assistant that provides concise, direct answers based on the provided context. | |
| If the answer cannot be found in the context, simply say "I don't know" or "The information is not available in the provided context." | |
| Keep your answers brief and to the point.""" | |
| system_role_prompt = SystemRolePrompt(system_template) | |
| user_prompt_template = """\ | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer the question concisely based on the context above.""" | |
| user_role_prompt = UserRolePrompt(user_prompt_template) | |
| # Session management | |
| sessions: Dict[str, Dict] = {} | |
| class Query(BaseModel): | |
| text: str | |
| k: int = 4 | |
| class DocumentResponse(BaseModel): | |
| text: str | |
| type: str # 'answer' or 'context' | |
| score: Optional[float] = None | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, llm: ChatOpenAI, vector_db_retriever: VectorDatabase) -> None: | |
| self.llm = llm | |
| self.vector_db_retriever = vector_db_retriever | |
| async def arun_pipeline(self, user_query: str, k: int = 4) -> AsyncGenerator[str, None]: | |
| # Get top k most relevant chunks | |
| context_list = self.vector_db_retriever.search_by_text(user_query, k=k) | |
| # Format context | |
| context_prompt = "" | |
| for context in context_list: | |
| context_prompt += context[0] + "\n" | |
| # Format prompts | |
| formatted_system_prompt = system_role_prompt.create_message() | |
| formatted_user_prompt = user_role_prompt.create_message( | |
| question=user_query, | |
| context=context_prompt | |
| ) | |
| # Stream only the LLM response | |
| async for chunk in self.llm.astream([formatted_system_prompt, formatted_user_prompt]): | |
| yield json.dumps({ | |
| "type": "token", | |
| "text": chunk | |
| }) | |
| # Send context information once at the end | |
| yield json.dumps({ | |
| "type": "context", | |
| "context": [{"text": text, "score": score} for text, score in context_list] | |
| }) | |
| def process_file(file_path: str, file_name: str): | |
| if file_name.lower().endswith('.pdf'): | |
| loader = PDFLoader(file_path) | |
| else: | |
| raise HTTPException(status_code=400, detail="Only PDF files are supported") | |
| documents = loader.load_documents() | |
| texts = text_splitter.split_texts(documents) | |
| return texts | |
| async def upload_document(file: UploadFile = File(...)): | |
| if not file.filename.lower().endswith('.pdf'): | |
| raise HTTPException(status_code=400, detail="Only PDF files are supported") | |
| try: | |
| # Read the file content directly into memory | |
| content = await file.read() | |
| # Create a temporary file in a directory we know exists | |
| temp_dir = "/tmp" # Using /tmp which is writable in most environments | |
| os.makedirs(temp_dir, exist_ok=True) | |
| temp_path = os.path.join(temp_dir, f"upload_{file.filename}") | |
| # Write the content to the temporary file | |
| with open(temp_path, 'wb') as temp_file: | |
| temp_file.write(content) | |
| try: | |
| # Process the file | |
| texts = process_file(temp_path, file.filename) | |
| # Create a new session | |
| session_id = str(uuid4()) | |
| vector_db = VectorDatabase() | |
| await vector_db.abuild_from_list(texts) | |
| # Store session data | |
| sessions[session_id] = { | |
| "vector_db": vector_db, | |
| "texts": texts | |
| } | |
| return { | |
| "session_id": session_id, | |
| "message": f"Document processed successfully. Added {len(texts)} chunks to the database." | |
| } | |
| finally: | |
| # Clean up the temporary file | |
| try: | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| except Exception as e: | |
| print(f"Warning: Could not delete temporary file: {e}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
| async def query_documents(session_id: str, query: Query): | |
| if session_id not in sessions: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| try: | |
| session = sessions[session_id] | |
| vector_db = session["vector_db"] | |
| # Initialize RAG pipeline | |
| rag_pipeline = RetrievalAugmentedQAPipeline( | |
| llm=chat_openai, | |
| vector_db_retriever=vector_db | |
| ) | |
| # Create streaming response | |
| async def generate(): | |
| async for chunk in rag_pipeline.arun_pipeline(query.text, query.k): | |
| yield f"data: {chunk}\n\n" | |
| return StreamingResponse( | |
| generate(), | |
| media_type="text/event-stream" | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
| await websocket.accept() | |
| if session_id not in sessions: | |
| await websocket.close(code=1008, reason="Session not found") | |
| return | |
| try: | |
| session = sessions[session_id] | |
| vector_db = session["vector_db"] | |
| while True: | |
| data = await websocket.receive_text() | |
| query = json.loads(data) | |
| # Initialize RAG pipeline | |
| rag_pipeline = RetrievalAugmentedQAPipeline( | |
| llm=chat_openai, | |
| vector_db_retriever=vector_db | |
| ) | |
| # Stream response | |
| async for chunk in rag_pipeline.arun_pipeline(query["text"], query.get("k", 4)): | |
| await websocket.send_text(json.dumps({ | |
| "type": "token" if isinstance(chunk, str) else "context", | |
| "text": chunk if isinstance(chunk, str) else chunk | |
| })) | |
| except Exception as e: | |
| await websocket.close(code=1011, reason=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=9000) |