222 lines
7.9 KiB
Python
222 lines
7.9 KiB
Python
"""
|
|
Search route aliases for LightRAG API.
|
|
This module provides search endpoint aliases that map to the existing query functionality.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any, Dict, List, Literal, Optional
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from lightrag.base import QueryParam
|
|
from ..utils_api import get_combined_auth_dependency
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from ascii_colors import trace_exception
|
|
|
|
router = APIRouter(tags=["search"])
|
|
|
|
|
|
class SearchRequest(BaseModel):
|
|
query: str = Field(
|
|
min_length=1,
|
|
description="The search query text",
|
|
)
|
|
|
|
mode: Literal["local", "global", "hybrid", "naive", "mix", "bypass"] = Field(
|
|
default="mix",
|
|
description="Search mode",
|
|
)
|
|
|
|
top_k: Optional[int] = Field(
|
|
ge=1,
|
|
default=10,
|
|
description="Number of top results to return",
|
|
)
|
|
|
|
chunk_top_k: Optional[int] = Field(
|
|
ge=1,
|
|
default=None,
|
|
description="Number of text chunks to retrieve initially from vector search",
|
|
)
|
|
|
|
@field_validator("query", mode="after")
|
|
@classmethod
|
|
def query_strip_after(cls, query: str) -> str:
|
|
return query.strip()
|
|
|
|
|
|
class SearchResponse(BaseModel):
|
|
results: List[Dict[str, Any]] = Field(
|
|
description="Search results with documents and metadata"
|
|
)
|
|
query: str = Field(description="The original search query")
|
|
total_results: int = Field(description="Total number of results found")
|
|
|
|
|
|
class SearchDataResponse(BaseModel):
|
|
entities: List[Dict[str, Any]] = Field(
|
|
description="Retrieved entities from knowledge graph"
|
|
)
|
|
relationships: List[Dict[str, Any]] = Field(
|
|
description="Retrieved relationships from knowledge graph"
|
|
)
|
|
chunks: List[Dict[str, Any]] = Field(
|
|
description="Retrieved text chunks from documents"
|
|
)
|
|
metadata: Dict[str, Any] = Field(
|
|
description="Search metadata including mode, keywords, and processing information"
|
|
)
|
|
|
|
|
|
def create_search_routes(rag, api_key: Optional[str] = None, top_k: int = 60):
|
|
combined_auth = get_combined_auth_dependency(api_key)
|
|
|
|
@router.post(
|
|
"/search", response_model=SearchResponse, dependencies=[Depends(combined_auth)]
|
|
)
|
|
async def search_documents(request: SearchRequest):
|
|
"""
|
|
Handle a POST request at the /search endpoint to search documents using RAG capabilities.
|
|
This is an alias for the /query endpoint with simplified response format.
|
|
"""
|
|
try:
|
|
# Convert search request to query parameters
|
|
from lightrag.api.routers.query_routes import QueryRequest
|
|
|
|
query_request = QueryRequest(
|
|
query=request.query,
|
|
mode=request.mode,
|
|
top_k=request.top_k,
|
|
chunk_top_k=request.chunk_top_k,
|
|
only_need_context=True # Only return context for search
|
|
)
|
|
|
|
param = query_request.to_query_params(False)
|
|
response = await rag.aquery_data(request.query, param=param)
|
|
|
|
# Format response for search endpoint
|
|
if isinstance(response, dict):
|
|
entities = response.get("entities", [])
|
|
relationships = response.get("relationships", [])
|
|
chunks = response.get("chunks", [])
|
|
metadata = response.get("metadata", {})
|
|
|
|
# Combine all results
|
|
all_results = []
|
|
|
|
# Add entities
|
|
for entity in entities:
|
|
all_results.append({
|
|
"type": "entity",
|
|
"content": entity.get("name", ""),
|
|
"score": entity.get("score", 0.0),
|
|
"metadata": entity
|
|
})
|
|
|
|
# Add relationships
|
|
for relation in relationships:
|
|
all_results.append({
|
|
"type": "relationship",
|
|
"content": f"{relation.get('source', '')} -> {relation.get('target', '')}",
|
|
"score": relation.get("score", 0.0),
|
|
"metadata": relation
|
|
})
|
|
|
|
# Add chunks
|
|
for chunk in chunks:
|
|
all_results.append({
|
|
"type": "chunk",
|
|
"content": chunk.get("text", ""),
|
|
"score": chunk.get("score", 0.0),
|
|
"metadata": {
|
|
"document_id": chunk.get("document_id"),
|
|
"chunk_id": chunk.get("chunk_id"),
|
|
"page": chunk.get("page"),
|
|
"source": chunk.get("source")
|
|
}
|
|
})
|
|
|
|
# Sort by score (descending)
|
|
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
|
|
|
return SearchResponse(
|
|
results=all_results[:request.top_k],
|
|
query=request.query,
|
|
total_results=len(all_results)
|
|
)
|
|
else:
|
|
# Fallback for unexpected response format
|
|
return SearchResponse(
|
|
results=[],
|
|
query=request.query,
|
|
total_results=0
|
|
)
|
|
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@router.post(
|
|
"/api/search",
|
|
response_model=SearchDataResponse,
|
|
dependencies=[Depends(combined_auth)]
|
|
)
|
|
async def search_data(request: SearchRequest):
|
|
"""
|
|
API search endpoint that returns structured data without LLM generation.
|
|
This is an alias for the /query/data endpoint.
|
|
"""
|
|
try:
|
|
# Convert search request to query parameters
|
|
from lightrag.api.routers.query_routes import QueryRequest
|
|
|
|
query_request = QueryRequest(
|
|
query=request.query,
|
|
mode=request.mode,
|
|
top_k=request.top_k,
|
|
chunk_top_k=request.chunk_top_k,
|
|
only_need_context=True # Only return context for search
|
|
)
|
|
|
|
param = query_request.to_query_params(False)
|
|
response = await rag.aquery_data(request.query, param=param)
|
|
|
|
# Format response for search data endpoint
|
|
if isinstance(response, dict):
|
|
entities = response.get("entities", [])
|
|
relationships = response.get("relationships", [])
|
|
chunks = response.get("chunks", [])
|
|
metadata = response.get("metadata", {})
|
|
|
|
# Validate data types
|
|
if not isinstance(entities, list):
|
|
entities = []
|
|
if not isinstance(relationships, list):
|
|
relationships = []
|
|
if not isinstance(chunks, list):
|
|
chunks = []
|
|
if not isinstance(metadata, dict):
|
|
metadata = {}
|
|
|
|
return SearchDataResponse(
|
|
entities=entities,
|
|
relationships=relationships,
|
|
chunks=chunks,
|
|
metadata=metadata,
|
|
)
|
|
else:
|
|
# Fallback for unexpected response format
|
|
return SearchDataResponse(
|
|
entities=[],
|
|
relationships=[],
|
|
chunks=[],
|
|
metadata={
|
|
"error": "Unexpected response format",
|
|
"raw_response": str(response),
|
|
},
|
|
)
|
|
except Exception as e:
|
|
trace_exception(e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
return router |