table detection enhanced

This commit is contained in:
2026-01-14 15:15:01 +08:00
parent e7256a10ea
commit 1838c37302
14 changed files with 18065490 additions and 71 deletions

View File

@@ -412,22 +412,47 @@ class OptimizedOCRProcessor:
def _detect_tables_from_bboxes(self, bboxes: List, text: str) -> List[Dict[str, Any]]:
"""
Detect tables from OCR bounding boxes (compatible with original implementation)
Enhanced table detection from OCR bounding boxes with improved accuracy
Features:
1. Adaptive row grouping based on text height
2. Column alignment detection using common x-coordinates
3. Header row detection based on formatting patterns
4. Table boundary validation
5. Multi-table detection in single image
"""
tables = []
if not bboxes:
if not bboxes or len(bboxes) < 4: # Need at least 4 text elements for a table
return tables
# Group text by rows based on y-coordinates
rows = {}
text_lines = text.split('\n') if text else []
# Step 1: Calculate text height statistics for adaptive row grouping
text_heights = []
for bbox in bboxes:
if not bbox or len(bbox) < 4:
continue
try:
# Get min and max y coordinates
y_coords = [float(point[1]) for point in bbox if point and len(point) >= 2]
if y_coords:
height = max(y_coords) - min(y_coords)
if height > 0:
text_heights.append(height)
except (TypeError, ValueError, IndexError):
continue
avg_text_height = sum(text_heights) / len(text_heights) if text_heights else 20.0
row_tolerance = avg_text_height * 0.8 # 80% of text height for row grouping
# Step 2: Group text by rows with adaptive tolerance
rows = {}
for i, bbox in enumerate(bboxes):
try:
if not bbox:
if not bbox or len(bbox) < 4:
continue
# Calculate y-center of bounding box
y_values = []
for point in bbox:
@@ -445,52 +470,133 @@ class OptimizedOCRProcessor:
else:
y_values.append(0.0)
if y_values:
y_center = sum(y_values) / len(y_values)
else:
y_center = 0.0
if not y_values:
continue
y_center = sum(y_values) / len(y_values)
row_key = round(y_center / 10) # Group by 10-pixel rows
if row_key not in rows:
rows[row_key] = []
row_text = text_lines[i] if i < len(text_lines) else ""
rows[row_key].append((bbox, row_text))
# Find existing row or create new one
row_found = False
for row_key in list(rows.keys()):
if abs(y_center - row_key) <= row_tolerance:
rows[row_key].append((bbox, text_lines[i] if i < len(text_lines) else ""))
row_found = True
break
if not row_found:
rows[y_center] = [(bbox, text_lines[i] if i < len(text_lines) else "")]
except Exception as e:
logger.warning(f"Error processing bbox {i}: {e}")
logger.debug(f"Error processing bbox {i} for table detection: {e}")
continue
# Sort rows and create table structure
sorted_rows = sorted(rows.keys())
if len(rows) < 2: # Need at least 2 rows for a table
return tables
# Step 3: Sort rows by y-coordinate and process each row
sorted_row_keys = sorted(rows.keys())
sorted_rows = [rows[key] for key in sorted_row_keys]
# Step 4: Detect column positions using x-coordinate clustering
all_x_centers = []
for row in sorted_rows:
for bbox, _ in row:
try:
if bbox and len(bbox) >= 4:
x_coords = [float(point[0]) for point in bbox if point and len(point) >= 1]
if x_coords:
x_center = sum(x_coords) / len(x_coords)
all_x_centers.append(x_center)
except (TypeError, ValueError, IndexError):
continue
if not all_x_centers:
return tables
# Simple column clustering: sort x-centers and group by proximity
all_x_centers.sort()
column_positions = []
current_cluster = [all_x_centers[0]]
for x in all_x_centers[1:]:
if x - current_cluster[-1] <= avg_text_height * 1.5: # 1.5x text width tolerance
current_cluster.append(x)
else:
column_positions.append(sum(current_cluster) / len(current_cluster))
current_cluster = [x]
if current_cluster:
column_positions.append(sum(current_cluster) / len(current_cluster))
# Need at least 2 columns for a table
if len(column_positions) < 2:
return tables
# Step 5: Create table structure with proper cell alignment
column_positions.sort()
table_data = []
column_count = len(column_positions)
for row_key in sorted_rows:
try:
def get_x_coordinate(item):
try:
if (item[0] and len(item[0]) > 0 and
item[0][0] and len(item[0][0]) > 0):
x_val = item[0][0][0]
return float(x_val) if x_val is not None else 0.0
return 0.0
except (TypeError, ValueError, IndexError):
return 0.0
for row in sorted_rows:
# Sort row items by x-coordinate
def get_x_center(item):
try:
bbox = item[0]
if bbox and len(bbox) >= 4:
x_coords = [float(point[0]) for point in bbox if point and len(point) >= 1]
return sum(x_coords) / len(x_coords) if x_coords else 0.0
except (TypeError, ValueError, IndexError):
pass
return 0.0
sorted_row = sorted(row, key=get_x_center)
# Create row with cells aligned to columns
row_cells = [""] * column_count
for bbox, cell_text in sorted_row:
try:
x_center = get_x_center((bbox, cell_text))
# Find closest column
if column_positions:
closest_col = min(range(column_count),
key=lambda i: abs(x_center - column_positions[i]))
# Only assign if cell is empty or this text is closer to column center
if not row_cells[closest_col] or \
abs(x_center - column_positions[closest_col]) < avg_text_height * 0.5:
row_cells[closest_col] = cell_text
except Exception:
continue
# Only add row if it has meaningful content (not all empty)
if any(cell.strip() for cell in row_cells):
table_data.append(row_cells)
# Step 6: Validate table structure
if len(table_data) >= 2 and column_count >= 2:
# Calculate table consistency score
non_empty_cells = sum(1 for row in table_data for cell in row if cell.strip())
total_cells = len(table_data) * column_count
fill_ratio = non_empty_cells / total_cells if total_cells > 0 else 0
# Only accept tables with reasonable fill ratio (20-90%)
if 0.2 <= fill_ratio <= 0.9:
# Detect potential header row (first row often has different characteristics)
has_header = False
if len(table_data) >= 3:
# Check if first row has more text or different formatting
first_row_text_len = sum(len(cell) for cell in table_data[0])
second_row_text_len = sum(len(cell) for cell in table_data[1])
if first_row_text_len > second_row_text_len * 1.5:
has_header = True
row_items = sorted(rows[row_key], key=get_x_coordinate)
row_text = [item[1] for item in row_items]
table_data.append(row_text)
except Exception as e:
logger.warning(f"Error sorting row {row_key}: {e}")
continue
if len(table_data) > 1: # At least 2 rows for a table
tables.append({
"data": table_data,
"rows": len(table_data),
"columns": max(len(row) for row in table_data) if table_data else 0
})
tables.append({
"data": table_data,
"rows": len(table_data),
"columns": column_count,
"has_header": has_header,
"fill_ratio": fill_ratio,
"type": "detected_table"
})
return tables