257 lines
9.6 KiB
Python
257 lines
9.6 KiB
Python
"""
|
|
Search route aliases for LightRAG API.
|
|
This module provides search endpoint aliases that map to the existing query functionality.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from lightrag.base import QueryParam
|
|
from ..utils_api import get_combined_auth_dependency
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from ascii_colors import trace_exception
|
|
|
|
router = APIRouter(tags=["search"])
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str = Field(
|
|
min_length=1,
|
|
description="The search query text",
|
|
)
|
|
|
|
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
|
|
default="mix",
|
|
description="Search mode",
|
|
)
|
|
|
|
top_k: Optional[int] = Field(
|
|
ge=1,
|
|
default=10,
|
|
description="Number of top results to return",
|
|
)
|
|
|
|
chunk_top_k: Optional[int] = Field(
|
|
ge=1,
|
|
default=None,
|
|
description="Number of text chunks to retrieve initially from vector search",
|
|
)
|
|
|
|
@field_validator("query", mode="after")
|
|
@classmethod
|
|
def query_strip_after(cls, query: str) -> str:
|
|
return query.strip()
|
|
|
|
|
|
class SearchResponse(BaseModel):
|
|
results: List[Dict[str, Any]] = Field(
|
|
description="Search results with documents and metadata"
|
|
)
|
|
query: str = Field(description="The original search query")
|
|
total_results: int = Field(description="Total number of results found")
|
|
|
|
|
|
class SearchDataResponse(BaseModel):
|
|
entities: List[Dict[str, Any]] = Field(
|
|
description="Retrieved entities from knowledge graph"
|
|
)
|
|
relationships: List[Dict[str, Any]] = Field(
|
|
description="Retrieved relationships from knowledge graph"
|
|
)
|
|
chunks: List[Dict[str, Any]] = Field(
|
|
description="Retrieved text chunks from documents"
|
|
)
|
|
metadata: Dict[str, Any] = Field(
|
|
description="Search metadata including mode, keywords, and processing information"
|
|
)
|
|
|
|
|
|
def create_search_routes(rag_or_manager, api_key: Optional[str] = None, top_k: int = 60):
|
|
# Accept either a LightRAG instance or a WorkspaceManager
|
|
from lightrag.api.workspace_manager import WorkspaceManager
|
|
from lightrag import LightRAG
|
|
from fastapi import Request, Depends
|
|
from lightrag.base import StoragesStatus
|
|
from lightrag.utils import logger
|
|
|
|
combined_auth = get_combined_auth_dependency(api_key)
|
|
|
|
# Define dependency to get workspace-specific RAG instance
|
|
async def get_workspace_rag(request: Request):
|
|
if isinstance(rag_or_manager, WorkspaceManager):
|
|
workspace = request.headers.get("X-Workspace", "").strip()
|
|
rag = rag_or_manager.get_rag(workspace)
|
|
# Ensure storages are initialized for this workspace
|
|
try:
|
|
logger.info(f"Workspace '{workspace}': storages status = {rag._storages_status}")
|
|
if rag._storages_status != StoragesStatus.FINALIZED:
|
|
logger.info(f"Initializing storages for workspace '{workspace}'")
|
|
await rag.initialize_storages()
|
|
logger.info(f"Storages initialized, status now = {rag._storages_status}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize storages for workspace '{workspace}': {e}")
|
|
raise HTTPException(status_code=500, detail=f"Storage initialization failed: {e}")
|
|
return rag
|
|
elif isinstance(rag_or_manager, LightRAG):
|
|
# Single RAG instance mode - ignore workspace header
|
|
return rag_or_manager
|
|
else:
|
|
raise TypeError(f"Expected LightRAG or WorkspaceManager, got {type(rag_or_manager)}")
|
|
|
|
@router.post(
|
|
"/search", response_model=SearchResponse, dependencies=[Depends(combined_auth)]
|
|
)
|
|
async def search_documents(
|
|
request: SearchRequest,
|
|
rag: LightRAG = Depends(get_workspace_rag)
|
|
):
|
|
"""
|
|
Handle a POST request at the /search endpoint to search documents using RAG capabilities.
|
|
This is an alias for the /query endpoint with simplified response format.
|
|
"""
|
|
try:
|
|
# Convert search request to query parameters
|
|
from lightrag.api.routers.query_routes import QueryRequest
|
|
|
|
query_request = QueryRequest(
|
|
query=request.query,
|
|
mode=request.mode,
|
|
top_k=request.top_k,
|
|
chunk_top_k=request.chunk_top_k,
|
|
only_need_context=True # Only return context for search
|
|
)
|
|
|
|
param = query_request.to_query_params(False)
|
|
response = await rag.aquery_data(request.query, param=param)
|
|
|
|
# Format response for search endpoint
|
|
if isinstance(response, dict):
|
|
entities = response.get("entities", [])
|
|
relationships = response.get("relationships", [])
|
|
chunks = response.get("chunks", [])
|
|
metadata = response.get("metadata", {})
|
|
|
|
# Combine all results
|
|
all_results = []
|
|
|
|
# Add entities
|
|
for entity in entities:
|
|
all_results.append({
|
|
"type": "entity",
|
|
"content": entity.get("name", ""),
|
|
"score": entity.get("score", 0.0),
|
|
"metadata": entity
|
|
})
|
|
|
|
# Add relationships
|
|
for relation in relationships:
|
|
all_results.append({
|
|
"type": "relationship",
|
|
"content": f"{relation.get('source', '')} -> {relation.get('target', '')}",
|
|
"score": relation.get("score", 0.0),
|
|
"metadata": relation
|
|
})
|
|
|
|
# Add chunks
|
|
for chunk in chunks:
|
|
all_results.append({
|
|
"type": "chunk",
|
|
"content": chunk.get("text", ""),
|
|
"score": chunk.get("score", 0.0),
|
|
"metadata": {
|
|
"document_id": chunk.get("document_id"),
|
|
"chunk_id": chunk.get("chunk_id"),
|
|
"page": chunk.get("page"),
|
|
"source": chunk.get("source")
|
|
}
|
|
})
|
|
|
|
# Sort by score (descending)
|
|
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
|
|
|
return SearchResponse(
|
|
results=all_results[:request.top_k],
|
|
query=request.query,
|
|
total_results=len(all_results)
|
|
)
|
|
else:
|
|
# Fallback for unexpected response format
|
|
return SearchResponse(
|
|
results=[],
|
|
query=request.query,
|
|
total_results=0
|
|
)
|
|
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post(
|
|
"/api/search",
|
|
response_model=SearchDataResponse,
|
|
dependencies=[Depends(combined_auth)]
|
|
)
|
|
async def search_data(
|
|
request: SearchRequest,
|
|
rag: LightRAG = Depends(get_workspace_rag)
|
|
):
|
|
"""
|
|
API search endpoint that returns structured data without LLM generation.
|
|
This is an alias for the /query/data endpoint.
|
|
"""
|
|
try:
|
|
# Convert search request to query parameters
|
|
from lightrag.api.routers.query_routes import QueryRequest
|
|
|
|
query_request = QueryRequest(
|
|
query=request.query,
|
|
mode=request.mode,
|
|
top_k=request.top_k,
|
|
chunk_top_k=request.chunk_top_k,
|
|
only_need_context=True # Only return context for search
|
|
)
|
|
|
|
param = query_request.to_query_params(False)
|
|
response = await rag.aquery_data(request.query, param=param)
|
|
|
|
# Format response for search data endpoint
|
|
if isinstance(response, dict):
|
|
entities = response.get("entities", [])
|
|
relationships = response.get("relationships", [])
|
|
chunks = response.get("chunks", [])
|
|
metadata = response.get("metadata", {})
|
|
|
|
# Validate data types
|
|
if not isinstance(entities, list):
|
|
entities = []
|
|
if not isinstance(relationships, list):
|
|
relationships = []
|
|
if not isinstance(chunks, list):
|
|
chunks = []
|
|
if not isinstance(metadata, dict):
|
|
metadata = {}
|
|
|
|
return SearchDataResponse(
|
|
entities=entities,
|
|
relationships=relationships,
|
|
chunks=chunks,
|
|
metadata=metadata,
|
|
)
|
|
else:
|
|
# Fallback for unexpected response format
|
|
return SearchDataResponse(
|
|
entities=[],
|
|
relationships=[],
|
|
chunks=[],
|
|
metadata={
|
|
"error": "Unexpected response format",
|
|
"raw_response": str(response),
|
|
},
|
|
)
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
return router |