1017 lines
39 KiB
Python
1017 lines
39 KiB
Python
"""
|
||
LightRAG FastAPI Server
|
||
"""
|
||
|
||
from fastapi import FastAPI, Depends, HTTPException
|
||
import asyncio
|
||
import os
|
||
import logging
|
||
import logging.config
|
||
import signal
|
||
import sys
|
||
import uvicorn
|
||
import pipmaster as pm
|
||
import inspect
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.responses import RedirectResponse
|
||
from pathlib import Path
|
||
import configparser
|
||
from ascii_colors import ASCIIColors
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from contextlib import asynccontextmanager
|
||
from dotenv import load_dotenv
|
||
from lightrag.api.utils_api import (
|
||
get_combined_auth_dependency,
|
||
display_splash_screen,
|
||
check_env_file,
|
||
)
|
||
from .config import (
|
||
global_args,
|
||
update_uvicorn_mode_config,
|
||
get_default_host,
|
||
)
|
||
from lightrag.utils import get_env_value
|
||
from lightrag import LightRAG, __version__ as core_version
|
||
from lightrag.api import __api_version__
|
||
from lightrag.types import GPTKeywordExtractionFormat
|
||
from lightrag.utils import EmbeddingFunc
|
||
from lightrag.constants import (
|
||
DEFAULT_LOG_MAX_BYTES,
|
||
DEFAULT_LOG_BACKUP_COUNT,
|
||
DEFAULT_LOG_FILENAME,
|
||
DEFAULT_LLM_TIMEOUT,
|
||
DEFAULT_EMBEDDING_TIMEOUT,
|
||
)
|
||
from lightrag.api.routers.document_routes import (
|
||
DocumentManager,
|
||
create_document_routes,
|
||
run_scanning_process,
|
||
)
|
||
from lightrag.api.routers.query_routes import create_query_routes
|
||
from lightrag.api.routers.graph_routes import create_graph_routes
|
||
from lightrag.api.routers.search_routes import create_search_routes
|
||
from lightrag.api.routers.ollama_api import OllamaAPI
|
||
from lightrag.api.routers.workspace_routes import router as workspace_router
|
||
|
||
from lightrag.utils import logger, set_verbose_debug
|
||
from lightrag.kg.shared_storage import (
|
||
get_namespace_data,
|
||
get_pipeline_status_lock,
|
||
initialize_pipeline_status,
|
||
cleanup_keyed_lock,
|
||
finalize_share_data,
|
||
)
|
||
from fastapi.security import OAuth2PasswordRequestForm
|
||
from lightrag.api.auth import auth_handler
|
||
|
||
# use the .env that is inside the current folder
|
||
# allows to use different .env file for each lightrag instance
|
||
# the OS environment variables take precedence over the .env file
|
||
load_dotenv(dotenv_path=".env", override=False)
|
||
|
||
|
||
webui_title = os.getenv("WEBUI_TITLE")
|
||
webui_description = os.getenv("WEBUI_DESCRIPTION")
|
||
|
||
# Initialize config parser
|
||
config = configparser.ConfigParser()
|
||
config.read("config.ini")
|
||
|
||
# Global authentication configuration
|
||
auth_configured = bool(auth_handler.accounts)
|
||
|
||
|
||
def setup_signal_handlers():
|
||
"""Setup signal handlers for graceful shutdown"""
|
||
|
||
def signal_handler(sig, frame):
|
||
print(f"\n\nReceived signal {sig}, shutting down gracefully...")
|
||
print(f"Process ID: {os.getpid()}")
|
||
|
||
# Release shared resources
|
||
finalize_share_data()
|
||
|
||
# Exit with success status
|
||
sys.exit(0)
|
||
|
||
# Register signal handlers
|
||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||
signal.signal(signal.SIGTERM, signal_handler) # kill command
|
||
|
||
|
||
class LLMConfigCache:
|
||
"""Smart LLM and Embedding configuration cache class"""
|
||
|
||
def __init__(self, args):
|
||
self.args = args
|
||
|
||
# Initialize configurations based on binding conditions
|
||
self.openai_llm_options = None
|
||
self.ollama_llm_options = None
|
||
self.ollama_embedding_options = None
|
||
|
||
# Only initialize and log OpenAI options when using OpenAI-related bindings
|
||
if args.llm_binding in ["openai", "azure_openai"]:
|
||
from lightrag.llm.binding_options import OpenAILLMOptions
|
||
|
||
self.openai_llm_options = OpenAILLMOptions.options_dict(args)
|
||
logger.info(f"OpenAI LLM Options: {self.openai_llm_options}")
|
||
|
||
# Only initialize and log Ollama LLM options when using Ollama LLM binding
|
||
if args.llm_binding == "ollama":
|
||
try:
|
||
from lightrag.llm.binding_options import OllamaLLMOptions
|
||
|
||
self.ollama_llm_options = OllamaLLMOptions.options_dict(args)
|
||
logger.info(f"Ollama LLM Options: {self.ollama_llm_options}")
|
||
except ImportError:
|
||
logger.warning(
|
||
"OllamaLLMOptions not available, using default configuration"
|
||
)
|
||
self.ollama_llm_options = {}
|
||
|
||
# Only initialize and log Ollama Embedding options when using Ollama Embedding binding
|
||
if args.embedding_binding == "ollama":
|
||
try:
|
||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||
|
||
self.ollama_embedding_options = OllamaEmbeddingOptions.options_dict(
|
||
args
|
||
)
|
||
logger.info(
|
||
f"Ollama Embedding Options: {self.ollama_embedding_options}"
|
||
)
|
||
except ImportError:
|
||
logger.warning(
|
||
"OllamaEmbeddingOptions not available, using default configuration"
|
||
)
|
||
self.ollama_embedding_options = {}
|
||
|
||
|
||
def create_app(args):
|
||
# Setup logging
|
||
logger.setLevel(args.log_level)
|
||
set_verbose_debug(args.verbose)
|
||
|
||
# Create configuration cache (this will output configuration logs)
|
||
config_cache = LLMConfigCache(args)
|
||
|
||
# Verify that bindings are correctly setup
|
||
if args.llm_binding not in [
|
||
"lollms",
|
||
"ollama",
|
||
"openai",
|
||
"azure_openai",
|
||
"aws_bedrock",
|
||
]:
|
||
raise Exception("llm binding not supported")
|
||
|
||
if args.embedding_binding not in [
|
||
"lollms",
|
||
"ollama",
|
||
"openai",
|
||
"azure_openai",
|
||
"aws_bedrock",
|
||
"jina",
|
||
]:
|
||
raise Exception("embedding binding not supported")
|
||
|
||
# Set default hosts if not provided
|
||
if args.llm_binding_host is None:
|
||
args.llm_binding_host = get_default_host(args.llm_binding, "llm")
|
||
|
||
if args.embedding_binding_host is None:
|
||
args.embedding_binding_host = get_default_host(args.embedding_binding, "embedding")
|
||
|
||
# Add SSL validation
|
||
if args.ssl:
|
||
if not args.ssl_certfile or not args.ssl_keyfile:
|
||
raise Exception(
|
||
"SSL certificate and key files must be provided when SSL is enabled"
|
||
)
|
||
if not os.path.exists(args.ssl_certfile):
|
||
raise Exception(f"SSL certificate file not found: {args.ssl_certfile}")
|
||
if not os.path.exists(args.ssl_keyfile):
|
||
raise Exception(f"SSL key file not found: {args.ssl_keyfile}")
|
||
|
||
# Check if API key is provided either through env var or args
|
||
api_key = os.getenv("LIGHTRAG_API_KEY") or args.key
|
||
|
||
# Create workspace manager for dynamic workspace management
|
||
from lightrag.api.workspace_manager import WorkspaceManager
|
||
workspace_manager = WorkspaceManager(args)
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""Lifespan context manager for startup and shutdown events"""
|
||
# Store background tasks
|
||
app.state.background_tasks = set()
|
||
|
||
try:
|
||
# Initialize default workspace if specified
|
||
if args.workspace:
|
||
# Ensure default workspace exists
|
||
if not workspace_manager.workspace_exists(args.workspace):
|
||
workspace_manager.create_workspace(args.workspace)
|
||
|
||
# Get default workspace RAG instance and initialize it
|
||
default_rag = workspace_manager.get_rag(args.workspace)
|
||
await default_rag.initialize_storages()
|
||
|
||
# Data migration for default workspace
|
||
await default_rag.check_and_migrate_data()
|
||
|
||
await initialize_pipeline_status()
|
||
|
||
pipeline_status = await get_namespace_data("pipeline_status")
|
||
|
||
should_start_autoscan = False
|
||
async with get_pipeline_status_lock():
|
||
# Auto scan documents if enabled
|
||
if args.auto_scan_at_startup:
|
||
if not pipeline_status.get("autoscanned", False):
|
||
pipeline_status["autoscanned"] = True
|
||
should_start_autoscan = True
|
||
|
||
# Only run auto scan when no other process started it first
|
||
if should_start_autoscan and args.workspace:
|
||
# Get document manager for default workspace
|
||
default_doc_manager = workspace_manager.get_document_manager(args.workspace)
|
||
default_rag = workspace_manager.get_rag(args.workspace)
|
||
# Create background task
|
||
task = asyncio.create_task(run_scanning_process(default_rag, default_doc_manager))
|
||
app.state.background_tasks.add(task)
|
||
task.add_done_callback(app.state.background_tasks.discard)
|
||
logger.info(f"Process {os.getpid()} auto scan task started at startup for workspace '{args.workspace}'.")
|
||
|
||
# Warm up OCR processor in background to avoid cold‑start delay on first upload
|
||
async def warm_up_ocr_processor():
|
||
try:
|
||
logger.info("Starting OCR processor warm‑up...")
|
||
# Import inside function to avoid unnecessary dependency if OCR not used
|
||
from lightrag.document_processor import get_document_processor
|
||
# This will initialize OptimizedOCRProcessor (≈9 seconds)
|
||
processor = get_document_processor()
|
||
logger.info("OCR processor warmed up successfully")
|
||
except Exception as e:
|
||
logger.warning(f"OCR warm‑up failed (non‑critical): {e}")
|
||
|
||
# Schedule warm‑up as a background task (non‑blocking)
|
||
warm_up_task = asyncio.create_task(warm_up_ocr_processor())
|
||
app.state.background_tasks.add(warm_up_task)
|
||
warm_up_task.add_done_callback(app.state.background_tasks.discard)
|
||
|
||
ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
|
||
|
||
yield
|
||
|
||
finally:
|
||
# Clean up all workspace RAG instances
|
||
for workspace_name, rag_instance in workspace_manager._rag_instances.items():
|
||
try:
|
||
await rag_instance.finalize_storages()
|
||
except Exception as e:
|
||
logger.error(f"Error finalizing storages for workspace '{workspace_name}': {e}")
|
||
|
||
# Clean up shared data
|
||
finalize_share_data()
|
||
|
||
# Initialize FastAPI
|
||
app_kwargs = {
|
||
"title": "Server API",
|
||
"description": (
|
||
"Providing API for LightRAG core, Web UI and Ollama Model Emulation"
|
||
+ "(With authentication)"
|
||
if api_key
|
||
else ""
|
||
),
|
||
"version": __api_version__,
|
||
"openapi_url": "/openapi.json", # Explicitly set OpenAPI schema URL
|
||
"docs_url": "/docs", # Explicitly set docs URL
|
||
"redoc_url": "/redoc", # Explicitly set redoc URL
|
||
"lifespan": lifespan,
|
||
}
|
||
|
||
# Configure Swagger UI parameters
|
||
# Enable persistAuthorization and tryItOutEnabled for better user experience
|
||
app_kwargs["swagger_ui_parameters"] = {
|
||
"persistAuthorization": True,
|
||
"tryItOutEnabled": True,
|
||
}
|
||
|
||
app = FastAPI(**app_kwargs)
|
||
|
||
def get_cors_origins():
|
||
"""Get allowed origins from global_args
|
||
Returns a list of allowed origins, defaults to ["*"] if not set
|
||
"""
|
||
origins_str = global_args.cors_origins
|
||
if origins_str == "*":
|
||
return ["*"]
|
||
return [origin.strip() for origin in origins_str.split(",")]
|
||
|
||
# Add CORS middleware
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=get_cors_origins(),
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Create combined auth dependency for all endpoints
|
||
combined_auth = get_combined_auth_dependency(api_key)
|
||
|
||
# Create working directory if it doesn't exist
|
||
Path(args.working_dir).mkdir(parents=True, exist_ok=True)
|
||
|
||
def create_optimized_openai_llm_func(
|
||
config_cache: LLMConfigCache, args, llm_timeout: int
|
||
):
|
||
"""Create optimized OpenAI LLM function with pre-processed configuration"""
|
||
|
||
async def optimized_openai_alike_model_complete(
|
||
prompt,
|
||
system_prompt=None,
|
||
history_messages=None,
|
||
keyword_extraction=False,
|
||
**kwargs,
|
||
) -> str:
|
||
from lightrag.llm.openai import openai_complete_if_cache
|
||
|
||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||
if keyword_extraction:
|
||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||
if history_messages is None:
|
||
history_messages = []
|
||
|
||
# Use pre-processed configuration to avoid repeated parsing
|
||
kwargs["timeout"] = llm_timeout
|
||
if config_cache.openai_llm_options:
|
||
kwargs.update(config_cache.openai_llm_options)
|
||
|
||
return await openai_complete_if_cache(
|
||
args.llm_model,
|
||
prompt,
|
||
system_prompt=system_prompt,
|
||
history_messages=history_messages,
|
||
base_url=args.llm_binding_host,
|
||
api_key=args.llm_binding_api_key,
|
||
**kwargs,
|
||
)
|
||
|
||
return optimized_openai_alike_model_complete
|
||
|
||
def create_optimized_azure_openai_llm_func(
|
||
config_cache: LLMConfigCache, args, llm_timeout: int
|
||
):
|
||
"""Create optimized Azure OpenAI LLM function with pre-processed configuration"""
|
||
|
||
async def optimized_azure_openai_model_complete(
|
||
prompt,
|
||
system_prompt=None,
|
||
history_messages=None,
|
||
keyword_extraction=False,
|
||
**kwargs,
|
||
) -> str:
|
||
from lightrag.llm.azure_openai import azure_openai_complete_if_cache
|
||
|
||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||
if keyword_extraction:
|
||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||
if history_messages is None:
|
||
history_messages = []
|
||
|
||
# Use pre-processed configuration to avoid repeated parsing
|
||
kwargs["timeout"] = llm_timeout
|
||
if config_cache.openai_llm_options:
|
||
kwargs.update(config_cache.openai_llm_options)
|
||
|
||
return await azure_openai_complete_if_cache(
|
||
args.llm_model,
|
||
prompt,
|
||
system_prompt=system_prompt,
|
||
history_messages=history_messages,
|
||
base_url=args.llm_binding_host,
|
||
api_key=os.getenv("AZURE_OPENAI_API_KEY", args.llm_binding_api_key),
|
||
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-08-01-preview"),
|
||
**kwargs,
|
||
)
|
||
|
||
return optimized_azure_openai_model_complete
|
||
|
||
def create_llm_model_func(binding: str):
|
||
"""
|
||
Create LLM model function based on binding type.
|
||
Uses optimized functions for OpenAI bindings and lazy import for others.
|
||
"""
|
||
try:
|
||
if binding == "lollms":
|
||
from lightrag.llm.lollms import lollms_model_complete
|
||
|
||
return lollms_model_complete
|
||
elif binding == "ollama":
|
||
from lightrag.llm.ollama import ollama_model_complete
|
||
|
||
return ollama_model_complete
|
||
elif binding == "aws_bedrock":
|
||
return bedrock_model_complete # Already defined locally
|
||
elif binding == "azure_openai":
|
||
# Use optimized function with pre-processed configuration
|
||
return create_optimized_azure_openai_llm_func(
|
||
config_cache, args, llm_timeout
|
||
)
|
||
else: # openai and compatible
|
||
# Use optimized function with pre-processed configuration
|
||
return create_optimized_openai_llm_func(config_cache, args, llm_timeout)
|
||
except ImportError as e:
|
||
raise Exception(f"Failed to import {binding} LLM binding: {e}")
|
||
|
||
def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
|
||
"""
|
||
Create LLM model kwargs based on binding type.
|
||
Uses lazy import for binding-specific options.
|
||
"""
|
||
if binding in ["lollms", "ollama"]:
|
||
try:
|
||
from lightrag.llm.binding_options import OllamaLLMOptions
|
||
|
||
return {
|
||
"host": args.llm_binding_host,
|
||
"timeout": llm_timeout,
|
||
"options": OllamaLLMOptions.options_dict(args),
|
||
"api_key": args.llm_binding_api_key,
|
||
}
|
||
except ImportError as e:
|
||
raise Exception(f"Failed to import {binding} options: {e}")
|
||
return {}
|
||
|
||
def create_optimized_embedding_function(
|
||
config_cache: LLMConfigCache, binding, model, host, api_key, dimensions, args
|
||
):
|
||
"""
|
||
Create optimized embedding function with pre-processed configuration for applicable bindings.
|
||
Uses lazy imports for all bindings and avoids repeated configuration parsing.
|
||
"""
|
||
|
||
async def optimized_embedding_function(texts):
|
||
try:
|
||
if binding == "lollms":
|
||
from lightrag.llm.lollms import lollms_embed
|
||
|
||
return await lollms_embed(
|
||
texts, embed_model=model, host=host, api_key=api_key
|
||
)
|
||
elif binding == "ollama":
|
||
from lightrag.llm.ollama import ollama_embed
|
||
|
||
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
|
||
if config_cache.ollama_embedding_options is not None:
|
||
ollama_options = config_cache.ollama_embedding_options
|
||
else:
|
||
# Fallback for cases where config cache wasn't initialized properly
|
||
from lightrag.llm.binding_options import OllamaEmbeddingOptions
|
||
|
||
ollama_options = OllamaEmbeddingOptions.options_dict(args)
|
||
|
||
return await ollama_embed(
|
||
texts,
|
||
embed_model=model,
|
||
host=host,
|
||
api_key=api_key,
|
||
options=ollama_options,
|
||
)
|
||
elif binding == "azure_openai":
|
||
from lightrag.llm.azure_openai import azure_openai_embed
|
||
|
||
return await azure_openai_embed(texts, model=model, api_key=api_key)
|
||
elif binding == "aws_bedrock":
|
||
from lightrag.llm.bedrock import bedrock_embed
|
||
|
||
return await bedrock_embed(texts, model=model)
|
||
elif binding == "jina":
|
||
from lightrag.llm.jina import jina_embed
|
||
|
||
return await jina_embed(
|
||
texts, dimensions=dimensions, base_url=host, api_key=api_key
|
||
)
|
||
else: # openai and compatible
|
||
from lightrag.llm.openai import openai_embed
|
||
|
||
return await openai_embed(
|
||
texts, model=model, base_url=host, api_key=api_key
|
||
)
|
||
except ImportError as e:
|
||
raise Exception(f"Failed to import {binding} embedding: {e}")
|
||
|
||
return optimized_embedding_function
|
||
|
||
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
|
||
embedding_timeout = get_env_value(
|
||
"EMBEDDING_TIMEOUT", DEFAULT_EMBEDDING_TIMEOUT, int
|
||
)
|
||
|
||
async def bedrock_model_complete(
|
||
prompt,
|
||
system_prompt=None,
|
||
history_messages=None,
|
||
keyword_extraction=False,
|
||
**kwargs,
|
||
) -> str:
|
||
# Lazy import
|
||
from lightrag.llm.bedrock import bedrock_complete_if_cache
|
||
|
||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||
if keyword_extraction:
|
||
kwargs["response_format"] = GPTKeywordExtractionFormat
|
||
if history_messages is None:
|
||
history_messages = []
|
||
|
||
# Use global temperature for Bedrock
|
||
kwargs["temperature"] = get_env_value("BEDROCK_LLM_TEMPERATURE", 1.0, float)
|
||
|
||
return await bedrock_complete_if_cache(
|
||
args.llm_model,
|
||
prompt,
|
||
system_prompt=system_prompt,
|
||
history_messages=history_messages,
|
||
**kwargs,
|
||
)
|
||
|
||
# Create embedding function with optimized configuration
|
||
embedding_func = EmbeddingFunc(
|
||
embedding_dim=args.embedding_dim,
|
||
func=create_optimized_embedding_function(
|
||
config_cache=config_cache,
|
||
binding=args.embedding_binding,
|
||
model=args.embedding_model,
|
||
host=args.embedding_binding_host,
|
||
api_key=args.embedding_binding_api_key,
|
||
dimensions=args.embedding_dim,
|
||
args=args, # Pass args object for fallback option generation
|
||
),
|
||
)
|
||
|
||
# Configure rerank function based on args.rerank_bindingparameter
|
||
rerank_model_func = None
|
||
if args.rerank_binding != "null":
|
||
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank, ollama_rerank
|
||
|
||
# Map rerank binding to corresponding function
|
||
rerank_functions = {
|
||
"cohere": cohere_rerank,
|
||
"jina": jina_rerank,
|
||
"aliyun": ali_rerank,
|
||
"ollama": ollama_rerank,
|
||
}
|
||
|
||
# Select the appropriate rerank function based on binding
|
||
selected_rerank_func = rerank_functions.get(args.rerank_binding)
|
||
if not selected_rerank_func:
|
||
logger.error(f"Unsupported rerank binding: {args.rerank_binding}")
|
||
raise ValueError(f"Unsupported rerank binding: {args.rerank_binding}")
|
||
|
||
# Get default values from selected_rerank_func if args values are None
|
||
if args.rerank_model is None or args.rerank_binding_host is None:
|
||
sig = inspect.signature(selected_rerank_func)
|
||
|
||
# Set default model if args.rerank_model is None
|
||
if args.rerank_model is None and "model" in sig.parameters:
|
||
default_model = sig.parameters["model"].default
|
||
if default_model != inspect.Parameter.empty:
|
||
args.rerank_model = default_model
|
||
|
||
# Set default base_url if args.rerank_binding_host is None
|
||
if args.rerank_binding_host is None and "base_url" in sig.parameters:
|
||
default_base_url = sig.parameters["base_url"].default
|
||
if default_base_url != inspect.Parameter.empty:
|
||
args.rerank_binding_host = default_base_url
|
||
|
||
async def server_rerank_func(
|
||
query: str, documents: list, top_n: int = None, extra_body: dict = None
|
||
):
|
||
"""Server rerank function with configuration from environment variables"""
|
||
return await selected_rerank_func(
|
||
query=query,
|
||
documents=documents,
|
||
top_n=top_n,
|
||
api_key=args.rerank_binding_api_key,
|
||
model=args.rerank_model,
|
||
base_url=args.rerank_binding_host,
|
||
extra_body=extra_body,
|
||
)
|
||
|
||
rerank_model_func = server_rerank_func
|
||
logger.info(
|
||
f"Reranking is enabled: {args.rerank_model or 'default model'} using {args.rerank_binding} provider"
|
||
)
|
||
else:
|
||
logger.info("Reranking is disabled")
|
||
|
||
# Create ollama_server_infos from command line arguments
|
||
from lightrag.api.config import OllamaServerInfos
|
||
|
||
ollama_server_infos = OllamaServerInfos(
|
||
name=args.simulated_model_name, tag=args.simulated_model_tag
|
||
)
|
||
|
||
# Create a factory function for creating LightRAG instances with the given configuration
|
||
def create_lightrag_factory():
|
||
"""Factory function to create LightRAG instances with server configuration"""
|
||
def factory(working_dir: str, workspace: str):
|
||
return LightRAG(
|
||
working_dir=working_dir,
|
||
workspace=workspace,
|
||
llm_model_func=create_llm_model_func(args.llm_binding),
|
||
llm_model_name=args.llm_model,
|
||
llm_model_max_async=args.max_async,
|
||
summary_max_tokens=args.summary_max_tokens,
|
||
summary_context_size=args.summary_context_size,
|
||
chunk_token_size=int(args.chunk_size),
|
||
chunk_overlap_token_size=int(args.chunk_overlap_size),
|
||
llm_model_kwargs=create_llm_model_kwargs(
|
||
args.llm_binding, args, llm_timeout
|
||
),
|
||
embedding_func=embedding_func,
|
||
default_llm_timeout=llm_timeout,
|
||
default_embedding_timeout=embedding_timeout,
|
||
kv_storage=args.kv_storage,
|
||
graph_storage=args.graph_storage,
|
||
vector_storage=args.vector_storage,
|
||
doc_status_storage=args.doc_status_storage,
|
||
vector_db_storage_cls_kwargs={
|
||
"cosine_better_than_threshold": args.cosine_threshold
|
||
},
|
||
enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
|
||
enable_llm_cache=args.enable_llm_cache,
|
||
rerank_model_func=rerank_model_func,
|
||
max_parallel_insert=args.max_parallel_insert,
|
||
max_graph_nodes=args.max_graph_nodes,
|
||
addon_params={
|
||
"language": args.summary_language,
|
||
"entity_types": args.entity_types,
|
||
},
|
||
ollama_server_infos=ollama_server_infos,
|
||
)
|
||
return factory
|
||
|
||
# Create workspace manager with LightRAG factory
|
||
workspace_manager = WorkspaceManager(args, lightrag_factory=create_lightrag_factory())
|
||
app.state.workspace_manager = workspace_manager
|
||
|
||
# Add routes with workspace manager
|
||
app.include_router(
|
||
create_document_routes(
|
||
workspace_manager,
|
||
api_key,
|
||
)
|
||
)
|
||
app.include_router(create_query_routes(workspace_manager, api_key, args.top_k))
|
||
app.include_router(create_graph_routes(workspace_manager, api_key))
|
||
app.include_router(create_search_routes(workspace_manager, api_key, args.top_k))
|
||
|
||
# Add Ollama API routes with workspace manager
|
||
ollama_api = OllamaAPI(workspace_manager, top_k=args.top_k, api_key=api_key)
|
||
app.include_router(ollama_api.router, prefix="/api")
|
||
|
||
# Add workspace routes
|
||
logger.info("Including workspace router")
|
||
app.include_router(workspace_router)
|
||
|
||
@app.get("/")
|
||
async def redirect_to_webui():
|
||
"""Redirect root path to /webui"""
|
||
return RedirectResponse(url="/webui")
|
||
|
||
@app.get("/auth-status")
|
||
async def get_auth_status():
|
||
"""Get authentication status and guest token if auth is not configured"""
|
||
|
||
if not auth_handler.accounts:
|
||
# Authentication not configured, return guest token
|
||
guest_token = auth_handler.create_token(
|
||
username="guest", role="guest", metadata={"auth_mode": "disabled"}
|
||
)
|
||
return {
|
||
"auth_configured": False,
|
||
"access_token": guest_token,
|
||
"token_type": "bearer",
|
||
"auth_mode": "disabled",
|
||
"message": "Authentication is disabled. Using guest access.",
|
||
"core_version": core_version,
|
||
"api_version": __api_version__,
|
||
"webui_title": webui_title,
|
||
"webui_description": webui_description,
|
||
}
|
||
|
||
return {
|
||
"auth_configured": True,
|
||
"auth_mode": "enabled",
|
||
"core_version": core_version,
|
||
"api_version": __api_version__,
|
||
"webui_title": webui_title,
|
||
"webui_description": webui_description,
|
||
}
|
||
|
||
@app.post("/login")
|
||
async def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
||
if not auth_handler.accounts:
|
||
# Authentication not configured, return guest token
|
||
guest_token = auth_handler.create_token(
|
||
username="guest", role="guest", metadata={"auth_mode": "disabled"}
|
||
)
|
||
return {
|
||
"access_token": guest_token,
|
||
"token_type": "bearer",
|
||
"auth_mode": "disabled",
|
||
"message": "Authentication is disabled. Using guest access.",
|
||
"core_version": core_version,
|
||
"api_version": __api_version__,
|
||
"webui_title": webui_title,
|
||
"webui_description": webui_description,
|
||
}
|
||
username = form_data.username
|
||
if auth_handler.accounts.get(username) != form_data.password:
|
||
raise HTTPException(status_code=401, detail="Incorrect credentials")
|
||
|
||
# Regular user login
|
||
user_token = auth_handler.create_token(
|
||
username=username, role="user", metadata={"auth_mode": "enabled"}
|
||
)
|
||
return {
|
||
"access_token": user_token,
|
||
"token_type": "bearer",
|
||
"auth_mode": "enabled",
|
||
"core_version": core_version,
|
||
"api_version": __api_version__,
|
||
"webui_title": webui_title,
|
||
"webui_description": webui_description,
|
||
}
|
||
|
||
@app.get("/health", dependencies=[Depends(combined_auth)])
|
||
async def get_status():
|
||
"""Get current system status"""
|
||
try:
|
||
pipeline_status = await get_namespace_data("pipeline_status")
|
||
|
||
if not auth_configured:
|
||
auth_mode = "disabled"
|
||
else:
|
||
auth_mode = "enabled"
|
||
|
||
# Cleanup expired keyed locks and get status
|
||
keyed_lock_info = cleanup_keyed_lock()
|
||
|
||
return {
|
||
"status": "healthy",
|
||
"working_directory": str(args.working_dir),
|
||
"input_directory": str(args.input_dir),
|
||
"configuration": {
|
||
# LLM configuration binding/host address (if applicable)/model (if applicable)
|
||
"llm_binding": args.llm_binding,
|
||
"llm_binding_host": args.llm_binding_host,
|
||
"llm_model": args.llm_model,
|
||
# embedding model configuration binding/host address (if applicable)/model (if applicable)
|
||
"embedding_binding": args.embedding_binding,
|
||
"embedding_binding_host": args.embedding_binding_host,
|
||
"embedding_model": args.embedding_model,
|
||
"summary_max_tokens": args.summary_max_tokens,
|
||
"summary_context_size": args.summary_context_size,
|
||
"kv_storage": args.kv_storage,
|
||
"doc_status_storage": args.doc_status_storage,
|
||
"graph_storage": args.graph_storage,
|
||
"vector_storage": args.vector_storage,
|
||
"enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
|
||
"enable_llm_cache": args.enable_llm_cache,
|
||
"workspace": args.workspace,
|
||
"max_graph_nodes": args.max_graph_nodes,
|
||
# Rerank configuration
|
||
"enable_rerank": rerank_model_func is not None,
|
||
"rerank_binding": args.rerank_binding,
|
||
"rerank_model": args.rerank_model if rerank_model_func else None,
|
||
"rerank_binding_host": args.rerank_binding_host
|
||
if rerank_model_func
|
||
else None,
|
||
# Environment variable status (requested configuration)
|
||
"summary_language": args.summary_language,
|
||
"force_llm_summary_on_merge": args.force_llm_summary_on_merge,
|
||
"max_parallel_insert": args.max_parallel_insert,
|
||
"cosine_threshold": args.cosine_threshold,
|
||
"min_rerank_score": args.min_rerank_score,
|
||
"related_chunk_number": args.related_chunk_number,
|
||
"max_async": args.max_async,
|
||
"embedding_func_max_async": args.embedding_func_max_async,
|
||
"embedding_batch_num": args.embedding_batch_num,
|
||
},
|
||
"auth_mode": auth_mode,
|
||
"pipeline_busy": pipeline_status.get("busy", False),
|
||
"keyed_locks": keyed_lock_info,
|
||
"core_version": core_version,
|
||
"api_version": __api_version__,
|
||
"webui_title": webui_title,
|
||
"webui_description": webui_description,
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Error getting health status: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
# Custom StaticFiles class for smart caching
|
||
class SmartStaticFiles(StaticFiles): # Renamed from NoCacheStaticFiles
|
||
async def get_response(self, path: str, scope):
|
||
response = await super().get_response(path, scope)
|
||
|
||
if path.endswith(".html"):
|
||
response.headers["Cache-Control"] = (
|
||
"no-cache, no-store, must-revalidate"
|
||
)
|
||
response.headers["Pragma"] = "no-cache"
|
||
response.headers["Expires"] = "0"
|
||
elif (
|
||
"/assets/" in path
|
||
): # Assets (JS, CSS, images, fonts) generated by Vite with hash in filename
|
||
response.headers["Cache-Control"] = (
|
||
"public, max-age=31536000, immutable"
|
||
)
|
||
# Add other rules here if needed for non-HTML, non-asset files
|
||
|
||
# Ensure correct Content-Type
|
||
if path.endswith(".js"):
|
||
response.headers["Content-Type"] = "application/javascript"
|
||
elif path.endswith(".css"):
|
||
response.headers["Content-Type"] = "text/css"
|
||
|
||
return response
|
||
|
||
# Webui mount webui/index.html
|
||
static_dir = Path(__file__).parent / "webui"
|
||
static_dir.mkdir(exist_ok=True)
|
||
app.mount(
|
||
"/webui",
|
||
SmartStaticFiles(
|
||
directory=static_dir, html=True, check_dir=True
|
||
), # Use SmartStaticFiles
|
||
name="webui",
|
||
)
|
||
|
||
return app
|
||
|
||
|
||
def get_application(args=None):
|
||
"""Factory function for creating the FastAPI application"""
|
||
if args is None:
|
||
args = global_args
|
||
return create_app(args)
|
||
|
||
|
||
def configure_logging():
|
||
"""Configure logging for uvicorn startup"""
|
||
|
||
# Reset any existing handlers to ensure clean configuration
|
||
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
|
||
logger = logging.getLogger(logger_name)
|
||
logger.handlers = []
|
||
logger.filters = []
|
||
|
||
# Get project root directory and set log directory to logs folder
|
||
project_root = Path(__file__).parent.parent.parent # Go from lightrag/api to LightRAG-main to railseek5
|
||
log_dir = project_root / "logs"
|
||
log_file_path = log_dir / DEFAULT_LOG_FILENAME
|
||
|
||
print(f"\nLightRAG log file: {log_file_path}\n")
|
||
os.makedirs(log_dir, exist_ok=True)
|
||
|
||
# Get log file max size and backup count from environment variables
|
||
log_max_bytes = get_env_value("LOG_MAX_BYTES", DEFAULT_LOG_MAX_BYTES, int)
|
||
log_backup_count = get_env_value("LOG_BACKUP_COUNT", DEFAULT_LOG_BACKUP_COUNT, int)
|
||
|
||
logging.config.dictConfig(
|
||
{
|
||
"version": 1,
|
||
"disable_existing_loggers": False,
|
||
"formatters": {
|
||
"default": {
|
||
"format": "%(levelname)s: %(message)s",
|
||
},
|
||
"detailed": {
|
||
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
},
|
||
},
|
||
"handlers": {
|
||
"console": {
|
||
"formatter": "default",
|
||
"class": "logging.StreamHandler",
|
||
"stream": "ext://sys.stderr",
|
||
},
|
||
"file": {
|
||
"formatter": "detailed",
|
||
"class": "logging.handlers.RotatingFileHandler",
|
||
"filename": log_file_path,
|
||
"maxBytes": log_max_bytes,
|
||
"backupCount": log_backup_count,
|
||
"encoding": "utf-8",
|
||
},
|
||
},
|
||
"loggers": {
|
||
# Configure all uvicorn related loggers
|
||
"uvicorn": {
|
||
"handlers": ["console", "file"],
|
||
"level": "INFO",
|
||
"propagate": False,
|
||
},
|
||
"uvicorn.access": {
|
||
"handlers": ["console", "file"],
|
||
"level": "INFO",
|
||
"propagate": False,
|
||
"filters": ["path_filter"],
|
||
},
|
||
"uvicorn.error": {
|
||
"handlers": ["console", "file"],
|
||
"level": "INFO",
|
||
"propagate": False,
|
||
},
|
||
"lightrag": {
|
||
"handlers": ["console", "file"],
|
||
"level": "INFO",
|
||
"propagate": False,
|
||
"filters": ["path_filter"],
|
||
},
|
||
},
|
||
"filters": {
|
||
"path_filter": {
|
||
"()": "lightrag.utils.LightragPathFilter",
|
||
},
|
||
},
|
||
}
|
||
)
|
||
|
||
|
||
def check_and_install_dependencies():
|
||
"""Check and install required dependencies"""
|
||
required_packages = [
|
||
"uvicorn",
|
||
"tiktoken",
|
||
"fastapi",
|
||
# Add other required packages here
|
||
]
|
||
|
||
for package in required_packages:
|
||
if not pm.is_installed(package):
|
||
print(f"Installing {package}...")
|
||
pm.install(package)
|
||
print(f"{package} installed successfully")
|
||
|
||
|
||
def main():
|
||
# Check if running under Gunicorn
|
||
if "GUNICORN_CMD_ARGS" in os.environ:
|
||
# If started with Gunicorn, return directly as Gunicorn will call get_application
|
||
print("Running under Gunicorn - worker management handled by Gunicorn")
|
||
return
|
||
|
||
# Check .env file
|
||
if not check_env_file():
|
||
sys.exit(1)
|
||
|
||
# Check and install dependencies
|
||
check_and_install_dependencies()
|
||
|
||
from multiprocessing import freeze_support
|
||
|
||
freeze_support()
|
||
|
||
# Configure logging before parsing args
|
||
configure_logging()
|
||
update_uvicorn_mode_config()
|
||
display_splash_screen(global_args)
|
||
|
||
# Setup signal handlers for graceful shutdown
|
||
setup_signal_handlers()
|
||
|
||
# Create application instance directly instead of using factory function
|
||
app = create_app(global_args)
|
||
|
||
# Start Uvicorn in single process mode
|
||
uvicorn_config = {
|
||
"app": app, # Pass application instance directly instead of string path
|
||
"host": global_args.host,
|
||
"port": global_args.port,
|
||
"log_config": None, # Disable default config
|
||
}
|
||
|
||
if global_args.ssl:
|
||
uvicorn_config.update(
|
||
{
|
||
"ssl_certfile": global_args.ssl_certfile,
|
||
"ssl_keyfile": global_args.ssl_keyfile,
|
||
}
|
||
)
|
||
|
||
print(
|
||
f"Starting Uvicorn server in single-process mode on {global_args.host}:{global_args.port}"
|
||
)
|
||
uvicorn.run(**uvicorn_config)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|