551 lines
20 KiB
Python
551 lines
20 KiB
Python
"""
|
|
Optimized OCR Processor with Batch Processing and Async Support
|
|
Replaces the inefficient process-per-request approach with shared model instance
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import asyncio
|
|
import concurrent.futures
|
|
from typing import Dict, List, Any, Optional, Tuple
|
|
from dataclasses import dataclass
|
|
import tempfile
|
|
from pathlib import Path
|
|
import json
|
|
import time
|
|
|
|
# Configure logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BatchOCRResult:
|
|
"""Batch OCR processing result"""
|
|
image_path: str
|
|
text: str
|
|
confidence: float
|
|
bboxes: List
|
|
line_count: int
|
|
processing_time: float
|
|
error: Optional[str] = None
|
|
|
|
|
|
class OptimizedOCRProcessor:
|
|
"""
|
|
Optimized OCR processor with batch processing, shared model instance, and async support
|
|
"""
|
|
|
|
def __init__(self, use_gpu: bool = True, languages: List[str] = None,
|
|
batch_size: int = 4, max_workers: int = 2):
|
|
"""
|
|
Initialize optimized OCR processor
|
|
|
|
Args:
|
|
use_gpu: Whether to use GPU acceleration
|
|
languages: List of languages for OCR
|
|
batch_size: Number of images to process in each batch
|
|
max_workers: Maximum number of parallel workers
|
|
"""
|
|
self.use_gpu = use_gpu
|
|
self.languages = languages or ['en', 'ch']
|
|
self.batch_size = batch_size
|
|
self.max_workers = max_workers
|
|
self.ocr_available = False
|
|
self._ocr_engine = None
|
|
self._model_loaded = False
|
|
self._temp_dir = None
|
|
self._executor = None
|
|
|
|
# Performance metrics
|
|
self.metrics = {
|
|
"total_images_processed": 0,
|
|
"total_processing_time": 0.0,
|
|
"batch_processing_times": [],
|
|
"errors": []
|
|
}
|
|
|
|
self._initialize_ocr()
|
|
|
|
def _initialize_ocr(self):
|
|
"""Initialize PaddleOCR with shared model instance"""
|
|
try:
|
|
logger.info("Initializing optimized OCR processor with batch support")
|
|
|
|
# Test if PaddleOCR can be imported
|
|
try:
|
|
from paddleocr import PaddleOCR
|
|
self._ocr_engine = PaddleOCR(
|
|
use_gpu=self.use_gpu,
|
|
use_angle_cls=True,
|
|
lang='en',
|
|
show_log=False,
|
|
gpu_mem=2000
|
|
)
|
|
self._model_loaded = True
|
|
self.ocr_available = True
|
|
logger.info("PaddleOCR initialized successfully with shared model instance")
|
|
|
|
# Warm up the model with a dummy image
|
|
self._warm_up_model()
|
|
|
|
except ImportError as e:
|
|
logger.error(f"Failed to import PaddleOCR: {e}")
|
|
self.ocr_available = False
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize PaddleOCR: {e}")
|
|
self.ocr_available = False
|
|
return
|
|
|
|
# Create temporary directory for batch processing
|
|
self._temp_dir = tempfile.mkdtemp(prefix="optimized_ocr_")
|
|
|
|
# Create thread pool executor for async operations
|
|
self._executor = concurrent.futures.ThreadPoolExecutor(
|
|
max_workers=self.max_workers,
|
|
thread_name_prefix="ocr_worker"
|
|
)
|
|
|
|
logger.info(f"Optimized OCR processor ready (GPU: {self.use_gpu}, "
|
|
f"batch_size: {self.batch_size}, workers: {self.max_workers})")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize optimized OCR processor: {e}")
|
|
self.ocr_available = False
|
|
|
|
def _warm_up_model(self):
|
|
"""Warm up the OCR model with a dummy image to reduce first-call latency"""
|
|
try:
|
|
# Create a simple test image
|
|
import numpy as np
|
|
test_image = np.ones((100, 100, 3), dtype=np.uint8) * 255
|
|
|
|
# Warm up with a small batch
|
|
self._ocr_engine.ocr(test_image, cls=True)
|
|
logger.info("OCR model warmed up successfully")
|
|
except Exception as e:
|
|
logger.warning(f"Model warm-up failed: {e}")
|
|
|
|
def extract_text_from_image(self, image_path: str) -> Dict[str, Any]:
|
|
"""
|
|
Extract text from a single image (backward compatibility)
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
|
|
Returns:
|
|
OCR result dictionary
|
|
"""
|
|
start_time = time.time()
|
|
|
|
if not self.ocr_available:
|
|
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
|
|
|
|
try:
|
|
# Process single image using batch method for consistency
|
|
batch_result = self._process_batch_internal([image_path])
|
|
if batch_result and len(batch_result) > 0:
|
|
result = batch_result[0]
|
|
processing_time = time.time() - start_time
|
|
|
|
# Update metrics
|
|
self.metrics["total_images_processed"] += 1
|
|
self.metrics["total_processing_time"] += processing_time
|
|
|
|
return {
|
|
"text": result.text,
|
|
"confidence": result.confidence,
|
|
"bboxes": result.bboxes,
|
|
"line_count": result.line_count,
|
|
"processing_time": processing_time
|
|
}
|
|
else:
|
|
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing image {image_path}: {e}")
|
|
self.metrics["errors"].append(str(e))
|
|
return {"text": "", "confidence": 0.0, "bboxes": [], "line_count": 0}
|
|
|
|
def extract_text_from_images_batch(self, image_paths: List[str]) -> List[BatchOCRResult]:
|
|
"""
|
|
Extract text from multiple images in a batch
|
|
|
|
Args:
|
|
image_paths: List of image file paths
|
|
|
|
Returns:
|
|
List of BatchOCRResult objects
|
|
"""
|
|
if not image_paths:
|
|
return []
|
|
|
|
batch_start_time = time.time()
|
|
|
|
if not self.ocr_available:
|
|
return [BatchOCRResult(
|
|
image_path=path,
|
|
text="",
|
|
confidence=0.0,
|
|
bboxes=[],
|
|
line_count=0,
|
|
processing_time=0.0,
|
|
error="OCR not available"
|
|
) for path in image_paths]
|
|
|
|
try:
|
|
# Process images in batches
|
|
all_results = []
|
|
|
|
for i in range(0, len(image_paths), self.batch_size):
|
|
batch = image_paths[i:i + self.batch_size]
|
|
batch_results = self._process_batch_internal(batch)
|
|
all_results.extend(batch_results)
|
|
|
|
batch_time = time.time() - batch_start_time
|
|
|
|
# Update metrics
|
|
self.metrics["total_images_processed"] += len(image_paths)
|
|
self.metrics["total_processing_time"] += batch_time
|
|
self.metrics["batch_processing_times"].append({
|
|
"batch_size": len(image_paths),
|
|
"processing_time": batch_time,
|
|
"avg_time_per_image": batch_time / len(image_paths) if image_paths else 0
|
|
})
|
|
|
|
logger.info(f"Processed {len(image_paths)} images in {batch_time:.2f}s "
|
|
f"({batch_time/len(image_paths):.2f}s per image)")
|
|
|
|
return all_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in batch processing: {e}")
|
|
self.metrics["errors"].append(str(e))
|
|
|
|
# Fall back to individual processing
|
|
results = []
|
|
for path in image_paths:
|
|
try:
|
|
single_result = self.extract_text_from_image(path)
|
|
results.append(BatchOCRResult(
|
|
image_path=path,
|
|
text=single_result["text"],
|
|
confidence=single_result["confidence"],
|
|
bboxes=single_result["bboxes"],
|
|
line_count=single_result["line_count"],
|
|
processing_time=single_result.get("processing_time", 0.0)
|
|
))
|
|
except Exception as img_error:
|
|
results.append(BatchOCRResult(
|
|
image_path=path,
|
|
text="",
|
|
confidence=0.0,
|
|
bboxes=[],
|
|
line_count=0,
|
|
processing_time=0.0,
|
|
error=str(img_error)
|
|
))
|
|
|
|
return results
|
|
|
|
def _process_batch_internal(self, image_paths: List[str]) -> List[BatchOCRResult]:
|
|
"""
|
|
Internal batch processing method
|
|
|
|
Args:
|
|
image_paths: List of image file paths
|
|
|
|
Returns:
|
|
List of BatchOCRResult objects
|
|
"""
|
|
results = []
|
|
|
|
try:
|
|
# Process all images in the batch
|
|
for image_path in image_paths:
|
|
image_start_time = time.time()
|
|
|
|
try:
|
|
# Use the shared OCR engine
|
|
ocr_result = self._ocr_engine.ocr(image_path, cls=True)
|
|
|
|
# Parse OCR result
|
|
text_lines = []
|
|
bboxes = []
|
|
total_confidence = 0.0
|
|
line_count = 0
|
|
|
|
if ocr_result and ocr_result[0]:
|
|
for line in ocr_result[0]:
|
|
if line and len(line) >= 2:
|
|
bbox, (text, confidence) = line
|
|
text_lines.append(str(text))
|
|
bboxes.append(bbox)
|
|
total_confidence += float(confidence) if confidence else 0.0
|
|
line_count += 1
|
|
|
|
avg_confidence = total_confidence / line_count if line_count > 0 else 0.0
|
|
full_text = "\n".join(text_lines)
|
|
processing_time = time.time() - image_start_time
|
|
|
|
results.append(BatchOCRResult(
|
|
image_path=image_path,
|
|
text=full_text,
|
|
confidence=avg_confidence,
|
|
bboxes=bboxes,
|
|
line_count=line_count,
|
|
processing_time=processing_time
|
|
))
|
|
|
|
except Exception as e:
|
|
processing_time = time.time() - image_start_time
|
|
logger.error(f"Error processing {image_path}: {e}")
|
|
results.append(BatchOCRResult(
|
|
image_path=image_path,
|
|
text="",
|
|
confidence=0.0,
|
|
bboxes=[],
|
|
line_count=0,
|
|
processing_time=processing_time,
|
|
error=str(e)
|
|
))
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in internal batch processing: {e}")
|
|
raise
|
|
|
|
async def extract_text_from_images_batch_async(self, image_paths: List[str]) -> List[BatchOCRResult]:
|
|
"""
|
|
Async version of batch processing
|
|
|
|
Args:
|
|
image_paths: List of image file paths
|
|
|
|
Returns:
|
|
List of BatchOCRResult objects
|
|
"""
|
|
loop = asyncio.get_event_loop()
|
|
|
|
# Run batch processing in thread pool
|
|
future = loop.run_in_executor(
|
|
self._executor,
|
|
self.extract_text_from_images_batch,
|
|
image_paths
|
|
)
|
|
|
|
try:
|
|
results = await asyncio.wait_for(future, timeout=300) # 5 minute timeout
|
|
return results
|
|
except asyncio.TimeoutError:
|
|
logger.error("OCR batch processing timeout")
|
|
return [BatchOCRResult(
|
|
image_path=path,
|
|
text="",
|
|
confidence=0.0,
|
|
bboxes=[],
|
|
line_count=0,
|
|
processing_time=0.0,
|
|
error="Processing timeout"
|
|
) for path in image_paths]
|
|
|
|
def extract_tables_from_image(self, image_path: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Extract tables from image using OCR and layout analysis
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
|
|
Returns:
|
|
List of table dictionaries
|
|
"""
|
|
try:
|
|
# Use OCR to get text with bounding boxes
|
|
ocr_result = self.extract_text_from_image(image_path)
|
|
|
|
# Simple table detection based on text alignment
|
|
tables = self._detect_tables_from_bboxes(ocr_result["bboxes"], ocr_result["text"])
|
|
return tables
|
|
except Exception as e:
|
|
logger.error(f"Table extraction failed: {e}")
|
|
return []
|
|
|
|
def _detect_tables_from_bboxes(self, bboxes: List, text: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Detect tables from OCR bounding boxes (compatible with original implementation)
|
|
"""
|
|
tables = []
|
|
|
|
if not bboxes:
|
|
return tables
|
|
|
|
# Group text by rows based on y-coordinates
|
|
rows = {}
|
|
text_lines = text.split('\n') if text else []
|
|
|
|
for i, bbox in enumerate(bboxes):
|
|
try:
|
|
if not bbox:
|
|
continue
|
|
|
|
# Calculate y-center of bounding box
|
|
y_values = []
|
|
for point in bbox:
|
|
if point and len(point) >= 2:
|
|
try:
|
|
y_val = point[1]
|
|
if isinstance(y_val, (int, float)):
|
|
y_values.append(float(y_val))
|
|
elif isinstance(y_val, str):
|
|
y_values.append(float(y_val))
|
|
else:
|
|
y_values.append(0.0)
|
|
except (TypeError, ValueError):
|
|
y_values.append(0.0)
|
|
else:
|
|
y_values.append(0.0)
|
|
|
|
if y_values:
|
|
y_center = sum(y_values) / len(y_values)
|
|
else:
|
|
y_center = 0.0
|
|
|
|
row_key = round(y_center / 10) # Group by 10-pixel rows
|
|
|
|
if row_key not in rows:
|
|
rows[row_key] = []
|
|
|
|
row_text = text_lines[i] if i < len(text_lines) else ""
|
|
rows[row_key].append((bbox, row_text))
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error processing bbox {i}: {e}")
|
|
continue
|
|
|
|
# Sort rows and create table structure
|
|
sorted_rows = sorted(rows.keys())
|
|
table_data = []
|
|
|
|
for row_key in sorted_rows:
|
|
try:
|
|
def get_x_coordinate(item):
|
|
try:
|
|
if (item[0] and len(item[0]) > 0 and
|
|
item[0][0] and len(item[0][0]) > 0):
|
|
x_val = item[0][0][0]
|
|
return float(x_val) if x_val is not None else 0.0
|
|
return 0.0
|
|
except (TypeError, ValueError, IndexError):
|
|
return 0.0
|
|
|
|
row_items = sorted(rows[row_key], key=get_x_coordinate)
|
|
row_text = [item[1] for item in row_items]
|
|
table_data.append(row_text)
|
|
except Exception as e:
|
|
logger.warning(f"Error sorting row {row_key}: {e}")
|
|
continue
|
|
|
|
if len(table_data) > 1: # At least 2 rows for a table
|
|
tables.append({
|
|
"data": table_data,
|
|
"rows": len(table_data),
|
|
"columns": max(len(row) for row in table_data) if table_data else 0
|
|
})
|
|
|
|
return tables
|
|
|
|
def get_metrics(self) -> Dict[str, Any]:
|
|
"""
|
|
Get performance metrics
|
|
|
|
Returns:
|
|
Dictionary of performance metrics
|
|
"""
|
|
avg_time_per_image = 0.0
|
|
if self.metrics["total_images_processed"] > 0:
|
|
avg_time_per_image = self.metrics["total_processing_time"] / self.metrics["total_images_processed"]
|
|
|
|
return {
|
|
**self.metrics,
|
|
"avg_time_per_image": avg_time_per_image,
|
|
"ocr_available": self.ocr_available,
|
|
"model_loaded": self._model_loaded,
|
|
"batch_size": self.batch_size,
|
|
"max_workers": self.max_workers,
|
|
"use_gpu": self.use_gpu
|
|
}
|
|
|
|
def close(self):
|
|
"""Clean up resources"""
|
|
if self._executor:
|
|
self._executor.shutdown(wait=True)
|
|
self._executor = None
|
|
|
|
if self._temp_dir and os.path.exists(self._temp_dir):
|
|
import shutil
|
|
try:
|
|
shutil.rmtree(self._temp_dir)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete temporary directory: {e}")
|
|
|
|
self._ocr_engine = None
|
|
self._model_loaded = False
|
|
self.ocr_available = False
|
|
|
|
logger.info("Optimized OCR processor closed")
|
|
|
|
def __del__(self):
|
|
"""Destructor to ensure cleanup"""
|
|
self.close()
|
|
|
|
|
|
# Performance test function
|
|
def test_optimized_ocr_performance():
|
|
"""Test optimized OCR processor performance"""
|
|
print("=== Testing Optimized OCR Processor ===")
|
|
|
|
# Create test processor
|
|
processor = OptimizedOCRProcessor(use_gpu=True, batch_size=4, max_workers=2)
|
|
|
|
if not processor.ocr_available:
|
|
print("OCR not available, skipping test")
|
|
return
|
|
|
|
# Generate dummy image paths (in real test, these would be actual images)
|
|
test_images = [f"test_image_{i}.png" for i in range(8)]
|
|
|
|
# Test single processing
|
|
print("\n1. Testing single image processing:")
|
|
single_start = time.time()
|
|
single_results = []
|
|
for img in test_images[:2]: # Test with 2 images
|
|
result = processor.extract_text_from_image(img)
|
|
single_results.append(result)
|
|
single_time = time.time() - single_start
|
|
print(f" Processed {len(test_images[:2])} images in {single_time:.2f}s")
|
|
|
|
# Test batch processing
|
|
print("\n2. Testing batch processing (4 images/batch):")
|
|
batch_start = time.time()
|
|
batch_results = processor.extract_text_from_images_batch(test_images)
|
|
batch_time = time.time() - batch_start
|
|
print(f" Processed {len(test_images)} images in {batch_time:.2f}s")
|
|
print(f" Average time per image: {batch_time/len(test_images):.2f}s")
|
|
|
|
# Show metrics
|
|
print("\n3. Performance Metrics:")
|
|
metrics = processor.get_metrics()
|
|
for key, value in metrics.items():
|
|
if key not in ["batch_processing_times", "errors", "bboxes"]:
|
|
print(f" {key}: {value}")
|
|
|
|
processor.close()
|
|
|
|
return single_results, batch_results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Run performance test
|
|
print("Running optimized OCR processor performance test...")
|
|
test_optimized_ocr_performance() |