Files
railseek6/LightRAG-main/lightrag/api/lightrag_factory.py

106 lines
4.3 KiB
Python

"""
Factory for creating LightRAG instances with shared configuration.
"""
import logging
from typing import Callable, Optional
from lightrag import LightRAG
from lightrag.utils import EmbeddingFunc
from lightrag.api.config import global_args
logger = logging.getLogger(__name__)
class LightRAGFactory:
"""Factory that creates LightRAG instances with shared configuration."""
def __init__(
self,
llm_model_func: Callable,
llm_model_name: str,
llm_model_max_async: int,
summary_max_tokens: int,
summary_context_size: int,
chunk_token_size: int,
chunk_overlap_token_size: int,
llm_model_kwargs: dict,
embedding_func: EmbeddingFunc,
default_llm_timeout: int,
default_embedding_timeout: int,
kv_storage: str,
graph_storage: str,
vector_storage: str,
doc_status_storage: str,
vector_db_storage_cls_kwargs: dict,
enable_llm_cache_for_entity_extract: bool,
enable_llm_cache: bool,
rerank_model_func: Optional[Callable],
max_parallel_insert: int,
max_graph_nodes: int,
addon_params: dict,
ollama_server_infos,
):
self.llm_model_func = llm_model_func
self.llm_model_name = llm_model_name
self.llm_model_max_async = llm_model_max_async
self.summary_max_tokens = summary_max_tokens
self.summary_context_size = summary_context_size
self.chunk_token_size = chunk_token_size
self.chunk_overlap_token_size = chunk_overlap_token_size
self.llm_model_kwargs = llm_model_kwargs
self.embedding_func = embedding_func
self.default_llm_timeout = default_llm_timeout
self.default_embedding_timeout = default_embedding_timeout
self.kv_storage = kv_storage
self.graph_storage = graph_storage
self.vector_storage = vector_storage
self.doc_status_storage = doc_status_storage
self.vector_db_storage_cls_kwargs = vector_db_storage_cls_kwargs
self.enable_llm_cache_for_entity_extract = enable_llm_cache_for_entity_extract
self.enable_llm_cache = enable_llm_cache
self.rerank_model_func = rerank_model_func
self.max_parallel_insert = max_parallel_insert
self.max_graph_nodes = max_graph_nodes
self.addon_params = addon_params
self.ollama_server_infos = ollama_server_infos
self._cache = {} # workspace -> LightRAG instance
def create(self, working_dir: str, workspace: str = "") -> LightRAG:
"""Create a new LightRAG instance for the given workspace."""
key = (working_dir, workspace)
if key in self._cache:
return self._cache[key]
rag = LightRAG(
working_dir=working_dir,
workspace=workspace,
llm_model_func=self.llm_model_func,
llm_model_name=self.llm_model_name,
llm_model_max_async=self.llm_model_max_async,
summary_max_tokens=self.summary_max_tokens,
summary_context_size=self.summary_context_size,
chunk_token_size=self.chunk_token_size,
chunk_overlap_token_size=self.chunk_overlap_token_size,
llm_model_kwargs=self.llm_model_kwargs,
embedding_func=self.embedding_func,
default_llm_timeout=self.default_llm_timeout,
default_embedding_timeout=self.default_embedding_timeout,
kv_storage=self.kv_storage,
graph_storage=self.graph_storage,
vector_storage=self.vector_storage,
doc_status_storage=self.doc_status_storage,
vector_db_storage_cls_kwargs=self.vector_db_storage_cls_kwargs,
enable_llm_cache_for_entity_extract=self.enable_llm_cache_for_entity_extract,
enable_llm_cache=self.enable_llm_cache,
rerank_model_func=self.rerank_model_func,
max_parallel_insert=self.max_parallel_insert,
max_graph_nodes=self.max_graph_nodes,
addon_params=self.addon_params,
ollama_server_infos=self.ollama_server_infos,
)
self._cache[key] = rag
return rag
def get(self, working_dir: str, workspace: str = "") -> LightRAG:
"""Get cached LightRAG instance or create if not exists."""
return self.create(working_dir, workspace)