mr_mvp_dev / vector_db /qdrant_crud.py
srivatsavdamaraju's picture
Upload 173 files
b2315b1 verified
# requirements.txt dependencies:
# fastapi
# uvicorn[standard]
# langchain-qdrant
# langchain-openai
# langchain-community
# qdrant-client
# fastembed
# langextract
# python-multipart
# python-dotenv
# pypdf
# pandas
from fastapi import FastAPI, File, UploadFile, HTTPException, APIRouter, Form
from fastapi import FastAPI, File, UploadFile, HTTPException, APIRouter, Form, Query
from fastapi.responses import JSONResponse
from typing import Optional, List, Dict, Any
from pydantic import BaseModel
import tempfile
import os
from dotenv import load_dotenv
from langchain_qdrant import QdrantVectorStore, RetrievalMode, FastEmbedSparse
from langchain_openai import OpenAIEmbeddings
from langchain_core.documents import Document
from langchain_community.document_loaders import (
PyPDFLoader,
CSVLoader,
TextLoader,
UnstructuredWordDocumentLoader,
UnstructuredExcelLoader
)
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
import langextract as lx
from retrieve_secret import *
load_dotenv()
# Configuration
# QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
# QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", None)
# OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# Initialize embeddings
embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
sparse_embeddings = FastEmbedSparse(model_name="Qdrant/bm25")
# Initialize Qdrant client
# print(QDRANT_URL , " " )
# print(QDRANT_API_KEY)
# # client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
# if not client:
# raise Exception("❌ Failed to connect to Qdrant.")
# client = QdrantClient(
# host="localhost",
# port="6333"
# )
client = QdrantClient(
host=QDRANT_HOST,
port=QDRANT_PORT
)
# Option 2: Using get_collections()
try:
collections = client.get_collections()
print("✅ Connected to Qdrant. Collections:", collections)
except Exception as e:
raise Exception(f"❌ Failed to connect to Qdrant: {e}")
# Option 3: Simple connection test
# FastAPI Qdrant_router
# Routers
# = APIRouter(prefix="/collections", tags=["Collections"])
# Qdrant_router = APIRouter(prefix="/documents", tags=["Documents"])
# Qdrant_router = APIRouter(prefix="/search", tags=["Search"])
# Qdrant_router = APIRouter(prefix="/qdrant",tags=["Qdrant_Collections"])
Qdrant_router = APIRouter(prefix="/qdrant", tags=["Qdrant_Collections"])
# ==================== PYDANTIC MODELS ====================
class CollectionCreate(BaseModel):
name: str
vector_size: int = 1536
distance: str = "Cosine"
metadata: Optional[Dict[str, Any]] = None
class CollectionUpdate(BaseModel):
metadata: Optional[Dict[str, Any]] = None
class DocumentCreate(BaseModel):
content: str
metadata: Optional[Dict[str, Any]] = None
class SearchQuery(BaseModel):
query: str
collection_name: str
k: int = 3
filter: Optional[Dict[str, Any]] = None
# ==================== HELPER FUNCTIONS ====================
def get_file_loader(file_path: str, file_type: str):
"""Return Qdrant_routerropriate document loader based on file type"""
loaders = {
'pdf': PyPDFLoader,
'csv': CSVLoader,
'txt': TextLoader,
'docx': UnstructuredWordDocumentLoader,
'doc': UnstructuredWordDocumentLoader,
'xlsx': UnstructuredExcelLoader,
'xls': UnstructuredExcelLoader
}
loader_class = loaders.get(file_type.lower())
if not loader_class:
raise ValueError(f"Unsupported file type: {file_type}")
return loader_class(file_path)
def extract_with_langextract(file_path: str) -> List[Document]:
"""Use LangExtract for advanced PDF extraction"""
# Load PDF content
loader = PyPDFLoader(file_path)
pages = loader.load()
documents = []
for page in pages:
prompt = """
Extract key entities, concepts, and relationships from the text.
Focus on important information, facts, and contextual details.
"""
try:
result = lx.extract(
text_or_documents=page.page_content,
prompt_description=prompt,
model_id="gpt-4o-mini",
api_key=OPENAI_API_KEY,
fence_output=True,
use_schema_constraints=False
)
# Create enhanced document with extractions
metadata = page.metadata.copy()
metadata['extractions'] = [
{
'class': e.extraction_class,
'text': e.extraction_text,
'attributes': e.attributes or {}
}
for e in result.extractions
]
documents.Qdrant_routerend(Document(
page_content=page.page_content,
metadata=metadata
))
except Exception as e:
print(f"LangExtract failed for page, using raw content: {e}")
documents.Qdrant_routerend(page)
return documents
# ==================== COLLECTION CRUD ENDPOINTS ====================
@Qdrant_router.post("/create_collection", status_code=201)
async def create_collection(collection: CollectionCreate):
"""Create a new collection with optional metadata"""
try:
# Check if collection exists
collections = client.get_collections().collections
if any(c.name == collection.name for c in collections):
raise HTTPException(status_code=400, detail=f"Collection '{collection.name}' already exists")
# Map distance metric
distance_map = {
"Cosine": Distance.COSINE,
"Euclid": Distance.EUCLID,
"Dot": Distance.DOT
}
# Create collection with initial empty documents
initial_doc = Document(
page_content="Initial placeholder document",
metadata=collection.metadata or {}
)
qdrant = QdrantVectorStore.from_documents(
[initial_doc],
embedding=embeddings,
sparse_embedding=sparse_embeddings,
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
collection_name=collection.name,
retrieval_mode=RetrievalMode.HYBRID,
)
return {
"message": f"Collection '{collection.name}' created successfully",
"collection_name": collection.name,
"metadata": collection.metadata
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@Qdrant_router.get("/list_all_collections")
async def list_collections():
"""List all collections"""
print("** LISTING ALL COLLECTIONS **")
try:
collections = client.get_collections().collections
print("collections", collections)
return {
"collections": [
{
"name": c.name,
"vectors_count": client.get_collection(c.name).points_count
}
for c in collections
]
}
except Exception as e:
print("this is excepetion ",e)
raise HTTPException(status_code=500, detail=str(e))
@Qdrant_router.get("/{collection_name}")
async def get_collection_info(collection_name: str):
"""Get detailed information about a collection"""
try:
info = client.get_collection(collection_name)
return {
"name": collection_name,
"points_count": info.points_count,
"vectors_count": info.vectors_count,
"status": info.status.value,
"config": {
"params": str(info.config.params)
}
}
except Exception as e:
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found")
@Qdrant_router.put("/{collection_name}")
async def update_collection_metadata(collection_name: str, update: CollectionUpdate):
"""Update collection metadata (note: this updates point metadata, not collection config)"""
try:
# Verify collection exists
client.get_collection(collection_name)
return {
"message": f"Collection '{collection_name}' metadata update acknowledged",
"note": "To update point metadata, use document update endpoints",
"metadata": update.metadata
}
except Exception as e:
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found")
@Qdrant_router.delete("/{collection_name}")
async def delete_collection(collection_name: str):
"""Delete a collection"""
try:
client.delete_collection(collection_name)
return {"message": f"Collection '{collection_name}' deleted successfully"}
except Exception as e:
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found")
# ==================== DOCUMENT CRUD ENDPOINTS ====================
@Qdrant_router.post("/{collection_name}/text")
async def add_text_document(collection_name: str, doc: DocumentCreate):
"""Add a text document to a collection with metadata"""
try:
qdrant = QdrantVectorStore.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
)
document = Document(
page_content=doc.content,
metadata=doc.metadata or {}
)
ids = qdrant.add_documents([document])
return {
"message": "Document added successfully",
"document_ids": ids,
"metadata": doc.metadata
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@Qdrant_router.post("/{collection_name}/upload")
async def upload_file_to_collection(
collection_name: str,
file: UploadFile = File(...),
use_langextract: bool = Form(False),
metadata: Optional[str] = Form(None)
):
"""
Upload and process a file (PDF, CSV, TXT, DOCX, XLSX) to a collection.
For PDFs, optionally use LangExtract for advanced entity extraction.
"""
try:
# Parse metadata if provided
import json
file_metadata = json.loads(metadata) if metadata else {}
# Get file extension
file_ext = file.filename.split('.')[-1].lower()
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_ext}') as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
# Process based on file type
if file_ext == 'pdf' and use_langextract:
documents = extract_with_langextract(tmp_path)
else:
loader = get_file_loader(tmp_path, file_ext)
documents = loader.load()
# Add file metadata to all documents
for doc in documents:
doc.metadata.update({
'source_file': file.filename,
'file_type': file_ext,
**file_metadata
})
# Add to Qdrant
qdrant = QdrantVectorStore.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
)
ids = qdrant.add_documents(documents)
return {
"message": f"File '{file.filename}' processed successfully",
"file_type": file_ext,
"documents_added": len(documents),
"document_ids": ids,
"used_langextract": use_langextract and file_ext == 'pdf'
}
finally:
# Clean up temp file
os.unlink(tmp_path)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@Qdrant_router.post("/{collection_name}/batch")
async def add_batch_documents(collection_name: str, docs: List[DocumentCreate]):
"""Add multiple text documents at once"""
try:
qdrant = QdrantVectorStore.from_existing_collection(
embedding=embeddings,
collection_name=collection_name,
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
)
documents = [
Document(page_content=doc.content, metadata=doc.metadata or {})
for doc in docs
]
ids = qdrant.add_documents(documents)
return {
"message": f"Added {len(documents)} documents successfully",
"document_ids": ids
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@Qdrant_router.get("/{collection_name}/count")
async def get_document_count(collection_name: str):
"""Get the number of documents in a collection"""
try:
info = client.get_collection(collection_name)
return {
"collection_name": collection_name,
"document_count": info.points_count
}
except Exception as e:
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found")
# ==================== SEARCH ENDPOINTS ====================
class QueryRequest(BaseModel):
query: str
top_k: Optional[int] = 2
class QueryResponseItem(BaseModel):
content: str
metadata: dict
class QueryResponse(BaseModel):
collection: str
mode: str
results: List[QueryResponseItem]
# --- Utility function ---
def get_vector_store(collection_name: str, mode: str):
"""Return a QdrantVectorStore for the given collection and mode."""
if mode == "dense":
return QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
retrieval_mode=RetrievalMode.DENSE,
vector_name="", # unnamed dense vector
)
elif mode == "sparse":
return QdrantVectorStore(
client=client,
collection_name=collection_name,
sparse_embedding=sparse_embeddings,
retrieval_mode=RetrievalMode.SPARSE,
sparse_vector_name="langchain-sparse",
)
elif mode == "hybrid":
return QdrantVectorStore(
client=client,
collection_name=collection_name,
embedding=embeddings,
sparse_embedding=sparse_embeddings,
retrieval_mode=RetrievalMode.HYBRID,
vector_name="",
sparse_vector_name="langchain-sparse",
)
else:
raise HTTPException(
status_code=400,
detail="Invalid mode. Choose one of: dense, sparse, hybrid.",
)
@Qdrant_router.post("/search", response_model=QueryResponse)
def search(
request: QueryRequest,
collection_name: str = Query(..., description="Name of the Qdrant collection"),
mode: str = Query(..., regex="^(dense|sparse|hybrid)$", description="Retrieval mode"),
):
"""Perform a similarity search using the specified collection and mode."""
try:
# Create vector store
vector_store = get_vector_store(collection_name, mode)
# Perform search
results = vector_store.similarity_search(request.query, k=request.top_k)
response_items = [
QueryResponseItem(content=doc.page_content, metadata=doc.metadata or {})
for doc in results
]
return QueryResponse(collection=collection_name, mode=mode, results=response_items)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
@Qdrant_router.get("/qdrant_db/health")
async def health_check():
"""Check if Qdrant connection is healthy"""
try:
collections = client.get_collections()
return {
"status": "healthy",
"qdrant_url": QDRANT_URL,
"collections_count": len(collections.collections)
}
except Exception as e:
return JSONResponse(
status_code=503,
content={"status": "unhealthy", "error": str(e)}
)