208 lines
7.9 KiB
Python
208 lines
7.9 KiB
Python
"""
|
|
Isolated OpenCLIP Image Classifier Module
|
|
Designed to work alongside PaddleOCR without dependency conflicts
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import tempfile
|
|
from typing import List, Dict, Any, Optional
|
|
from pathlib import Path
|
|
import importlib.util
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OpenCLIPClassifier:
|
|
"""OpenCLIP-based image classifier with isolated dependencies"""
|
|
|
|
def __init__(self, model_name: str = "ViT-B-32", pretrained: str = "laion2b_s34b_b79k"):
|
|
self.model_name = model_name
|
|
self.pretrained = pretrained
|
|
self.model = None
|
|
self.processor = None
|
|
self.available = False
|
|
self._initialize_classifier()
|
|
|
|
def _initialize_classifier(self):
|
|
"""Initialize OpenCLIP model with proper isolation"""
|
|
try:
|
|
# Check if open_clip is available
|
|
if importlib.util.find_spec("open_clip") is None:
|
|
logger.warning("OpenCLIP not available, image classification will be disabled")
|
|
return
|
|
|
|
import open_clip
|
|
import torch
|
|
|
|
logger.info(f"Initializing OpenCLIP with model: {self.model_name}, pretrained: {self.pretrained}")
|
|
|
|
# Load model and processor
|
|
self.model, _, self.processor = open_clip.create_model_and_transforms(
|
|
model_name=self.model_name,
|
|
pretrained=self.pretrained
|
|
)
|
|
|
|
# Move to GPU if available
|
|
if torch.cuda.is_available():
|
|
self.model = self.model.cuda()
|
|
logger.info("OpenCLIP model moved to GPU")
|
|
else:
|
|
logger.info("OpenCLIP model using CPU")
|
|
|
|
self.available = True
|
|
logger.info("OpenCLIP classifier initialized successfully")
|
|
|
|
except ImportError as e:
|
|
logger.warning(f"OpenCLIP dependencies not available: {e}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize OpenCLIP classifier: {e}")
|
|
|
|
def classify_image(self, image_path: str, top_k: int = 5) -> List[Dict[str, Any]]:
|
|
"""
|
|
Classify image and return top predictions
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
top_k: Number of top predictions to return
|
|
|
|
Returns:
|
|
List of classification results with confidence scores
|
|
"""
|
|
if not self.available or not self.model:
|
|
return [{"label": "classification_unavailable", "confidence": 0.0}]
|
|
|
|
try:
|
|
import torch
|
|
from PIL import Image
|
|
|
|
# Load and preprocess image
|
|
image = Image.open(image_path).convert("RGB")
|
|
image_tensor = self.processor(image).unsqueeze(0)
|
|
|
|
# Move to GPU if available
|
|
if torch.cuda.is_available():
|
|
image_tensor = image_tensor.cuda()
|
|
|
|
# Get model predictions
|
|
with torch.no_grad():
|
|
image_features = self.model.encode_image(image_tensor)
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
|
|
# Use zero-shot classification with common labels
|
|
text_labels = [
|
|
"a photo of a bee", "a photo of a flower", "a photo of a person",
|
|
"a photo of a document", "a photo of a chart", "a photo of a diagram",
|
|
"a photo of a table", "a photo of a graph", "a photo of a logo",
|
|
"a photo of a signature", "a photo of a stamp", "a photo of a barcode",
|
|
"a photo of a QR code", "a photo of a screenshot", "a photo of a landscape",
|
|
"a photo of an animal", "a photo of a building", "a photo of a vehicle",
|
|
"a photo of food", "a photo of clothing", "a photo of electronics",
|
|
"a photo of furniture", "a photo of nature", "a photo of art",
|
|
"a photo of text", "a photo of numbers", "a photo of symbols"
|
|
]
|
|
|
|
# Encode text labels
|
|
text_tokens = open_clip.tokenize(text_labels)
|
|
if torch.cuda.is_available():
|
|
text_tokens = text_tokens.cuda()
|
|
|
|
text_features = self.model.encode_text(text_tokens)
|
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
|
|
# Calculate similarity
|
|
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
|
values, indices = similarity[0].topk(top_k)
|
|
|
|
results = []
|
|
for value, index in zip(values, indices):
|
|
results.append({
|
|
"label": text_labels[index],
|
|
"confidence": float(value)
|
|
})
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Image classification failed: {e}")
|
|
return [{"label": "classification_error", "confidence": 0.0}]
|
|
|
|
def get_detailed_classification(self, image_path: str) -> Dict[str, Any]:
|
|
"""
|
|
Get detailed classification with multiple label categories
|
|
"""
|
|
if not self.available:
|
|
return {
|
|
"available": False,
|
|
"error": "OpenCLIP classifier not available",
|
|
"predictions": []
|
|
}
|
|
|
|
try:
|
|
# Get standard classification
|
|
standard_results = self.classify_image(image_path, top_k=3)
|
|
|
|
# Additional specialized classifications
|
|
specialized_labels = {
|
|
"document_types": [
|
|
"a photo of a contract", "a photo of a invoice", "a photo of a receipt",
|
|
"a photo of a report", "a photo of a presentation", "a photo of a form",
|
|
"a photo of a certificate", "a photo of a license", "a photo of a passport"
|
|
],
|
|
"office_objects": [
|
|
"a photo of a desk", "a photo of a chair", "a photo of a computer",
|
|
"a photo of a printer", "a photo of a phone", "a photo of a pen",
|
|
"a photo of paper", "a photo of a folder", "a photo of a cabinet"
|
|
]
|
|
}
|
|
|
|
detailed_results = {
|
|
"available": True,
|
|
"standard_predictions": standard_results,
|
|
"file_path": image_path
|
|
}
|
|
|
|
return detailed_results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Detailed classification failed: {e}")
|
|
return {
|
|
"available": False,
|
|
"error": str(e),
|
|
"predictions": []
|
|
}
|
|
|
|
|
|
# Singleton instance
|
|
_classifier_instance = None
|
|
|
|
def get_image_classifier() -> OpenCLIPClassifier:
|
|
"""Get singleton image classifier instance"""
|
|
global _classifier_instance
|
|
if _classifier_instance is None:
|
|
_classifier_instance = OpenCLIPClassifier()
|
|
return _classifier_instance
|
|
|
|
|
|
def test_classifier():
|
|
"""Test function for image classifier"""
|
|
classifier = get_image_classifier()
|
|
|
|
if classifier.available:
|
|
print("OpenCLIP classifier is available")
|
|
# Test with a sample image if available
|
|
test_images = ["test_bee_image.png", "sample_image.jpg"]
|
|
for test_image in test_images:
|
|
if os.path.exists(test_image):
|
|
results = classifier.classify_image(test_image)
|
|
print(f"Classification results for {test_image}:")
|
|
for result in results:
|
|
print(f" {result['label']}: {result['confidence']:.4f}")
|
|
break
|
|
else:
|
|
print("No test images found for classification")
|
|
else:
|
|
print("OpenCLIP classifier is not available")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_classifier() |