106 lines
4.3 KiB
Python
106 lines
4.3 KiB
Python
"""
|
|
Factory for creating LightRAG instances with shared configuration.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Callable, Optional
|
|
from lightrag import LightRAG
|
|
from lightrag.utils import EmbeddingFunc
|
|
from lightrag.api.config import global_args
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LightRAGFactory:
|
|
"""Factory that creates LightRAG instances with shared configuration."""
|
|
|
|
def __init__(
|
|
self,
|
|
llm_model_func: Callable,
|
|
llm_model_name: str,
|
|
llm_model_max_async: int,
|
|
summary_max_tokens: int,
|
|
summary_context_size: int,
|
|
chunk_token_size: int,
|
|
chunk_overlap_token_size: int,
|
|
llm_model_kwargs: dict,
|
|
embedding_func: EmbeddingFunc,
|
|
default_llm_timeout: int,
|
|
default_embedding_timeout: int,
|
|
kv_storage: str,
|
|
graph_storage: str,
|
|
vector_storage: str,
|
|
doc_status_storage: str,
|
|
vector_db_storage_cls_kwargs: dict,
|
|
enable_llm_cache_for_entity_extract: bool,
|
|
enable_llm_cache: bool,
|
|
rerank_model_func: Optional[Callable],
|
|
max_parallel_insert: int,
|
|
max_graph_nodes: int,
|
|
addon_params: dict,
|
|
ollama_server_infos,
|
|
):
|
|
self.llm_model_func = llm_model_func
|
|
self.llm_model_name = llm_model_name
|
|
self.llm_model_max_async = llm_model_max_async
|
|
self.summary_max_tokens = summary_max_tokens
|
|
self.summary_context_size = summary_context_size
|
|
self.chunk_token_size = chunk_token_size
|
|
self.chunk_overlap_token_size = chunk_overlap_token_size
|
|
self.llm_model_kwargs = llm_model_kwargs
|
|
self.embedding_func = embedding_func
|
|
self.default_llm_timeout = default_llm_timeout
|
|
self.default_embedding_timeout = default_embedding_timeout
|
|
self.kv_storage = kv_storage
|
|
self.graph_storage = graph_storage
|
|
self.vector_storage = vector_storage
|
|
self.doc_status_storage = doc_status_storage
|
|
self.vector_db_storage_cls_kwargs = vector_db_storage_cls_kwargs
|
|
self.enable_llm_cache_for_entity_extract = enable_llm_cache_for_entity_extract
|
|
self.enable_llm_cache = enable_llm_cache
|
|
self.rerank_model_func = rerank_model_func
|
|
self.max_parallel_insert = max_parallel_insert
|
|
self.max_graph_nodes = max_graph_nodes
|
|
self.addon_params = addon_params
|
|
self.ollama_server_infos = ollama_server_infos
|
|
self._cache = {} # workspace -> LightRAG instance
|
|
|
|
def create(self, working_dir: str, workspace: str = "") -> LightRAG:
|
|
"""Create a new LightRAG instance for the given workspace."""
|
|
key = (working_dir, workspace)
|
|
if key in self._cache:
|
|
return self._cache[key]
|
|
|
|
rag = LightRAG(
|
|
working_dir=working_dir,
|
|
workspace=workspace,
|
|
llm_model_func=self.llm_model_func,
|
|
llm_model_name=self.llm_model_name,
|
|
llm_model_max_async=self.llm_model_max_async,
|
|
summary_max_tokens=self.summary_max_tokens,
|
|
summary_context_size=self.summary_context_size,
|
|
chunk_token_size=self.chunk_token_size,
|
|
chunk_overlap_token_size=self.chunk_overlap_token_size,
|
|
llm_model_kwargs=self.llm_model_kwargs,
|
|
embedding_func=self.embedding_func,
|
|
default_llm_timeout=self.default_llm_timeout,
|
|
default_embedding_timeout=self.default_embedding_timeout,
|
|
kv_storage=self.kv_storage,
|
|
graph_storage=self.graph_storage,
|
|
vector_storage=self.vector_storage,
|
|
doc_status_storage=self.doc_status_storage,
|
|
vector_db_storage_cls_kwargs=self.vector_db_storage_cls_kwargs,
|
|
enable_llm_cache_for_entity_extract=self.enable_llm_cache_for_entity_extract,
|
|
enable_llm_cache=self.enable_llm_cache,
|
|
rerank_model_func=self.rerank_model_func,
|
|
max_parallel_insert=self.max_parallel_insert,
|
|
max_graph_nodes=self.max_graph_nodes,
|
|
addon_params=self.addon_params,
|
|
ollama_server_infos=self.ollama_server_infos,
|
|
)
|
|
self._cache[key] = rag
|
|
return rag
|
|
|
|
def get(self, working_dir: str, workspace: str = "") -> LightRAG:
|
|
"""Get cached LightRAG instance or create if not exists."""
|
|
return self.create(working_dir, workspace) |