Files
railseek6/LightRAG-main/lightrag/optimized_ocr_processor.py
2026-01-14 15:15:01 +08:00

696 lines
26 KiB
Python

"""
Optimized OCR Processor with Batch Processing and Async Support
Replaces the inefficient process-per-request approach with shared model instance
"""
import os
import logging
import asyncio
import concurrent.futures
import threading
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
import tempfile
from pathlib import Path
import json
import time
# Configure logging
logger = logging.getLogger(__name__)
@dataclass
class BatchOCRResult:
"""Batch OCR processing result"""
image_path: str
text: str
confidence: float
bboxes: List
line_count: int
processing_time: float
error: Optional[str] = None
class OptimizedOCRProcessor:
"""
Optimized OCR processor with batch processing, shared model instance, and async support
"""
def __init__(self, use_gpu: bool = True, languages: List[str] = None,
batch_size: int = 4, max_workers: int = 2):
"""
Initialize optimized OCR processor
Args:
use_gpu: Whether to use GPU acceleration
languages: List of languages for OCR
batch_size: Number of images to process in each batch
max_workers: Maximum number of parallel workers
"""
self.use_gpu = use_gpu
self.languages = languages or ['en', 'ch']
self.batch_size = batch_size
self.max_workers = max_workers
self.ocr_available = False
self._ocr_engine = None
self._model_loaded = False
self._temp_dir = None
self._executor = None
self._initialization_lock = threading.Lock()
self._initialization_thread = None
self._initialization_started = False
# Performance metrics
self.metrics = {
"total_images_processed": 0,
"total_processing_time": 0.0,
"batch_processing_times": [],
"errors": []
}
# Start lazy initialization in background thread
self._start_lazy_initialization()
def _start_lazy_initialization(self):
"""Start OCR initialization in a background thread."""
with self._initialization_lock:
if self._initialization_started:
return
self._initialization_started = True
# Start thread
self._initialization_thread = threading.Thread(
target=self._initialize_ocr,
name="OCRInitializer",
daemon=True
)
self._initialization_thread.start()
logger.info("Started lazy OCR initialization in background thread")
def _ensure_ocr_initialized(self, timeout: float = None):
"""
Block until OCR initialization is complete.
If timeout is None, wait indefinitely.
Returns True if OCR is available, False otherwise.
"""
if self.ocr_available:
return True
if not self._initialization_started:
self._start_lazy_initialization()
if self._initialization_thread is not None:
self._initialization_thread.join(timeout=timeout)
# After join, check if OCR is now available
return self.ocr_available
def _initialize_ocr(self):
"""Initialize PaddleOCR with shared model instance"""
try:
logger.info("Initializing optimized OCR processor with batch support")
# Test if PaddleOCR can be imported
try:
from paddleocr import PaddleOCR
self._ocr_engine = PaddleOCR(
use_gpu=self.use_gpu,
use_angle_cls=True,
lang='en',
show_log=False,
gpu_mem=2000
)
self._model_loaded = True
self.ocr_available = True
logger.info("PaddleOCR initialized successfully with shared model instance")
# Warm up the model with a dummy image
self._warm_up_model()
except ImportError as e:
logger.error(f"Failed to import PaddleOCR: {e}")
self.ocr_available = False
return
except Exception as e:
logger.error(f"Failed to initialize PaddleOCR: {e}")
self.ocr_available = False
return
# Create temporary directory for batch processing
self._temp_dir = tempfile.mkdtemp(prefix="optimized_ocr_")
# Create thread pool executor for async operations
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.max_workers,
thread_name_prefix="ocr_worker"
)
logger.info(f"Optimized OCR processor ready (GPU: {self.use_gpu}, "
f"batch_size: {self.batch_size}, workers: {self.max_workers})")
except Exception as e:
logger.error(f"Failed to initialize optimized OCR processor: {e}")
self.ocr_available = False
def _warm_up_model(self):
"""Warm up the OCR model with a dummy image to reduce first-call latency"""
try:
# Create a simple test image
import numpy as np
test_image = np.ones((100, 100, 3), dtype=np.uint8) * 255
# Warm up with a small batch
self._ocr_engine.ocr(test_image, cls=True)
logger.info("OCR model warmed up successfully")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
def extract_text_from_image(self, image_path: str) -> Dict[str, Any]:
"""
Extract text from a single image (backward compatibility)
Args:
image_path: Path to image file
Returns:
OCR result dictionary
"""
start_time = time.time()
# Ensure OCR is initialized (wait up to 30 seconds)
if not self._ensure_ocr_initialized(timeout=30.0):
logger.warning("OCR not available after waiting")
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
try:
# Process single image using batch method for consistency
batch_result = self._process_batch_internal([image_path])
if batch_result and len(batch_result) > 0:
result = batch_result[0]
processing_time = time.time() - start_time
# Update metrics
self.metrics["total_images_processed"] += 1
self.metrics["total_processing_time"] += processing_time
return {
"text": result.text,
"confidence": result.confidence,
"bboxes": result.bboxes,
"line_count": result.line_count,
"processing_time": processing_time
}
else:
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
self.metrics["errors"].append(str(e))
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
def extract_text_from_images_batch(self, image_paths: List[str]) -> List[BatchOCRResult]:
"""
Extract text from multiple images in a batch
Args:
image_paths: List of image file paths
Returns:
List of BatchOCRResult objects
"""
if not image_paths:
return []
batch_start_time = time.time()
# Ensure OCR is initialized (wait up to 30 seconds)
if not self._ensure_ocr_initialized(timeout=30.0):
logger.warning("OCR not available for batch processing")
return [BatchOCRResult(
image_path=path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=0.0,
error="OCR not available"
) for path in image_paths]
try:
# Process images in batches
all_results = []
for i in range(0, len(image_paths), self.batch_size):
batch = image_paths[i:i + self.batch_size]
batch_results = self._process_batch_internal(batch)
all_results.extend(batch_results)
batch_time = time.time() - batch_start_time
# Update metrics
self.metrics["total_images_processed"] += len(image_paths)
self.metrics["total_processing_time"] += batch_time
self.metrics["batch_processing_times"].append({
"batch_size": len(image_paths),
"processing_time": batch_time,
"avg_time_per_image": batch_time / len(image_paths) if image_paths else 0
})
logger.info(f"Processed {len(image_paths)} images in {batch_time:.2f}s "
f"({batch_time/len(image_paths):.2f}s per image)")
return all_results
except Exception as e:
logger.error(f"Error in batch processing: {e}")
self.metrics["errors"].append(str(e))
# Fall back to individual processing
results = []
for path in image_paths:
try:
single_result = self.extract_text_from_image(path)
results.append(BatchOCRResult(
image_path=path,
text=single_result["text"],
confidence=single_result["confidence"],
bboxes=single_result["bboxes"],
line_count=single_result["line_count"],
processing_time=single_result.get("processing_time", 0.0)
))
except Exception as img_error:
results.append(BatchOCRResult(
image_path=path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=0.0,
error=str(img_error)
))
return results
def _process_batch_internal(self, image_paths: List[str]) -> List[BatchOCRResult]:
"""
Internal batch processing method
Args:
image_paths: List of image file paths
Returns:
List of BatchOCRResult objects
"""
results = []
try:
# Process all images in the batch
for image_path in image_paths:
image_start_time = time.time()
try:
# Use the shared OCR engine
ocr_result = self._ocr_engine.ocr(image_path, cls=True)
# Parse OCR result
text_lines = []
bboxes = []
total_confidence = 0.0
line_count = 0
if ocr_result and ocr_result[0]:
for line in ocr_result[0]:
if line and len(line) >= 2:
bbox, (text, confidence) = line
text_lines.append(str(text))
bboxes.append(bbox)
total_confidence += float(confidence) if confidence else 0.0
line_count += 1
avg_confidence = total_confidence / line_count if line_count > 0 else 0.0
full_text = "\n".join(text_lines)
processing_time = time.time() - image_start_time
results.append(BatchOCRResult(
image_path=image_path,
text=full_text,
confidence=avg_confidence,
bboxes=bboxes,
line_count=line_count,
processing_time=processing_time
))
except Exception as e:
processing_time = time.time() - image_start_time
logger.error(f"Error processing {image_path}: {e}")
results.append(BatchOCRResult(
image_path=image_path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=processing_time,
error=str(e)
))
return results
except Exception as e:
logger.error(f"Error in internal batch processing: {e}")
raise
async def extract_text_from_images_batch_async(self, image_paths: List[str]) -> List[BatchOCRResult]:
"""
Async version of batch processing
Args:
image_paths: List of image file paths
Returns:
List of BatchOCRResult objects
"""
loop = asyncio.get_event_loop()
# Run batch processing in thread pool
future = loop.run_in_executor(
self._executor,
self.extract_text_from_images_batch,
image_paths
)
try:
results = await asyncio.wait_for(future, timeout=300) # 5 minute timeout
return results
except asyncio.TimeoutError:
logger.error("OCR batch processing timeout")
return [BatchOCRResult(
image_path=path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=0.0,
error="Processing timeout"
) for path in image_paths]
def extract_tables_from_image(self, image_path: str) -> List[Dict[str, Any]]:
"""
Extract tables from image using OCR and layout analysis
Args:
image_path: Path to image file
Returns:
List of table dictionaries
"""
try:
# Use OCR to get text with bounding boxes
ocr_result = self.extract_text_from_image(image_path)
# Simple table detection based on text alignment
tables = self._detect_tables_from_bboxes(ocr_result["bboxes"], ocr_result["text"])
return tables
except Exception as e:
logger.error(f"Table extraction failed: {e}")
return []
def _detect_tables_from_bboxes(self, bboxes: List, text: str) -> List[Dict[str, Any]]:
"""
Enhanced table detection from OCR bounding boxes with improved accuracy
Features:
1. Adaptive row grouping based on text height
2. Column alignment detection using common x-coordinates
3. Header row detection based on formatting patterns
4. Table boundary validation
5. Multi-table detection in single image
"""
tables = []
if not bboxes or len(bboxes) < 4: # Need at least 4 text elements for a table
return tables
text_lines = text.split('\n') if text else []
# Step 1: Calculate text height statistics for adaptive row grouping
text_heights = []
for bbox in bboxes:
if not bbox or len(bbox) < 4:
continue
try:
# Get min and max y coordinates
y_coords = [float(point[1]) for point in bbox if point and len(point) >= 2]
if y_coords:
height = max(y_coords) - min(y_coords)
if height > 0:
text_heights.append(height)
except (TypeError, ValueError, IndexError):
continue
avg_text_height = sum(text_heights) / len(text_heights) if text_heights else 20.0
row_tolerance = avg_text_height * 0.8 # 80% of text height for row grouping
# Step 2: Group text by rows with adaptive tolerance
rows = {}
for i, bbox in enumerate(bboxes):
try:
if not bbox or len(bbox) < 4:
continue
# Calculate y-center of bounding box
y_values = []
for point in bbox:
if point and len(point) >= 2:
try:
y_val = point[1]
if isinstance(y_val, (int, float)):
y_values.append(float(y_val))
elif isinstance(y_val, str):
y_values.append(float(y_val))
else:
y_values.append(0.0)
except (TypeError, ValueError):
y_values.append(0.0)
else:
y_values.append(0.0)
if not y_values:
continue
y_center = sum(y_values) / len(y_values)
# Find existing row or create new one
row_found = False
for row_key in list(rows.keys()):
if abs(y_center - row_key) <= row_tolerance:
rows[row_key].append((bbox, text_lines[i] if i < len(text_lines) else ""))
row_found = True
break
if not row_found:
rows[y_center] = [(bbox, text_lines[i] if i < len(text_lines) else "")]
except Exception as e:
logger.debug(f"Error processing bbox {i} for table detection: {e}")
continue
if len(rows) < 2: # Need at least 2 rows for a table
return tables
# Step 3: Sort rows by y-coordinate and process each row
sorted_row_keys = sorted(rows.keys())
sorted_rows = [rows[key] for key in sorted_row_keys]
# Step 4: Detect column positions using x-coordinate clustering
all_x_centers = []
for row in sorted_rows:
for bbox, _ in row:
try:
if bbox and len(bbox) >= 4:
x_coords = [float(point[0]) for point in bbox if point and len(point) >= 1]
if x_coords:
x_center = sum(x_coords) / len(x_coords)
all_x_centers.append(x_center)
except (TypeError, ValueError, IndexError):
continue
if not all_x_centers:
return tables
# Simple column clustering: sort x-centers and group by proximity
all_x_centers.sort()
column_positions = []
current_cluster = [all_x_centers[0]]
for x in all_x_centers[1:]:
if x - current_cluster[-1] <= avg_text_height * 1.5: # 1.5x text width tolerance
current_cluster.append(x)
else:
column_positions.append(sum(current_cluster) / len(current_cluster))
current_cluster = [x]
if current_cluster:
column_positions.append(sum(current_cluster) / len(current_cluster))
# Need at least 2 columns for a table
if len(column_positions) < 2:
return tables
# Step 5: Create table structure with proper cell alignment
column_positions.sort()
table_data = []
column_count = len(column_positions)
for row in sorted_rows:
# Sort row items by x-coordinate
def get_x_center(item):
try:
bbox = item[0]
if bbox and len(bbox) >= 4:
x_coords = [float(point[0]) for point in bbox if point and len(point) >= 1]
return sum(x_coords) / len(x_coords) if x_coords else 0.0
except (TypeError, ValueError, IndexError):
pass
return 0.0
sorted_row = sorted(row, key=get_x_center)
# Create row with cells aligned to columns
row_cells = [""] * column_count
for bbox, cell_text in sorted_row:
try:
x_center = get_x_center((bbox, cell_text))
# Find closest column
if column_positions:
closest_col = min(range(column_count),
key=lambda i: abs(x_center - column_positions[i]))
# Only assign if cell is empty or this text is closer to column center
if not row_cells[closest_col] or \
abs(x_center - column_positions[closest_col]) < avg_text_height * 0.5:
row_cells[closest_col] = cell_text
except Exception:
continue
# Only add row if it has meaningful content (not all empty)
if any(cell.strip() for cell in row_cells):
table_data.append(row_cells)
# Step 6: Validate table structure
if len(table_data) >= 2 and column_count >= 2:
# Calculate table consistency score
non_empty_cells = sum(1 for row in table_data for cell in row if cell.strip())
total_cells = len(table_data) * column_count
fill_ratio = non_empty_cells / total_cells if total_cells > 0 else 0
# Only accept tables with reasonable fill ratio (20-90%)
if 0.2 <= fill_ratio <= 0.9:
# Detect potential header row (first row often has different characteristics)
has_header = False
if len(table_data) >= 3:
# Check if first row has more text or different formatting
first_row_text_len = sum(len(cell) for cell in table_data[0])
second_row_text_len = sum(len(cell) for cell in table_data[1])
if first_row_text_len > second_row_text_len * 1.5:
has_header = True
tables.append({
"data": table_data,
"rows": len(table_data),
"columns": column_count,
"has_header": has_header,
"fill_ratio": fill_ratio,
"type": "detected_table"
})
return tables
def get_metrics(self) -> Dict[str, Any]:
"""
Get performance metrics
Returns:
Dictionary of performance metrics
"""
avg_time_per_image = 0.0
if self.metrics["total_images_processed"] > 0:
avg_time_per_image = self.metrics["total_processing_time"] / self.metrics["total_images_processed"]
return {
**self.metrics,
"avg_time_per_image": avg_time_per_image,
"ocr_available": self.ocr_available,
"model_loaded": self._model_loaded,
"batch_size": self.batch_size,
"max_workers": self.max_workers,
"use_gpu": self.use_gpu
}
def close(self):
"""Clean up resources"""
if self._executor:
self._executor.shutdown(wait=True)
self._executor = None
if self._temp_dir and os.path.exists(self._temp_dir):
import shutil
try:
shutil.rmtree(self._temp_dir)
except Exception as e:
logger.warning(f"Failed to delete temporary directory: {e}")
self._ocr_engine = None
self._model_loaded = False
self.ocr_available = False
logger.info("Optimized OCR processor closed")
def __del__(self):
"""Destructor to ensure cleanup"""
self.close()
# Performance test function
def test_optimized_ocr_performance():
"""Test optimized OCR processor performance"""
print("=== Testing Optimized OCR Processor ===")
# Create test processor
processor = OptimizedOCRProcessor(use_gpu=True, batch_size=4, max_workers=2)
if not processor.ocr_available:
print("OCR not available, skipping test")
return
# Generate dummy image paths (in real test, these would be actual images)
test_images = [f"test_image_{i}.png" for i in range(8)]
# Test single processing
print("\n1. Testing single image processing:")
single_start = time.time()
single_results = []
for img in test_images[:2]: # Test with 2 images
result = processor.extract_text_from_image(img)
single_results.append(result)
single_time = time.time() - single_start
print(f" Processed {len(test_images[:2])} images in {single_time:.2f}s")
# Test batch processing
print("\n2. Testing batch processing (4 images/batch):")
batch_start = time.time()
batch_results = processor.extract_text_from_images_batch(test_images)
batch_time = time.time() - batch_start
print(f" Processed {len(test_images)} images in {batch_time:.2f}s")
print(f" Average time per image: {batch_time/len(test_images):.2f}s")
# Show metrics
print("\n3. Performance Metrics:")
metrics = processor.get_metrics()
for key, value in metrics.items():
if key not in ["batch_processing_times", "errors", "bboxes"]:
print(f" {key}: {value}")
processor.close()
return single_results, batch_results
if __name__ == "__main__":
# Run performance test
print("Running optimized OCR processor performance test...")
test_optimized_ocr_performance()