Spaces:
Sleeping
Sleeping
File size: 16,757 Bytes
b2315b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
from fastapi import FastAPI, Body, HTTPException, UploadFile, File
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex, StorageContext, Settings, Document
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.embeddings.openai import OpenAIEmbedding
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from dotenv import load_dotenv
import os
import shutil
import tempfile
from fastapi import APIRouter
# Load environment variables from .env file
load_dotenv()
# Initialize FastAPI qdrant_llama_index_router
qdrant_llama_index_router = APIRouter(prefix="/qdrant_llama_index", tags=["qdrant_llama_index"])
# Set up OpenAI embedding model
Settings.embed_model = OpenAIEmbedding(
model="text-embedding-ada-002",
api_key=os.getenv("OPENAI_API_KEY")
)
# Qdrant client for remote instance
try:
client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY")
)
except Exception as e:
raise Exception(f"Failed to connect to Qdrant: {str(e)}")
# --- Helper function to get VectorStoreIndex ---
def get_vector_store_index(collection_name: str):
"""Helper to get a VectorStoreIndex for a given collection."""
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
return VectorStoreIndex.from_vector_store(vector_store=vector_store)
# --- API Endpoints ---
@qdrant_llama_index_router.get("/")
async def read_root():
return {"message": "Welcome to the Qdrant Collections API! Visit /docs for API documentation."}
@qdrant_llama_index_router.post("/qdrant/create_collection")
async def create_collection(
collection_params: dict = Body(
...,
example={
"name": "my_new_collection",
"vector_size": 1536,
"distance": "Cosine"
}
)
):
"""
Create a new collection with specified parameters.
- name: Name of the collection (required)
- vector_size: Size of the vectors (required, e.g., 1536 for OpenAI embeddings)
- distance: Distance metric (required, e.g., "Cosine", "Euclid", "Dot")
"""
try:
name = collection_params.get("name")
vector_size = collection_params.get("vector_size")
distance_str = collection_params.get("distance")
if not name or not isinstance(vector_size, int) or not distance_str:
raise HTTPException(status_code=422, detail="Invalid payload: 'name', 'vector_size', and 'distance' are required.")
try:
distance = Distance[distance_str.upper()]
except KeyError:
raise HTTPException(status_code=422, detail=f"Invalid distance metric: {distance_str}. Must be one of 'Cosine', 'Euclid', 'Dot'.")
# Check if collection already exists
if client.collection_exists(collection_name=name):
return {"status": "info", "message": f"Collection '{name}' already exists."}
client.create_collection(
collection_name=name,
vectors_config=VectorParams(size=vector_size, distance=distance),
)
print(f"Collection '{name}' created successfully.")
return {"status": "success", "message": f"Collection '{name}' created."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating collection: {str(e)}")
@qdrant_llama_index_router.get("/qdrant/list_all_collections")
async def list_all_collections():
"""
List all collections in Qdrant.
"""
try:
collections = client.get_collections()
return {"status": "success", "collections": [c.name for c in collections.collections]}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error listing collections: {str(e)}")
@qdrant_llama_index_router.get("/qdrant/{collection_name}")
async def get_collection_info(collection_name: str):
"""
Get information about a specific collection.
- collection_name: The name of the collection
"""
try:
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found.")
collection_info = client.get_collection(collection_name=collection_name)
return {"status": "success", "collection_info": collection_info.dict()}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting collection info: {str(e)}")
@qdrant_llama_index_router.put("/qdrant/{collection_name}")
async def update_collection_metadata(
collection_name: str,
metadata: dict = Body(
...,
example={
"field_name": "new_value"
}
)
):
"""
Update collection metadata (placeholder - Qdrant's update is for config, not arbitrary metadata).
For actual metadata, you'd update points within the collection.
- collection_name: The name of the collection
- metadata: A dictionary of metadata to "update" (this is a conceptual placeholder)
"""
# Qdrant's `update_collection` is for configuration, not arbitrary metadata.
# To update metadata for points, you would need to query points and then update them.
# This endpoint serves as a conceptual placeholder as per the image.
# For a real implementation, you might update collection aliases or other config.
return {"status": "info", "message": f"Collection '{collection_name}' metadata update endpoint (conceptual). Qdrant's update is for configuration, not arbitrary collection-level metadata. Received: {metadata}"}
@qdrant_llama_index_router.delete("/qdrant/{collection_name}")
async def delete_collection(collection_name: str):
"""
Delete a collection.
- collection_name: The name of the collection to delete
"""
try:
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found.")
client.delete_collection(collection_name=collection_name)
return {"status": "success", "message": f"Collection '{collection_name}' deleted."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting collection: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/{collection_name}/text")
async def add_text_document(
collection_name: str,
text_data: dict = Body(
...,
example={
"text": "This is a sample document about FastAPI and Qdrant.",
"metadata": {"author": "GitHub Copilot", "source": "API"}
}
)
):
"""
Add a single text document to a collection.
- collection_name: The name of the collection
- text_data: A dictionary containing 'text' and optional 'metadata'
"""
try:
text = text_data.get("text")
metadata = text_data.get("metadata", {})
if not text:
raise HTTPException(status_code=422, detail="Invalid payload: 'text' is required.")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Please create it first.")
documents = [Document(text=text, metadata=metadata)]
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True
)
return {"status": "success", "message": f"Text document added to collection '{collection_name}'."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error adding text document: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/{collection_name}/upload")
async def upload_file_to_collection(
collection_name: str,
file: UploadFile = File(...)
):
"""
Upload a file (e.g., CSV, PDF, TXT) to be ingested into a collection.
- collection_name: The name of the collection
- file: The file to upload
"""
try:
print(f"Attempting to upload file '{file.filename}' to collection '{collection_name}'")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Please create it first.")
print(f"Collection '{collection_name}' exists.")
# Save the uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as temp_file:
shutil.copyfileobj(file.file, temp_file)
temp_file_path = temp_file.name
print(f"File '{file.filename}' saved temporarily to '{temp_file_path}'.")
try:
# Load the file into Document objects using SimpleDirectoryReader
documents = SimpleDirectoryReader(input_files=[temp_file_path]).load_data()
if not documents:
raise HTTPException(status_code=400, detail=f"No documents could be loaded from the file '{file.filename}'. Check file format or content.")
print(f"Loaded {len(documents)} document(s) from uploaded file: {file.filename}")
if documents: # Print a snippet of the first document for verification
print(f"First document text snippet: {documents[0].text[:200]}...")
# Set up Qdrant vector store
vector_store = QdrantVectorStore(
client=client,
collection_name=collection_name,
enable_hybrid=False, # Disable hybrid to avoid FastEmbed dependency unless explicitly configured
batch_size=20
)
print(f"QdrantVectorStore initialized for collection '{collection_name}'.")
# Create storage context and index
storage_context = StorageContext.from_defaults(vector_store=vector_store)
print("StorageContext created.")
# Before ingestion, get current count
initial_count_result = client.count(collection_name=collection_name, exact=True)
initial_count = initial_count_result.count
print(f"Initial document count in '{collection_name}': {initial_count}")
index = VectorStoreIndex.from_documents(
documents,
storage_context=storage_context,
show_progress=True # This should print progress to console
)
print("VectorStoreIndex created and documents ingested.")
# After ingestion, get new count
final_count_result = client.count(collection_name=collection_name, exact=True)
final_count = final_count_result.count
print(f"Final document count in '{collection_name}': {final_count}")
if final_count > initial_count:
print(f"Successfully added {final_count - initial_count} new documents.")
else:
print("No new documents were added or count did not increase. This might indicate an issue.")
return {"status": "success", "message": f"File '{file.filename}' ingested into '{collection_name}'. Added {final_count - initial_count} documents."}
finally:
# Clean up the temporary file
os.unlink(temp_file_path)
print(f"Temporary file '{temp_file_path}' deleted.")
except HTTPException as e:
print(f"HTTP Exception during upload: {e.detail}")
raise e
except Exception as e:
print(f"Unhandled Exception during upload: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error during file ingestion: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/{collection_name}/batch")
async def add_batch_documents(
collection_name: str,
documents_data: list[dict] = Body(
...,
example=[
{"text": "First document text.", "metadata": {"id": 1}},
{"text": "Second document text.", "metadata": {"id": 2}}
]
)
):
"""
Add multiple text documents in a batch to a collection.
- collection_name: The name of the collection
- documents_data: A list of dictionaries, each containing 'text' and optional 'metadata'
"""
try:
if not documents_data:
raise HTTPException(status_code=422, detail="Invalid payload: 'documents_data' cannot be empty.")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Please create it first.")
llama_documents = []
for doc_data in documents_data:
text = doc_data.get("text")
metadata = doc_data.get("metadata", {})
if not text:
raise HTTPException(status_code=422, detail="Each document in batch must have a 'text' field.")
llama_documents.append(Document(text=text, metadata=metadata))
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
llama_documents,
storage_context=storage_context,
show_progress=True
)
return {"status": "success", "message": f"{len(llama_documents)} documents added in batch to collection '{collection_name}'."}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error adding batch documents: {str(e)}")
@qdrant_llama_index_router.get("/qdrant/{collection_name}/count")
async def get_document_count(collection_name: str):
"""
Get the number of documents (points) in a collection.
- collection_name: The name of the collection
"""
try:
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found.")
count_result = client.count(collection_name=collection_name, exact=True)
return {"status": "success", "collection_name": collection_name, "count": count_result.count}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error getting document count: {str(e)}")
@qdrant_llama_index_router.post("/qdrant/search")
async def search_collection(
search_params: dict = Body(
...,
example={
"collection_name": "visa_dataset_collection",
"query": "What is the average salary?",
"top_k": 5
}
)
):
"""
Search a specific Qdrant collection.
- collection_name: The name of the collection to search (required)
- query: The query string (required)
- top_k: Number of results to return (optional, default 5)
"""
try:
collection_name = search_params.get("collection_name")
query = search_params.get("query")
top_k = search_params.get("top_k", 5)
if not collection_name or not query:
raise HTTPException(status_code=422, detail="Invalid payload: 'collection_name' and 'query' are required.")
if not client.collection_exists(collection_name=collection_name):
raise HTTPException(status_code=404, detail=f"Collection '{collection_name}' not found. Cannot perform search.")
vector_store = QdrantVectorStore(client=client, collection_name=collection_name)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
query_engine = index.as_query_engine(similarity_top_k=top_k)
response = query_engine.query(query)
return {"status": "success", "collection_name": collection_name, "query": query, "response": str(response)}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during search: {str(e)}")
|