420 lines
16 KiB
Python
420 lines
16 KiB
Python
"""
|
|
Optimization 1: Proper LLM formatting integration for entity extraction
|
|
This optimization improves the entity extraction formatting to reduce LLM retries and improve parsing accuracy.
|
|
"""
|
|
import json
|
|
import asyncio
|
|
from typing import Dict, List, Tuple, Any
|
|
from lightrag.utils import logger
|
|
from lightrag.prompt import PROMPTS
|
|
|
|
class EntityExtractionOptimizer:
|
|
"""Optimize entity extraction formatting to reduce LLM retries and improve parsing accuracy."""
|
|
|
|
def __init__(self):
|
|
self.optimization_enabled = True
|
|
self.original_formatting = None
|
|
|
|
async def optimize_extraction_formatting(
|
|
self,
|
|
content: str,
|
|
context_base: Dict[str, Any],
|
|
use_llm_func: callable,
|
|
llm_response_cache: Any = None,
|
|
chunk_key: str = None
|
|
) -> Tuple[Dict, Dict]:
|
|
"""
|
|
Optimized entity extraction with improved formatting and error handling.
|
|
|
|
Args:
|
|
content: Text content to extract entities from
|
|
context_base: Context parameters for the prompt
|
|
use_llm_func: LLM function to call
|
|
llm_response_cache: Cache for LLM responses
|
|
chunk_key: Chunk identifier for caching
|
|
|
|
Returns:
|
|
Tuple of (entities_dict, relationships_dict)
|
|
"""
|
|
if not self.optimization_enabled:
|
|
# Fall back to original implementation
|
|
from lightrag.operate import _process_single_content
|
|
return await _process_single_content((chunk_key, {"content": content}))
|
|
|
|
try:
|
|
# Get initial extraction with optimized formatting
|
|
entity_extraction_system_prompt = PROMPTS[
|
|
"entity_extraction_system_prompt"
|
|
].format(**{**context_base, "input_text": content})
|
|
entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format(
|
|
**{**context_base, "input_text": content}
|
|
)
|
|
|
|
# Use the optimized LLM call with better formatting
|
|
final_result, timestamp = await self._optimized_llm_call(
|
|
entity_extraction_user_prompt,
|
|
use_llm_func,
|
|
system_prompt=entity_extraction_system_prompt,
|
|
llm_response_cache=llm_response_cache,
|
|
chunk_key=chunk_key
|
|
)
|
|
|
|
# Process extraction with improved error handling
|
|
maybe_nodes, maybe_edges = await self._process_extraction_with_validation(
|
|
final_result,
|
|
chunk_key,
|
|
timestamp,
|
|
context_base
|
|
)
|
|
|
|
# Apply gleaning if enabled
|
|
if context_base.get("entity_extract_max_gleaning", 0) > 0:
|
|
maybe_nodes, maybe_edges = await self._apply_gleaning_optimized(
|
|
content,
|
|
context_base,
|
|
use_llm_func,
|
|
llm_response_cache,
|
|
chunk_key,
|
|
maybe_nodes,
|
|
maybe_edges,
|
|
entity_extraction_system_prompt,
|
|
final_result
|
|
)
|
|
|
|
return maybe_nodes, maybe_edges
|
|
|
|
except Exception as e:
|
|
logger.error(f"Optimized extraction failed for chunk {chunk_key}: {e}")
|
|
# Fall back to original implementation
|
|
from lightrag.operate import _process_single_content
|
|
return await _process_single_content((chunk_key, {"content": content}))
|
|
|
|
async def _optimized_llm_call(
|
|
self,
|
|
prompt: str,
|
|
use_llm_func: callable,
|
|
system_prompt: str = None,
|
|
llm_response_cache: Any = None,
|
|
chunk_key: str = None,
|
|
cache_keys_collector: List = None
|
|
) -> Tuple[str, int]:
|
|
"""
|
|
Optimized LLM call with improved error handling and caching.
|
|
"""
|
|
from lightrag.utils import use_llm_func_with_cache
|
|
|
|
try:
|
|
result, timestamp = await use_llm_func_with_cache(
|
|
prompt,
|
|
use_llm_func,
|
|
system_prompt=system_prompt,
|
|
llm_response_cache=llm_response_cache,
|
|
cache_type="extract",
|
|
chunk_id=chunk_key,
|
|
cache_keys_collector=cache_keys_collector,
|
|
)
|
|
return result, timestamp
|
|
except Exception as e:
|
|
logger.error(f"LLM call failed for chunk {chunk_key}: {e}")
|
|
raise
|
|
|
|
async def _process_extraction_with_validation(
|
|
self,
|
|
result: str,
|
|
chunk_key: str,
|
|
timestamp: int,
|
|
context_base: Dict[str, Any]
|
|
) -> Tuple[Dict, Dict]:
|
|
"""
|
|
Process extraction result with improved validation and error handling.
|
|
"""
|
|
from lightrag.operate import _process_extraction_result
|
|
|
|
try:
|
|
# Use the existing processing function but with additional validation
|
|
maybe_nodes, maybe_edges = await _process_extraction_result(
|
|
result,
|
|
chunk_key,
|
|
timestamp,
|
|
tuple_delimiter=context_base["tuple_delimiter"],
|
|
completion_delimiter=context_base["completion_delimiter"],
|
|
)
|
|
|
|
# Validate the extracted data
|
|
validated_nodes = self._validate_entities(maybe_nodes, chunk_key)
|
|
validated_edges = self._validate_relationships(maybe_edges, chunk_key)
|
|
|
|
return validated_nodes, validated_edges
|
|
|
|
except Exception as e:
|
|
logger.error(f"Extraction processing failed for chunk {chunk_key}: {e}")
|
|
return {}, {}
|
|
|
|
def _validate_entities(self, entities: Dict, chunk_key: str) -> Dict:
|
|
"""Validate extracted entities for consistency and quality."""
|
|
validated = {}
|
|
|
|
for entity_name, entity_list in entities.items():
|
|
if not entity_name or not entity_name.strip():
|
|
logger.warning(f"Empty entity name in chunk {chunk_key}")
|
|
continue
|
|
|
|
valid_entities = []
|
|
for entity in entity_list:
|
|
if self._is_valid_entity(entity):
|
|
valid_entities.append(entity)
|
|
else:
|
|
logger.warning(f"Invalid entity data in chunk {chunk_key}: {entity}")
|
|
|
|
if valid_entities:
|
|
validated[entity_name] = valid_entities
|
|
|
|
return validated
|
|
|
|
def _validate_relationships(self, relationships: Dict, chunk_key: str) -> Dict:
|
|
"""Validate extracted relationships for consistency and quality."""
|
|
validated = {}
|
|
|
|
for edge_key, edge_list in relationships.items():
|
|
if not isinstance(edge_key, tuple) or len(edge_key) != 2:
|
|
logger.warning(f"Invalid edge key format in chunk {chunk_key}: {edge_key}")
|
|
continue
|
|
|
|
valid_edges = []
|
|
for edge in edge_list:
|
|
if self._is_valid_relationship(edge):
|
|
valid_edges.append(edge)
|
|
else:
|
|
logger.warning(f"Invalid relationship data in chunk {chunk_key}: {edge}")
|
|
|
|
if valid_edges:
|
|
validated[edge_key] = valid_edges
|
|
|
|
return validated
|
|
|
|
def _is_valid_entity(self, entity: Dict) -> bool:
|
|
"""Check if entity data is valid."""
|
|
required_fields = ["entity_name", "entity_type", "description"]
|
|
return all(field in entity and entity[field] for field in required_fields)
|
|
|
|
def _is_valid_relationship(self, relationship: Dict) -> bool:
|
|
"""Check if relationship data is valid."""
|
|
required_fields = ["src_id", "tgt_id", "description"]
|
|
return all(field in relationship and relationship[field] for field in required_fields)
|
|
|
|
async def _apply_gleaning_optimized(
|
|
self,
|
|
content: str,
|
|
context_base: Dict[str, Any],
|
|
use_llm_func: callable,
|
|
llm_response_cache: Any,
|
|
chunk_key: str,
|
|
current_nodes: Dict,
|
|
current_edges: Dict,
|
|
system_prompt: str,
|
|
initial_result: str
|
|
) -> Tuple[Dict, Dict]:
|
|
"""
|
|
Apply gleaning with improved error handling and result merging.
|
|
"""
|
|
from lightrag.utils import pack_user_ass_to_openai_messages, use_llm_func_with_cache
|
|
|
|
try:
|
|
entity_continue_extraction_user_prompt = PROMPTS[
|
|
"entity_continue_extraction_user_prompt"
|
|
].format(**{**context_base, "input_text": content})
|
|
|
|
history = pack_user_ass_to_openai_messages(
|
|
PROMPTS["entity_extraction_user_prompt"].format(**{**context_base, "input_text": content}),
|
|
initial_result
|
|
)
|
|
|
|
glean_result, timestamp = await use_llm_func_with_cache(
|
|
entity_continue_extraction_user_prompt,
|
|
use_llm_func,
|
|
system_prompt=system_prompt,
|
|
llm_response_cache=llm_response_cache,
|
|
history_messages=history,
|
|
cache_type="extract",
|
|
chunk_id=chunk_key,
|
|
)
|
|
|
|
# Process gleaning result
|
|
glean_nodes, glean_edges = await self._process_extraction_with_validation(
|
|
glean_result,
|
|
chunk_key,
|
|
timestamp,
|
|
context_base
|
|
)
|
|
|
|
# Merge results with improved conflict resolution
|
|
merged_nodes = self._merge_entities(current_nodes, glean_nodes)
|
|
merged_edges = self._merge_relationships(current_edges, glean_edges)
|
|
|
|
return merged_nodes, merged_edges
|
|
|
|
except Exception as e:
|
|
logger.error(f"Gleaning failed for chunk {chunk_key}: {e}")
|
|
return current_nodes, current_edges
|
|
|
|
def _merge_entities(self, current: Dict, new: Dict) -> Dict:
|
|
"""Merge entities with improved conflict resolution."""
|
|
merged = current.copy()
|
|
|
|
for entity_name, new_entities in new.items():
|
|
if entity_name not in merged:
|
|
merged[entity_name] = new_entities
|
|
else:
|
|
# Compare and keep the better description
|
|
current_desc = merged[entity_name][0].get("description", "")
|
|
new_desc = new_entities[0].get("description", "")
|
|
|
|
if len(new_desc) > len(current_desc):
|
|
merged[entity_name] = new_entities
|
|
|
|
return merged
|
|
|
|
def _merge_relationships(self, current: Dict, new: Dict) -> Dict:
|
|
"""Merge relationships with improved conflict resolution."""
|
|
merged = current.copy()
|
|
|
|
for edge_key, new_edges in new.items():
|
|
if edge_key not in merged:
|
|
merged[edge_key] = new_edges
|
|
else:
|
|
# Compare and keep the better description
|
|
current_desc = merged[edge_key][0].get("description", "")
|
|
new_desc = new_edges[0].get("description", "")
|
|
|
|
if len(new_desc) > len(current_desc):
|
|
merged[edge_key] = new_edges
|
|
|
|
return merged
|
|
|
|
def enable_optimization(self):
|
|
"""Enable the optimization."""
|
|
self.optimization_enabled = True
|
|
logger.info("Entity extraction optimization enabled")
|
|
|
|
def disable_optimization(self):
|
|
"""Disable the optimization."""
|
|
self.optimization_enabled = False
|
|
logger.info("Entity extraction optimization disabled")
|
|
|
|
|
|
# Global optimizer instance
|
|
entity_extraction_optimizer = EntityExtractionOptimizer()
|
|
|
|
|
|
async def extract_entities_optimized(
|
|
chunks: Dict[str, Any],
|
|
global_config: Dict[str, str],
|
|
pipeline_status: Dict = None,
|
|
pipeline_status_lock = None,
|
|
llm_response_cache: Any = None,
|
|
text_chunks_storage: Any = None,
|
|
) -> List:
|
|
"""
|
|
Optimized version of extract_entities with improved formatting and error handling.
|
|
"""
|
|
from lightrag.operate import extract_entities
|
|
|
|
if not entity_extraction_optimizer.optimization_enabled:
|
|
# Fall back to original implementation
|
|
return await extract_entities(
|
|
chunks, global_config, pipeline_status, pipeline_status_lock,
|
|
llm_response_cache, text_chunks_storage
|
|
)
|
|
|
|
# Extract configuration
|
|
use_llm_func = global_config["llm_model_func"]
|
|
entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
|
|
language = global_config["addon_params"].get("language", "English")
|
|
entity_types = global_config["addon_params"].get("entity_types", [])
|
|
|
|
# Build context base
|
|
examples = "\n".join(PROMPTS["entity_extraction_examples"])
|
|
example_context_base = {
|
|
"tuple_delimiter": PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
|
"completion_delimiter": PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
|
"entity_types": ", ".join(entity_types),
|
|
"language": language,
|
|
}
|
|
examples = examples.format(**example_context_base)
|
|
|
|
context_base = {
|
|
"tuple_delimiter": PROMPTS["DEFAULT_TUPLE_DELIMITER"],
|
|
"completion_delimiter": PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
|
|
"entity_types": ",".join(entity_types),
|
|
"examples": examples,
|
|
"language": language,
|
|
"entity_extract_max_gleaning": entity_extract_max_gleaning,
|
|
}
|
|
|
|
ordered_chunks = list(chunks.items())
|
|
total_chunks = len(ordered_chunks)
|
|
processed_chunks = 0
|
|
|
|
async def _process_single_chunk_optimized(chunk_key_dp):
|
|
nonlocal processed_chunks
|
|
chunk_key = chunk_key_dp[0]
|
|
chunk_dp = chunk_key_dp[1]
|
|
content = chunk_dp["content"]
|
|
file_path = chunk_dp.get("file_path", "unknown_source")
|
|
|
|
try:
|
|
# Use optimized extraction
|
|
maybe_nodes, maybe_edges = await entity_extraction_optimizer.optimize_extraction_formatting(
|
|
content=content,
|
|
context_base=context_base,
|
|
use_llm_func=use_llm_func,
|
|
llm_response_cache=llm_response_cache,
|
|
chunk_key=chunk_key
|
|
)
|
|
|
|
processed_chunks += 1
|
|
entities_count = len(maybe_nodes)
|
|
relations_count = len(maybe_edges)
|
|
|
|
log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel (optimized)"
|
|
logger.info(log_message)
|
|
|
|
if pipeline_status is not None and pipeline_status_lock is not None:
|
|
async with pipeline_status_lock:
|
|
pipeline_status["latest_message"] = log_message
|
|
pipeline_status["history_messages"].append(log_message)
|
|
|
|
return maybe_nodes, maybe_edges
|
|
|
|
except Exception as e:
|
|
logger.error(f"Optimized extraction failed for chunk {chunk_key}: {e}")
|
|
# Return empty results on failure
|
|
return {}, {}
|
|
|
|
# Get max async tasks limit
|
|
chunk_max_async = global_config.get("llm_model_max_async", 4)
|
|
semaphore = asyncio.Semaphore(chunk_max_async)
|
|
|
|
async def _process_with_semaphore(chunk):
|
|
async with semaphore:
|
|
return await _process_single_chunk_optimized(chunk)
|
|
|
|
# Process chunks concurrently
|
|
tasks = []
|
|
for chunk in ordered_chunks:
|
|
task = asyncio.create_task(_process_with_semaphore(chunk))
|
|
tasks.append(task)
|
|
|
|
# Wait for all tasks to complete
|
|
chunk_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Filter out exceptions and return valid results
|
|
valid_results = []
|
|
for result in chunk_results:
|
|
if not isinstance(result, Exception):
|
|
valid_results.append(result)
|
|
else:
|
|
logger.error(f"Chunk processing failed: {result}")
|
|
valid_results.append(({}, {})) # Empty result for failed chunks
|
|
|
|
return valid_results |