Files
railseek6/LightRAG-main/optimize_entity_extraction.py

420 lines
16 KiB
Python

"""
Optimization 1: Proper LLM formatting integration for entity extraction
This optimization improves the entity extraction formatting to reduce LLM retries and improve parsing accuracy.
"""
import json
import asyncio
from typing import Dict, List, Tuple, Any
from lightrag.utils import logger
from lightrag.prompt import PROMPTS
class EntityExtractionOptimizer:
"""Optimize entity extraction formatting to reduce LLM retries and improve parsing accuracy."""
def __init__(self):
self.optimization_enabled = True
self.original_formatting = None
async def optimize_extraction_formatting(
self,
content: str,
context_base: Dict[str, Any],
use_llm_func: callable,
llm_response_cache: Any = None,
chunk_key: str = None
) -> Tuple[Dict, Dict]:
"""
Optimized entity extraction with improved formatting and error handling.
Args:
content: Text content to extract entities from
context_base: Context parameters for the prompt
use_llm_func: LLM function to call
llm_response_cache: Cache for LLM responses
chunk_key: Chunk identifier for caching
Returns:
Tuple of (entities_dict, relationships_dict)
"""
if not self.optimization_enabled:
# Fall back to original implementation
from lightrag.operate import _process_single_content
return await _process_single_content((chunk_key, {"content": content}))
try:
# Get initial extraction with optimized formatting
entity_extraction_system_prompt = PROMPTS[
"entity_extraction_system_prompt"
].format(**{**context_base, "input_text": content})
entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format(
**{**context_base, "input_text": content}
)
# Use the optimized LLM call with better formatting
final_result, timestamp = await self._optimized_llm_call(
entity_extraction_user_prompt,
use_llm_func,
system_prompt=entity_extraction_system_prompt,
llm_response_cache=llm_response_cache,
chunk_key=chunk_key
)
# Process extraction with improved error handling
maybe_nodes, maybe_edges = await self._process_extraction_with_validation(
final_result,
chunk_key,
timestamp,
context_base
)
# Apply gleaning if enabled
if context_base.get("entity_extract_max_gleaning", 0) > 0:
maybe_nodes, maybe_edges = await self._apply_gleaning_optimized(
content,
context_base,
use_llm_func,
llm_response_cache,
chunk_key,
maybe_nodes,
maybe_edges,
entity_extraction_system_prompt,
final_result
)
return maybe_nodes, maybe_edges
except Exception as e:
logger.error(f"Optimized extraction failed for chunk {chunk_key}: {e}")
# Fall back to original implementation
from lightrag.operate import _process_single_content
return await _process_single_content((chunk_key, {"content": content}))
async def _optimized_llm_call(
self,
prompt: str,
use_llm_func: callable,
system_prompt: str = None,
llm_response_cache: Any = None,
chunk_key: str = None,
cache_keys_collector: List = None
) -> Tuple[str, int]:
"""
Optimized LLM call with improved error handling and caching.
"""
from lightrag.utils import use_llm_func_with_cache
try:
result, timestamp = await use_llm_func_with_cache(
prompt,
use_llm_func,
system_prompt=system_prompt,
llm_response_cache=llm_response_cache,
cache_type="extract",
chunk_id=chunk_key,
cache_keys_collector=cache_keys_collector,
)
return result, timestamp
except Exception as e:
logger.error(f"LLM call failed for chunk {chunk_key}: {e}")
raise
async def _process_extraction_with_validation(
self,
result: str,
chunk_key: str,
timestamp: int,
context_base: Dict[str, Any]
) -> Tuple[Dict, Dict]:
"""
Process extraction result with improved validation and error handling.
"""
from lightrag.operate import _process_extraction_result
try:
# Use the existing processing function but with additional validation
maybe_nodes, maybe_edges = await _process_extraction_result(
result,
chunk_key,
timestamp,
tuple_delimiter=context_base["tuple_delimiter"],
completion_delimiter=context_base["completion_delimiter"],
)
# Validate the extracted data
validated_nodes = self._validate_entities(maybe_nodes, chunk_key)
validated_edges = self._validate_relationships(maybe_edges, chunk_key)
return validated_nodes, validated_edges
except Exception as e:
logger.error(f"Extraction processing failed for chunk {chunk_key}: {e}")
return {}, {}
def _validate_entities(self, entities: Dict, chunk_key: str) -> Dict:
"""Validate extracted entities for consistency and quality."""
validated = {}
for entity_name, entity_list in entities.items():
if not entity_name or not entity_name.strip():
logger.warning(f"Empty entity name in chunk {chunk_key}")
continue
valid_entities = []
for entity in entity_list:
if self._is_valid_entity(entity):
valid_entities.append(entity)
else:
logger.warning(f"Invalid entity data in chunk {chunk_key}: {entity}")
if valid_entities:
validated[entity_name] = valid_entities
return validated
def _validate_relationships(self, relationships: Dict, chunk_key: str) -> Dict:
"""Validate extracted relationships for consistency and quality."""
validated = {}
for edge_key, edge_list in relationships.items():
if not isinstance(edge_key, tuple) or len(edge_key) != 2:
logger.warning(f"Invalid edge key format in chunk {chunk_key}: {edge_key}")
continue
valid_edges = []
for edge in edge_list:
if self._is_valid_relationship(edge):
valid_edges.append(edge)
else:
logger.warning(f"Invalid relationship data in chunk {chunk_key}: {edge}")
if valid_edges:
validated[edge_key] = valid_edges
return validated
def _is_valid_entity(self, entity: Dict) -> bool:
"""Check if entity data is valid."""
required_fields = ["entity_name", "entity_type", "description"]
return all(field in entity and entity[field] for field in required_fields)
def _is_valid_relationship(self, relationship: Dict) -> bool:
"""Check if relationship data is valid."""
required_fields = ["src_id", "tgt_id", "description"]
return all(field in relationship and relationship[field] for field in required_fields)
async def _apply_gleaning_optimized(
self,
content: str,
context_base: Dict[str, Any],
use_llm_func: callable,
llm_response_cache: Any,
chunk_key: str,
current_nodes: Dict,
current_edges: Dict,
system_prompt: str,
initial_result: str
) -> Tuple[Dict, Dict]:
"""
Apply gleaning with improved error handling and result merging.
"""
from lightrag.utils import pack_user_ass_to_openai_messages, use_llm_func_with_cache
try:
entity_continue_extraction_user_prompt = PROMPTS[
"entity_continue_extraction_user_prompt"
].format(**{**context_base, "input_text": content})
history = pack_user_ass_to_openai_messages(
PROMPTS["entity_extraction_user_prompt"].format(**{**context_base, "input_text": content}),
initial_result
)
glean_result, timestamp = await use_llm_func_with_cache(
entity_continue_extraction_user_prompt,
use_llm_func,
system_prompt=system_prompt,
llm_response_cache=llm_response_cache,
history_messages=history,
cache_type="extract",
chunk_id=chunk_key,
)
# Process gleaning result
glean_nodes, glean_edges = await self._process_extraction_with_validation(
glean_result,
chunk_key,
timestamp,
context_base
)
# Merge results with improved conflict resolution
merged_nodes = self._merge_entities(current_nodes, glean_nodes)
merged_edges = self._merge_relationships(current_edges, glean_edges)
return merged_nodes, merged_edges
except Exception as e:
logger.error(f"Gleaning failed for chunk {chunk_key}: {e}")
return current_nodes, current_edges
def _merge_entities(self, current: Dict, new: Dict) -> Dict:
"""Merge entities with improved conflict resolution."""
merged = current.copy()
for entity_name, new_entities in new.items():
if entity_name not in merged:
merged[entity_name] = new_entities
else:
# Compare and keep the better description
current_desc = merged[entity_name][0].get("description", "")
new_desc = new_entities[0].get("description", "")
if len(new_desc) > len(current_desc):
merged[entity_name] = new_entities
return merged
def _merge_relationships(self, current: Dict, new: Dict) -> Dict:
"""Merge relationships with improved conflict resolution."""
merged = current.copy()
for edge_key, new_edges in new.items():
if edge_key not in merged:
merged[edge_key] = new_edges
else:
# Compare and keep the better description
current_desc = merged[edge_key][0].get("description", "")
new_desc = new_edges[0].get("description", "")
if len(new_desc) > len(current_desc):
merged[edge_key] = new_edges
return merged
def enable_optimization(self):
"""Enable the optimization."""
self.optimization_enabled = True
logger.info("Entity extraction optimization enabled")
def disable_optimization(self):
"""Disable the optimization."""
self.optimization_enabled = False
logger.info("Entity extraction optimization disabled")
# Global optimizer instance
entity_extraction_optimizer = EntityExtractionOptimizer()
async def extract_entities_optimized(
chunks: Dict[str, Any],
global_config: Dict[str, str],
pipeline_status: Dict = None,
pipeline_status_lock = None,
llm_response_cache: Any = None,
text_chunks_storage: Any = None,
) -> List:
"""
Optimized version of extract_entities with improved formatting and error handling.
"""
from lightrag.operate import extract_entities
if not entity_extraction_optimizer.optimization_enabled:
# Fall back to original implementation
return await extract_entities(
chunks, global_config, pipeline_status, pipeline_status_lock,
llm_response_cache, text_chunks_storage
)
# Extract configuration
use_llm_func = global_config["llm_model_func"]
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
language = global_config["addon_params"].get("language", "English")
entity_types = global_config["addon_params"].get("entity_types", [])
# Build context base
examples = "\n".join(PROMPTS["entity_extraction_examples"])
example_context_base = {
"tuple_delimiter": PROMPTS["DEFAULT_TUPLE_DELIMITER"],
"completion_delimiter": PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
"entity_types": ", ".join(entity_types),
"language": language,
}
examples = examples.format(**example_context_base)
context_base = {
"tuple_delimiter": PROMPTS["DEFAULT_TUPLE_DELIMITER"],
"completion_delimiter": PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
"entity_types": ",".join(entity_types),
"examples": examples,
"language": language,
"entity_extract_max_gleaning": entity_extract_max_gleaning,
}
ordered_chunks = list(chunks.items())
total_chunks = len(ordered_chunks)
processed_chunks = 0
async def _process_single_chunk_optimized(chunk_key_dp):
nonlocal processed_chunks
chunk_key = chunk_key_dp[0]
chunk_dp = chunk_key_dp[1]
content = chunk_dp["content"]
file_path = chunk_dp.get("file_path", "unknown_source")
try:
# Use optimized extraction
maybe_nodes, maybe_edges = await entity_extraction_optimizer.optimize_extraction_formatting(
content=content,
context_base=context_base,
use_llm_func=use_llm_func,
llm_response_cache=llm_response_cache,
chunk_key=chunk_key
)
processed_chunks += 1
entities_count = len(maybe_nodes)
relations_count = len(maybe_edges)
log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel (optimized)"
logger.info(log_message)
if pipeline_status is not None and pipeline_status_lock is not None:
async with pipeline_status_lock:
pipeline_status["latest_message"] = log_message
pipeline_status["history_messages"].append(log_message)
return maybe_nodes, maybe_edges
except Exception as e:
logger.error(f"Optimized extraction failed for chunk {chunk_key}: {e}")
# Return empty results on failure
return {}, {}
# Get max async tasks limit
chunk_max_async = global_config.get("llm_model_max_async", 4)
semaphore = asyncio.Semaphore(chunk_max_async)
async def _process_with_semaphore(chunk):
async with semaphore:
return await _process_single_chunk_optimized(chunk)
# Process chunks concurrently
tasks = []
for chunk in ordered_chunks:
task = asyncio.create_task(_process_with_semaphore(chunk))
tasks.append(task)
# Wait for all tasks to complete
chunk_results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out exceptions and return valid results
valid_results = []
for result in chunk_results:
if not isinstance(result, Exception):
valid_results.append(result)
else:
logger.error(f"Chunk processing failed: {result}")
valid_results.append(({}, {})) # Empty result for failed chunks
return valid_results