Files
railseek6/LightRAG-main/lightrag/api/routers/search_routes.py
2026-01-12 22:31:11 +08:00

257 lines
9.6 KiB
Python

"""
Search route aliases for LightRAG API.
This module provides search endpoint aliases that map to the existing query functionality.
"""
import logging
from typing import Any, Dict, List, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from lightrag.base import QueryParam
from ..utils_api import get_combined_auth_dependency
from pydantic import BaseModel, Field, field_validator
from ascii_colors import trace_exception
router = APIRouter(tags=["search"])
class SearchRequest(BaseModel):
query: str = Field(
min_length=1,
description="The search query text",
)
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
default="mix",
description="Search mode",
)
top_k: Optional[int] = Field(
ge=1,
default=10,
description="Number of top results to return",
)
chunk_top_k: Optional[int] = Field(
ge=1,
default=None,
description="Number of text chunks to retrieve initially from vector search",
)
@field_validator("query", mode="after")
@classmethod
def query_strip_after(cls, query: str) -> str:
return query.strip()
class SearchResponse(BaseModel):
results: List[Dict[str, Any]] = Field(
description="Search results with documents and metadata"
)
query: str = Field(description="The original search query")
total_results: int = Field(description="Total number of results found")
class SearchDataResponse(BaseModel):
entities: List[Dict[str, Any]] = Field(
description="Retrieved entities from knowledge graph"
)
relationships: List[Dict[str, Any]] = Field(
description="Retrieved relationships from knowledge graph"
)
chunks: List[Dict[str, Any]] = Field(
description="Retrieved text chunks from documents"
)
metadata: Dict[str, Any] = Field(
description="Search metadata including mode, keywords, and processing information"
)
def create_search_routes(rag_or_manager, api_key: Optional[str] = None, top_k: int = 60):
# Accept either a LightRAG instance or a WorkspaceManager
from lightrag.api.workspace_manager import WorkspaceManager
from lightrag import LightRAG
from fastapi import Request, Depends
from lightrag.base import StoragesStatus
from lightrag.utils import logger
combined_auth = get_combined_auth_dependency(api_key)
# Define dependency to get workspace-specific RAG instance
async def get_workspace_rag(request: Request):
if isinstance(rag_or_manager, WorkspaceManager):
workspace = request.headers.get("X-Workspace", "").strip()
rag = rag_or_manager.get_rag(workspace)
# Ensure storages are initialized for this workspace
try:
logger.info(f"Workspace '{workspace}': storages status = {rag._storages_status}")
if rag._storages_status != StoragesStatus.FINALIZED:
logger.info(f"Initializing storages for workspace '{workspace}'")
await rag.initialize_storages()
logger.info(f"Storages initialized, status now = {rag._storages_status}")
except Exception as e:
logger.error(f"Failed to initialize storages for workspace '{workspace}': {e}")
raise HTTPException(status_code=500, detail=f"Storage initialization failed: {e}")
return rag
elif isinstance(rag_or_manager, LightRAG):
# Single RAG instance mode - ignore workspace header
return rag_or_manager
else:
raise TypeError(f"Expected LightRAG or WorkspaceManager, got {type(rag_or_manager)}")
@router.post(
"/search", response_model=SearchResponse, dependencies=[Depends(combined_auth)]
)
async def search_documents(
request: SearchRequest,
rag: LightRAG = Depends(get_workspace_rag)
):
"""
Handle a POST request at the /search endpoint to search documents using RAG capabilities.
This is an alias for the /query endpoint with simplified response format.
"""
try:
# Convert search request to query parameters
from lightrag.api.routers.query_routes import QueryRequest
query_request = QueryRequest(
query=request.query,
mode=request.mode,
top_k=request.top_k,
chunk_top_k=request.chunk_top_k,
only_need_context=True # Only return context for search
)
param = query_request.to_query_params(False)
response = await rag.aquery_data(request.query, param=param)
# Format response for search endpoint
if isinstance(response, dict):
entities = response.get("entities", [])
relationships = response.get("relationships", [])
chunks = response.get("chunks", [])
metadata = response.get("metadata", {})
# Combine all results
all_results = []
# Add entities
for entity in entities:
all_results.append({
"type": "entity",
"content": entity.get("name", ""),
"score": entity.get("score", 0.0),
"metadata": entity
})
# Add relationships
for relation in relationships:
all_results.append({
"type": "relationship",
"content": f"{relation.get('source', '')} -> {relation.get('target', '')}",
"score": relation.get("score", 0.0),
"metadata": relation
})
# Add chunks
for chunk in chunks:
all_results.append({
"type": "chunk",
"content": chunk.get("text", ""),
"score": chunk.get("score", 0.0),
"metadata": {
"document_id": chunk.get("document_id"),
"chunk_id": chunk.get("chunk_id"),
"page": chunk.get("page"),
"source": chunk.get("source")
}
})
# Sort by score (descending)
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
return SearchResponse(
results=all_results[:request.top_k],
query=request.query,
total_results=len(all_results)
)
else:
# Fallback for unexpected response format
return SearchResponse(
results=[],
query=request.query,
total_results=0
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/api/search",
response_model=SearchDataResponse,
dependencies=[Depends(combined_auth)]
)
async def search_data(
request: SearchRequest,
rag: LightRAG = Depends(get_workspace_rag)
):
"""
API search endpoint that returns structured data without LLM generation.
This is an alias for the /query/data endpoint.
"""
try:
# Convert search request to query parameters
from lightrag.api.routers.query_routes import QueryRequest
query_request = QueryRequest(
query=request.query,
mode=request.mode,
top_k=request.top_k,
chunk_top_k=request.chunk_top_k,
only_need_context=True # Only return context for search
)
param = query_request.to_query_params(False)
response = await rag.aquery_data(request.query, param=param)
# Format response for search data endpoint
if isinstance(response, dict):
entities = response.get("entities", [])
relationships = response.get("relationships", [])
chunks = response.get("chunks", [])
metadata = response.get("metadata", {})
# Validate data types
if not isinstance(entities, list):
entities = []
if not isinstance(relationships, list):
relationships = []
if not isinstance(chunks, list):
chunks = []
if not isinstance(metadata, dict):
metadata = {}
return SearchDataResponse(
entities=entities,
relationships=relationships,
chunks=chunks,
metadata=metadata,
)
else:
# Fallback for unexpected response format
return SearchDataResponse(
entities=[],
relationships=[],
chunks=[],
metadata={
"error": "Unexpected response format",
"raw_response": str(response),
},
)
except Exception as e:
trace_exception(e)
raise HTTPException(status_code=500, detail=str(e))
return router