""" Optimization 1: Proper LLM formatting integration for entity extraction This optimization improves the entity extraction formatting to reduce LLM retries and improve parsing accuracy. """ import json import asyncio from typing import Dict, List, Tuple, Any from lightrag.utils import logger from lightrag.prompt import PROMPTS class EntityExtractionOptimizer: """Optimize entity extraction formatting to reduce LLM retries and improve parsing accuracy.""" def __init__(self): self.optimization_enabled = True self.original_formatting = None async def optimize_extraction_formatting( self, content: str, context_base: Dict[str, Any], use_llm_func: callable, llm_response_cache: Any = None, chunk_key: str = None ) -> Tuple[Dict, Dict]: """ Optimized entity extraction with improved formatting and error handling. Args: content: Text content to extract entities from context_base: Context parameters for the prompt use_llm_func: LLM function to call llm_response_cache: Cache for LLM responses chunk_key: Chunk identifier for caching Returns: Tuple of (entities_dict, relationships_dict) """ if not self.optimization_enabled: # Fall back to original implementation from lightrag.operate import _process_single_content return await _process_single_content((chunk_key, {"content": content})) try: # Get initial extraction with optimized formatting entity_extraction_system_prompt = PROMPTS[ "entity_extraction_system_prompt" ].format(**{**context_base, "input_text": content}) entity_extraction_user_prompt = PROMPTS["entity_extraction_user_prompt"].format( **{**context_base, "input_text": content} ) # Use the optimized LLM call with better formatting final_result, timestamp = await self._optimized_llm_call( entity_extraction_user_prompt, use_llm_func, system_prompt=entity_extraction_system_prompt, llm_response_cache=llm_response_cache, chunk_key=chunk_key ) # Process extraction with improved error handling maybe_nodes, maybe_edges = await self._process_extraction_with_validation( final_result, chunk_key, timestamp, context_base ) # Apply gleaning if enabled if context_base.get("entity_extract_max_gleaning", 0) > 0: maybe_nodes, maybe_edges = await self._apply_gleaning_optimized( content, context_base, use_llm_func, llm_response_cache, chunk_key, maybe_nodes, maybe_edges, entity_extraction_system_prompt, final_result ) return maybe_nodes, maybe_edges except Exception as e: logger.error(f"Optimized extraction failed for chunk {chunk_key}: {e}") # Fall back to original implementation from lightrag.operate import _process_single_content return await _process_single_content((chunk_key, {"content": content})) async def _optimized_llm_call( self, prompt: str, use_llm_func: callable, system_prompt: str = None, llm_response_cache: Any = None, chunk_key: str = None, cache_keys_collector: List = None ) -> Tuple[str, int]: """ Optimized LLM call with improved error handling and caching. """ from lightrag.utils import use_llm_func_with_cache try: result, timestamp = await use_llm_func_with_cache( prompt, use_llm_func, system_prompt=system_prompt, llm_response_cache=llm_response_cache, cache_type="extract", chunk_id=chunk_key, cache_keys_collector=cache_keys_collector, ) return result, timestamp except Exception as e: logger.error(f"LLM call failed for chunk {chunk_key}: {e}") raise async def _process_extraction_with_validation( self, result: str, chunk_key: str, timestamp: int, context_base: Dict[str, Any] ) -> Tuple[Dict, Dict]: """ Process extraction result with improved validation and error handling. """ from lightrag.operate import _process_extraction_result try: # Use the existing processing function but with additional validation maybe_nodes, maybe_edges = await _process_extraction_result( result, chunk_key, timestamp, tuple_delimiter=context_base["tuple_delimiter"], completion_delimiter=context_base["completion_delimiter"], ) # Validate the extracted data validated_nodes = self._validate_entities(maybe_nodes, chunk_key) validated_edges = self._validate_relationships(maybe_edges, chunk_key) return validated_nodes, validated_edges except Exception as e: logger.error(f"Extraction processing failed for chunk {chunk_key}: {e}") return {}, {} def _validate_entities(self, entities: Dict, chunk_key: str) -> Dict: """Validate extracted entities for consistency and quality.""" validated = {} for entity_name, entity_list in entities.items(): if not entity_name or not entity_name.strip(): logger.warning(f"Empty entity name in chunk {chunk_key}") continue valid_entities = [] for entity in entity_list: if self._is_valid_entity(entity): valid_entities.append(entity) else: logger.warning(f"Invalid entity data in chunk {chunk_key}: {entity}") if valid_entities: validated[entity_name] = valid_entities return validated def _validate_relationships(self, relationships: Dict, chunk_key: str) -> Dict: """Validate extracted relationships for consistency and quality.""" validated = {} for edge_key, edge_list in relationships.items(): if not isinstance(edge_key, tuple) or len(edge_key) != 2: logger.warning(f"Invalid edge key format in chunk {chunk_key}: {edge_key}") continue valid_edges = [] for edge in edge_list: if self._is_valid_relationship(edge): valid_edges.append(edge) else: logger.warning(f"Invalid relationship data in chunk {chunk_key}: {edge}") if valid_edges: validated[edge_key] = valid_edges return validated def _is_valid_entity(self, entity: Dict) -> bool: """Check if entity data is valid.""" required_fields = ["entity_name", "entity_type", "description"] return all(field in entity and entity[field] for field in required_fields) def _is_valid_relationship(self, relationship: Dict) -> bool: """Check if relationship data is valid.""" required_fields = ["src_id", "tgt_id", "description"] return all(field in relationship and relationship[field] for field in required_fields) async def _apply_gleaning_optimized( self, content: str, context_base: Dict[str, Any], use_llm_func: callable, llm_response_cache: Any, chunk_key: str, current_nodes: Dict, current_edges: Dict, system_prompt: str, initial_result: str ) -> Tuple[Dict, Dict]: """ Apply gleaning with improved error handling and result merging. """ from lightrag.utils import pack_user_ass_to_openai_messages, use_llm_func_with_cache try: entity_continue_extraction_user_prompt = PROMPTS[ "entity_continue_extraction_user_prompt" ].format(**{**context_base, "input_text": content}) history = pack_user_ass_to_openai_messages( PROMPTS["entity_extraction_user_prompt"].format(**{**context_base, "input_text": content}), initial_result ) glean_result, timestamp = await use_llm_func_with_cache( entity_continue_extraction_user_prompt, use_llm_func, system_prompt=system_prompt, llm_response_cache=llm_response_cache, history_messages=history, cache_type="extract", chunk_id=chunk_key, ) # Process gleaning result glean_nodes, glean_edges = await self._process_extraction_with_validation( glean_result, chunk_key, timestamp, context_base ) # Merge results with improved conflict resolution merged_nodes = self._merge_entities(current_nodes, glean_nodes) merged_edges = self._merge_relationships(current_edges, glean_edges) return merged_nodes, merged_edges except Exception as e: logger.error(f"Gleaning failed for chunk {chunk_key}: {e}") return current_nodes, current_edges def _merge_entities(self, current: Dict, new: Dict) -> Dict: """Merge entities with improved conflict resolution.""" merged = current.copy() for entity_name, new_entities in new.items(): if entity_name not in merged: merged[entity_name] = new_entities else: # Compare and keep the better description current_desc = merged[entity_name][0].get("description", "") new_desc = new_entities[0].get("description", "") if len(new_desc) > len(current_desc): merged[entity_name] = new_entities return merged def _merge_relationships(self, current: Dict, new: Dict) -> Dict: """Merge relationships with improved conflict resolution.""" merged = current.copy() for edge_key, new_edges in new.items(): if edge_key not in merged: merged[edge_key] = new_edges else: # Compare and keep the better description current_desc = merged[edge_key][0].get("description", "") new_desc = new_edges[0].get("description", "") if len(new_desc) > len(current_desc): merged[edge_key] = new_edges return merged def enable_optimization(self): """Enable the optimization.""" self.optimization_enabled = True logger.info("Entity extraction optimization enabled") def disable_optimization(self): """Disable the optimization.""" self.optimization_enabled = False logger.info("Entity extraction optimization disabled") # Global optimizer instance entity_extraction_optimizer = EntityExtractionOptimizer() async def extract_entities_optimized( chunks: Dict[str, Any], global_config: Dict[str, str], pipeline_status: Dict = None, pipeline_status_lock = None, llm_response_cache: Any = None, text_chunks_storage: Any = None, ) -> List: """ Optimized version of extract_entities with improved formatting and error handling. """ from lightrag.operate import extract_entities if not entity_extraction_optimizer.optimization_enabled: # Fall back to original implementation return await extract_entities( chunks, global_config, pipeline_status, pipeline_status_lock, llm_response_cache, text_chunks_storage ) # Extract configuration use_llm_func = global_config["llm_model_func"] entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"] language = global_config["addon_params"].get("language", "English") entity_types = global_config["addon_params"].get("entity_types", []) # Build context base examples = "\n".join(PROMPTS["entity_extraction_examples"]) example_context_base = { "tuple_delimiter": PROMPTS["DEFAULT_TUPLE_DELIMITER"], "completion_delimiter": PROMPTS["DEFAULT_COMPLETION_DELIMITER"], "entity_types": ", ".join(entity_types), "language": language, } examples = examples.format(**example_context_base) context_base = { "tuple_delimiter": PROMPTS["DEFAULT_TUPLE_DELIMITER"], "completion_delimiter": PROMPTS["DEFAULT_COMPLETION_DELIMITER"], "entity_types": ",".join(entity_types), "examples": examples, "language": language, "entity_extract_max_gleaning": entity_extract_max_gleaning, } ordered_chunks = list(chunks.items()) total_chunks = len(ordered_chunks) processed_chunks = 0 async def _process_single_chunk_optimized(chunk_key_dp): nonlocal processed_chunks chunk_key = chunk_key_dp[0] chunk_dp = chunk_key_dp[1] content = chunk_dp["content"] file_path = chunk_dp.get("file_path", "unknown_source") try: # Use optimized extraction maybe_nodes, maybe_edges = await entity_extraction_optimizer.optimize_extraction_formatting( content=content, context_base=context_base, use_llm_func=use_llm_func, llm_response_cache=llm_response_cache, chunk_key=chunk_key ) processed_chunks += 1 entities_count = len(maybe_nodes) relations_count = len(maybe_edges) log_message = f"Chunk {processed_chunks} of {total_chunks} extracted {entities_count} Ent + {relations_count} Rel (optimized)" logger.info(log_message) if pipeline_status is not None and pipeline_status_lock is not None: async with pipeline_status_lock: pipeline_status["latest_message"] = log_message pipeline_status["history_messages"].append(log_message) return maybe_nodes, maybe_edges except Exception as e: logger.error(f"Optimized extraction failed for chunk {chunk_key}: {e}") # Return empty results on failure return {}, {} # Get max async tasks limit chunk_max_async = global_config.get("llm_model_max_async", 4) semaphore = asyncio.Semaphore(chunk_max_async) async def _process_with_semaphore(chunk): async with semaphore: return await _process_single_chunk_optimized(chunk) # Process chunks concurrently tasks = [] for chunk in ordered_chunks: task = asyncio.create_task(_process_with_semaphore(chunk)) tasks.append(task) # Wait for all tasks to complete chunk_results = await asyncio.gather(*tasks, return_exceptions=True) # Filter out exceptions and return valid results valid_results = [] for result in chunk_results: if not isinstance(result, Exception): valid_results.append(result) else: logger.error(f"Chunk processing failed: {result}") valid_results.append(({}, {})) # Empty result for failed chunks return valid_results