Files
railseek6/LightRAG-main/lightrag/optimized_ocr_processor.py
2026-01-13 19:10:24 +08:00

590 lines
21 KiB
Python

"""
Optimized OCR Processor with Batch Processing and Async Support
Replaces the inefficient process-per-request approach with shared model instance
"""
import os
import logging
import asyncio
import concurrent.futures
import threading
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass
import tempfile
from pathlib import Path
import json
import time
# Configure logging
logger = logging.getLogger(__name__)
@dataclass
class BatchOCRResult:
"""Batch OCR processing result"""
image_path: str
text: str
confidence: float
bboxes: List
line_count: int
processing_time: float
error: Optional[str] = None
class OptimizedOCRProcessor:
"""
Optimized OCR processor with batch processing, shared model instance, and async support
"""
def __init__(self, use_gpu: bool = True, languages: List[str] = None,
batch_size: int = 4, max_workers: int = 2):
"""
Initialize optimized OCR processor
Args:
use_gpu: Whether to use GPU acceleration
languages: List of languages for OCR
batch_size: Number of images to process in each batch
max_workers: Maximum number of parallel workers
"""
self.use_gpu = use_gpu
self.languages = languages or ['en', 'ch']
self.batch_size = batch_size
self.max_workers = max_workers
self.ocr_available = False
self._ocr_engine = None
self._model_loaded = False
self._temp_dir = None
self._executor = None
self._initialization_lock = threading.Lock()
self._initialization_thread = None
self._initialization_started = False
# Performance metrics
self.metrics = {
"total_images_processed": 0,
"total_processing_time": 0.0,
"batch_processing_times": [],
"errors": []
}
# Start lazy initialization in background thread
self._start_lazy_initialization()
def _start_lazy_initialization(self):
"""Start OCR initialization in a background thread."""
with self._initialization_lock:
if self._initialization_started:
return
self._initialization_started = True
# Start thread
self._initialization_thread = threading.Thread(
target=self._initialize_ocr,
name="OCRInitializer",
daemon=True
)
self._initialization_thread.start()
logger.info("Started lazy OCR initialization in background thread")
def _ensure_ocr_initialized(self, timeout: float = None):
"""
Block until OCR initialization is complete.
If timeout is None, wait indefinitely.
Returns True if OCR is available, False otherwise.
"""
if self.ocr_available:
return True
if not self._initialization_started:
self._start_lazy_initialization()
if self._initialization_thread is not None:
self._initialization_thread.join(timeout=timeout)
# After join, check if OCR is now available
return self.ocr_available
def _initialize_ocr(self):
"""Initialize PaddleOCR with shared model instance"""
try:
logger.info("Initializing optimized OCR processor with batch support")
# Test if PaddleOCR can be imported
try:
from paddleocr import PaddleOCR
self._ocr_engine = PaddleOCR(
use_gpu=self.use_gpu,
use_angle_cls=True,
lang='en',
show_log=False,
gpu_mem=2000
)
self._model_loaded = True
self.ocr_available = True
logger.info("PaddleOCR initialized successfully with shared model instance")
# Warm up the model with a dummy image
self._warm_up_model()
except ImportError as e:
logger.error(f"Failed to import PaddleOCR: {e}")
self.ocr_available = False
return
except Exception as e:
logger.error(f"Failed to initialize PaddleOCR: {e}")
self.ocr_available = False
return
# Create temporary directory for batch processing
self._temp_dir = tempfile.mkdtemp(prefix="optimized_ocr_")
# Create thread pool executor for async operations
self._executor = concurrent.futures.ThreadPoolExecutor(
max_workers=self.max_workers,
thread_name_prefix="ocr_worker"
)
logger.info(f"Optimized OCR processor ready (GPU: {self.use_gpu}, "
f"batch_size: {self.batch_size}, workers: {self.max_workers})")
except Exception as e:
logger.error(f"Failed to initialize optimized OCR processor: {e}")
self.ocr_available = False
def _warm_up_model(self):
"""Warm up the OCR model with a dummy image to reduce first-call latency"""
try:
# Create a simple test image
import numpy as np
test_image = np.ones((100, 100, 3), dtype=np.uint8) * 255
# Warm up with a small batch
self._ocr_engine.ocr(test_image, cls=True)
logger.info("OCR model warmed up successfully")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
def extract_text_from_image(self, image_path: str) -> Dict[str, Any]:
"""
Extract text from a single image (backward compatibility)
Args:
image_path: Path to image file
Returns:
OCR result dictionary
"""
start_time = time.time()
# Ensure OCR is initialized (wait up to 30 seconds)
if not self._ensure_ocr_initialized(timeout=30.0):
logger.warning("OCR not available after waiting")
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
try:
# Process single image using batch method for consistency
batch_result = self._process_batch_internal([image_path])
if batch_result and len(batch_result) > 0:
result = batch_result[0]
processing_time = time.time() - start_time
# Update metrics
self.metrics["total_images_processed"] += 1
self.metrics["total_processing_time"] += processing_time
return {
"text": result.text,
"confidence": result.confidence,
"bboxes": result.bboxes,
"line_count": result.line_count,
"processing_time": processing_time
}
else:
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
self.metrics["errors"].append(str(e))
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
def extract_text_from_images_batch(self, image_paths: List[str]) -> List[BatchOCRResult]:
"""
Extract text from multiple images in a batch
Args:
image_paths: List of image file paths
Returns:
List of BatchOCRResult objects
"""
if not image_paths:
return []
batch_start_time = time.time()
# Ensure OCR is initialized (wait up to 30 seconds)
if not self._ensure_ocr_initialized(timeout=30.0):
logger.warning("OCR not available for batch processing")
return [BatchOCRResult(
image_path=path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=0.0,
error="OCR not available"
) for path in image_paths]
try:
# Process images in batches
all_results = []
for i in range(0, len(image_paths), self.batch_size):
batch = image_paths[i:i + self.batch_size]
batch_results = self._process_batch_internal(batch)
all_results.extend(batch_results)
batch_time = time.time() - batch_start_time
# Update metrics
self.metrics["total_images_processed"] += len(image_paths)
self.metrics["total_processing_time"] += batch_time
self.metrics["batch_processing_times"].append({
"batch_size": len(image_paths),
"processing_time": batch_time,
"avg_time_per_image": batch_time / len(image_paths) if image_paths else 0
})
logger.info(f"Processed {len(image_paths)} images in {batch_time:.2f}s "
f"({batch_time/len(image_paths):.2f}s per image)")
return all_results
except Exception as e:
logger.error(f"Error in batch processing: {e}")
self.metrics["errors"].append(str(e))
# Fall back to individual processing
results = []
for path in image_paths:
try:
single_result = self.extract_text_from_image(path)
results.append(BatchOCRResult(
image_path=path,
text=single_result["text"],
confidence=single_result["confidence"],
bboxes=single_result["bboxes"],
line_count=single_result["line_count"],
processing_time=single_result.get("processing_time", 0.0)
))
except Exception as img_error:
results.append(BatchOCRResult(
image_path=path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=0.0,
error=str(img_error)
))
return results
def _process_batch_internal(self, image_paths: List[str]) -> List[BatchOCRResult]:
"""
Internal batch processing method
Args:
image_paths: List of image file paths
Returns:
List of BatchOCRResult objects
"""
results = []
try:
# Process all images in the batch
for image_path in image_paths:
image_start_time = time.time()
try:
# Use the shared OCR engine
ocr_result = self._ocr_engine.ocr(image_path, cls=True)
# Parse OCR result
text_lines = []
bboxes = []
total_confidence = 0.0
line_count = 0
if ocr_result and ocr_result[0]:
for line in ocr_result[0]:
if line and len(line) >= 2:
bbox, (text, confidence) = line
text_lines.append(str(text))
bboxes.append(bbox)
total_confidence += float(confidence) if confidence else 0.0
line_count += 1
avg_confidence = total_confidence / line_count if line_count > 0 else 0.0
full_text = "\n".join(text_lines)
processing_time = time.time() - image_start_time
results.append(BatchOCRResult(
image_path=image_path,
text=full_text,
confidence=avg_confidence,
bboxes=bboxes,
line_count=line_count,
processing_time=processing_time
))
except Exception as e:
processing_time = time.time() - image_start_time
logger.error(f"Error processing {image_path}: {e}")
results.append(BatchOCRResult(
image_path=image_path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=processing_time,
error=str(e)
))
return results
except Exception as e:
logger.error(f"Error in internal batch processing: {e}")
raise
async def extract_text_from_images_batch_async(self, image_paths: List[str]) -> List[BatchOCRResult]:
"""
Async version of batch processing
Args:
image_paths: List of image file paths
Returns:
List of BatchOCRResult objects
"""
loop = asyncio.get_event_loop()
# Run batch processing in thread pool
future = loop.run_in_executor(
self._executor,
self.extract_text_from_images_batch,
image_paths
)
try:
results = await asyncio.wait_for(future, timeout=300) # 5 minute timeout
return results
except asyncio.TimeoutError:
logger.error("OCR batch processing timeout")
return [BatchOCRResult(
image_path=path,
text="",
confidence=0.0,
bboxes=[],
line_count=0,
processing_time=0.0,
error="Processing timeout"
) for path in image_paths]
def extract_tables_from_image(self, image_path: str) -> List[Dict[str, Any]]:
"""
Extract tables from image using OCR and layout analysis
Args:
image_path: Path to image file
Returns:
List of table dictionaries
"""
try:
# Use OCR to get text with bounding boxes
ocr_result = self.extract_text_from_image(image_path)
# Simple table detection based on text alignment
tables = self._detect_tables_from_bboxes(ocr_result["bboxes"], ocr_result["text"])
return tables
except Exception as e:
logger.error(f"Table extraction failed: {e}")
return []
def _detect_tables_from_bboxes(self, bboxes: List, text: str) -> List[Dict[str, Any]]:
"""
Detect tables from OCR bounding boxes (compatible with original implementation)
"""
tables = []
if not bboxes:
return tables
# Group text by rows based on y-coordinates
rows = {}
text_lines = text.split('\n') if text else []
for i, bbox in enumerate(bboxes):
try:
if not bbox:
continue
# Calculate y-center of bounding box
y_values = []
for point in bbox:
if point and len(point) >= 2:
try:
y_val = point[1]
if isinstance(y_val, (int, float)):
y_values.append(float(y_val))
elif isinstance(y_val, str):
y_values.append(float(y_val))
else:
y_values.append(0.0)
except (TypeError, ValueError):
y_values.append(0.0)
else:
y_values.append(0.0)
if y_values:
y_center = sum(y_values) / len(y_values)
else:
y_center = 0.0
row_key = round(y_center / 10) # Group by 10-pixel rows
if row_key not in rows:
rows[row_key] = []
row_text = text_lines[i] if i < len(text_lines) else ""
rows[row_key].append((bbox, row_text))
except Exception as e:
logger.warning(f"Error processing bbox {i}: {e}")
continue
# Sort rows and create table structure
sorted_rows = sorted(rows.keys())
table_data = []
for row_key in sorted_rows:
try:
def get_x_coordinate(item):
try:
if (item[0] and len(item[0]) > 0 and
item[0][0] and len(item[0][0]) > 0):
x_val = item[0][0][0]
return float(x_val) if x_val is not None else 0.0
return 0.0
except (TypeError, ValueError, IndexError):
return 0.0
row_items = sorted(rows[row_key], key=get_x_coordinate)
row_text = [item[1] for item in row_items]
table_data.append(row_text)
except Exception as e:
logger.warning(f"Error sorting row {row_key}: {e}")
continue
if len(table_data) > 1: # At least 2 rows for a table
tables.append({
"data": table_data,
"rows": len(table_data),
"columns": max(len(row) for row in table_data) if table_data else 0
})
return tables
def get_metrics(self) -> Dict[str, Any]:
"""
Get performance metrics
Returns:
Dictionary of performance metrics
"""
avg_time_per_image = 0.0
if self.metrics["total_images_processed"] > 0:
avg_time_per_image = self.metrics["total_processing_time"] / self.metrics["total_images_processed"]
return {
**self.metrics,
"avg_time_per_image": avg_time_per_image,
"ocr_available": self.ocr_available,
"model_loaded": self._model_loaded,
"batch_size": self.batch_size,
"max_workers": self.max_workers,
"use_gpu": self.use_gpu
}
def close(self):
"""Clean up resources"""
if self._executor:
self._executor.shutdown(wait=True)
self._executor = None
if self._temp_dir and os.path.exists(self._temp_dir):
import shutil
try:
shutil.rmtree(self._temp_dir)
except Exception as e:
logger.warning(f"Failed to delete temporary directory: {e}")
self._ocr_engine = None
self._model_loaded = False
self.ocr_available = False
logger.info("Optimized OCR processor closed")
def __del__(self):
"""Destructor to ensure cleanup"""
self.close()
# Performance test function
def test_optimized_ocr_performance():
"""Test optimized OCR processor performance"""
print("=== Testing Optimized OCR Processor ===")
# Create test processor
processor = OptimizedOCRProcessor(use_gpu=True, batch_size=4, max_workers=2)
if not processor.ocr_available:
print("OCR not available, skipping test")
return
# Generate dummy image paths (in real test, these would be actual images)
test_images = [f"test_image_{i}.png" for i in range(8)]
# Test single processing
print("\n1. Testing single image processing:")
single_start = time.time()
single_results = []
for img in test_images[:2]: # Test with 2 images
result = processor.extract_text_from_image(img)
single_results.append(result)
single_time = time.time() - single_start
print(f" Processed {len(test_images[:2])} images in {single_time:.2f}s")
# Test batch processing
print("\n2. Testing batch processing (4 images/batch):")
batch_start = time.time()
batch_results = processor.extract_text_from_images_batch(test_images)
batch_time = time.time() - batch_start
print(f" Processed {len(test_images)} images in {batch_time:.2f}s")
print(f" Average time per image: {batch_time/len(test_images):.2f}s")
# Show metrics
print("\n3. Performance Metrics:")
metrics = processor.get_metrics()
for key, value in metrics.items():
if key not in ["batch_processing_times", "errors", "bboxes"]:
print(f" {key}: {value}")
processor.close()
return single_results, batch_results
if __name__ == "__main__":
# Run performance test
print("Running optimized OCR processor performance test...")
test_optimized_ocr_performance()