mr_mvp_dev / vector_db /v2.py.txt
srivatsavdamaraju's picture
Upload 173 files
b2315b1 verified
from fastapi import FastAPI, Body, HTTPException, UploadFile, File
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext, Settings, Document
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.openai import OpenAIEmbedding
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from dotenv import load_dotenv
import os
import shutil
import tempfile
from fastapi import APIRouter
# Load environment variables from .env file
load_dotenv()
# Initialize FastAPI qdrant_llama_index_router
qdrant_llama_index_router = APIRouter(prefix="/qdrant_llama_index", tags=["qdrant_llama_index"])
# Set up OpenAI embedding model
Settings.embed_model = OpenAIEmbedding(
model="text-embedding-ada-002",
api_key=os.getenv("OPENAI_API_KEY")
)
# Qdrant client for remote instance
try:
client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY")
)
except Exception as e:
raise Exception(f"Failed to connect to Qdrant: {str(e)}")
# --- Helper function to get VectorStoreIndex ---
def get_vector_store_index(collection_name: str):
"""Helper to get a VectorStoreIndex for a given collection."""
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
return VectorStoreIndex.from_vector_store(vector_store=vector_store)
# --- API Endpoints ---
@qdrant_llama_index_router.get("/")
async def read_root():
return {"message": "Welcome to the Qdrant Collections API! Visit /docs for API documentation."}
@qdrant_llama_index_router.post("/qdrant/create_collection")
async def create_collection(
collection_params: dict = Body(
...,
example={
"name": "my_new_collection",
"vector_size": 1536,
"distance": "Cosine"
}
)
):
"""
Create a new collection with specified parameters.
- name: Name of the collection (required)
- vector_size: Size of the vectors (required, e.g., 1536 for OpenAI embeddings)
- distance: Distance metric (required, e.g., "Cosine", "Euclid", "Dot")
"""
try:
name = collection_params.get("name")
vector_size = collection_params.get("vector_size")
distance_str = collection_params.get("distance")
if not name or not isinstance(vector_size, int) or not distance_str:
raise HTTPException(status_code=422, detail="Invalid payload: 'name', 'vector_size', and 'distance' are required.")
try:
distance = Distance[distance_str.upper()]
except KeyError:
raise HTTPException(status_code=422, detail=f"Invalid distance metric: {distance_str}. Must be one of 'Cosine', 'Euclid', 'Dot'.")
# Check if collection already exists
if client.collection_exists(collection_name=name):
return {"status": "info", "message": f"Collection '{name}' already exists."}
client.create_collection(
collection_name=name,
vectors_config=VectorParams(size=vector_size, distance=distance),
)
print(f"Collection '{name}' created successfully.")
return {"status": "success", "message": f"Collection '{name}' created."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating collection: {str(e)}")
@qdrant_llama_index_router.get("/qdrant/list_all_collections")
async def list_all_collections():
"""
List all collections in Qdrant.
"""
try:
collections = client.get_collections()
return {"status": "success", "collections": [c.name for c in collections.collections]}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error listing collections: {str(e)}")
@qdrant_llama_index_router.get("/qdrant/{collection_name}")
async def get_collection_info(collection_name: str):
"""
Get information about a specific collection.
- collection_name: The name of the collection
"""
try:
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found.")
collection_info = client.get_collection(collection_name=collection_name)
return {"status": "success", "collection_info": collection_info.dict()}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}")
@qdrant_llama_index_router.put("/qdrant/{collection_name}")
async def update_collection_metadata(
collection_name: str,
metadata: dict = Body(
...,
example={
"field_name": "new_value"
}
)
):
"""
Update collection metadata (placeholder - Qdrant's update is for config, not arbitrary metadata).
For actual metadata, you'd update points within the collection.
- collection_name: The name of the collection
- metadata: A dictionary of metadata to "update" (this is a conceptual placeholder)
"""
# Qdrant's `update_collection` is for configuration, not arbitrary metadata.
# To update metadata for points, you would need to query points and then update them.
# This endpoint serves as a conceptual placeholder as per the image.
# For a real implementation, you might update collection aliases or other config.
return {"status": "info", "message": f"Collection '{collection_name}' metadata update endpoint (conceptual). Qdrant's update is for configuration, not arbitrary collection-level metadata. Received: {metadata}"}
@qdrant_llama_index_router.delete("/qdrant/{collection_name}")
async def delete_collection(collection_name: str):
"""
Delete a collection.
- collection_name: The name of the collection to delete
"""
try:
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found.")
client.delete_collection(collection_name=collection_name)
return {"status": "success", "message": f"Collection '{collection_name}' deleted."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/{collection_name}/text")
async def add_text_document(
collection_name: str,
text_data: dict = Body(
...,
example={
"text": "This is a sample document about FastAPI and Qdrant.",
"metadata": {"author": "GitHub Copilot", "source": "API"}
}
)
):
"""
Add a single text document to a collection.
- collection_name: The name of the collection
- text_data: A dictionary containing 'text' and optional 'metadata'
"""
try:
text = text_data.get("text")
metadata = text_data.get("metadata", {})
if not text:
raise HTTPException(status_code=422, detail="Invalid payload: 'text' is required.")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Please create it first.")
documents = [Document(text=text, metadata=metadata)]
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True
)
return {"status": "success", "message": f"Text document added to collection '{collection_name}'."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error adding text document: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/{collection_name}/upload")
async def upload_file_to_collection(
collection_name: str,
file: UploadFile = File(...)
):
"""
Upload a file (e.g., CSV, PDF, TXT) to be ingested into a collection.
- collection_name: The name of the collection
- file: The file to upload
"""
try:
print(f"Attempting to upload file '{file.filename}' to collection '{collection_name}'")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Please create it first.")
print(f"Collection '{collection_name}' exists.")
# Save the uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as temp_file:
shutil.copyfileobj(file.file, temp_file)
temp_file_path = temp_file.name
print(f"File '{file.filename}' saved temporarily to '{temp_file_path}'.")
try:
# Load the file into Document objects using SimpleDirectoryReader
documents = SimpleDirectoryReader(input_files=[temp_file_path]).load_data()
if not documents:
raise HTTPException(status_code=400, detail=f"No documents could be loaded from the file '{file.filename}'. Check file format or content.")
print(f"Loaded {len(documents)} document(s) from uploaded file: {file.filename}")
if documents: # Print a snippet of the first document for verification
print(f"First document text snippet: {documents[0].text[:200]}...")
# Set up Qdrant vector store
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
enable_hybrid=False, # Disable hybrid to avoid FastEmbed dependency unless explicitly configured
batch_size=20
)
print(f"QdrantVectorStore initialized for collection '{collection_name}'.")
# Create storage context and index
storage_context = StorageContext.from_defaults(vector_store=vector_store)
print("StorageContext created.")
# Before ingestion, get current count
initial_count_result = client.count(collection_name=collection_name, exact=True)
initial_count = initial_count_result.count
print(f"Initial document count in '{collection_name}': {initial_count}")
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True # This should print progress to console
)
print("VectorStoreIndex created and documents ingested.")
# After ingestion, get new count
final_count_result = client.count(collection_name=collection_name, exact=True)
final_count = final_count_result.count
print(f"Final document count in '{collection_name}': {final_count}")
if final_count > initial_count:
print(f"Successfully added {final_count - initial_count} new documents.")
else:
print("No new documents were added or count did not increase. This might indicate an issue.")
return {"status": "success", "message": f"File '{file.filename}' ingested into '{collection_name}'. Added {final_count - initial_count} documents."}
finally:
# Clean up the temporary file
os.unlink(temp_file_path)
print(f"Temporary file '{temp_file_path}' deleted.")
except HTTPException as e:
print(f"HTTP Exception during upload: {e.detail}")
raise e
except Exception as e:
print(f"Unhandled Exception during upload: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error during file ingestion: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/{collection_name}/batch")
async def add_batch_documents(
collection_name: str,
documents_data: list[dict] = Body(
...,
example=[
{"text": "First document text.", "metadata": {"id": 1}},
{"text": "Second document text.", "metadata": {"id": 2}}
]
)
):
"""
Add multiple text documents in a batch to a collection.
- collection_name: The name of the collection
- documents_data: A list of dictionaries, each containing 'text' and optional 'metadata'
"""
try:
if not documents_data:
raise HTTPException(status_code=422, detail="Invalid payload: 'documents_data' cannot be empty.")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Please create it first.")
llama_documents = []
for doc_data in documents_data:
text = doc_data.get("text")
metadata = doc_data.get("metadata", {})
if not text:
raise HTTPException(status_code=422, detail="Each document in batch must have a 'text' field.")
llama_documents.append(Document(text=text, metadata=metadata))
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
llama_documents,
storage_context=storage_context,
show_progress=True
)
return {"status": "success", "message": f"{len(llama_documents)} documents added in batch to collection '{collection_name}'."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error adding batch documents: {str(e)}")
@qdrant_llama_index_router.get("/qdrant/{collection_name}/count")
async def get_document_count(collection_name: str):
"""
Get the number of documents (points) in a collection.
- collection_name: The name of the collection
"""
try:
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found.")
count_result = client.count(collection_name=collection_name, exact=True)
return {"status": "success", "collection_name": collection_name, "count": count_result.count}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting document count: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/search")
async def search_collection(
search_params: dict = Body(
...,
example={
"collection_name": "visa_dataset_collection",
"query": "What is the average salary?",
"top_k": 5
}
)
):
"""
Search a specific Qdrant collection.
- collection_name: The name of the collection to search (required)
- query: The query string (required)
- top_k: Number of results to return (optional, default 5)
"""
try:
collection_name = search_params.get("collection_name")
query = search_params.get("query")
top_k = search_params.get("top_k", 5)
if not collection_name or not query:
raise HTTPException(status_code=422, detail="Invalid payload: 'collection_name' and 'query' are required.")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Cannot perform search.")
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
query_engine = index.as_query_engine(similarity_top_k=top_k)
response = query_engine.query(query)
return {"status": "success", "collection_name": collection_name, "query": query, "response": str(response)}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during search: {str(e)}")