Files
railseek6/optimized_image_classifier.py

353 lines
13 KiB
Python

"""
Optimized Image Classifier with Persistent GPU Service
Eliminates subprocess overhead by running a persistent classification service
"""
import os
import json
import tempfile
import subprocess
import logging
import threading
import queue
import time
from pathlib import Path
from typing import List, Dict, Any
logger = logging.getLogger(__name__)
class OptimizedImageClassifier:
"""Image classifier with persistent GPU service for maximum performance"""
def __init__(self):
self.available = False
self.service_process = None
self.request_queue = queue.Queue()
self.response_queue = queue.Queue()
self.service_thread = None
self._check_availability()
def _check_availability(self):
"""Check if OpenCLIP is available in the isolated virtual environment"""
try:
venv_python = "openclip_gpu_env\\Scripts\\python.exe"
if not os.path.exists(venv_python):
raise RuntimeError(f"Virtual environment not found: {venv_python}")
# Quick availability check
result = subprocess.run([
venv_python, '-c',
'try: import open_clip; print("SUCCESS"); exit(0)\nexcept Exception as e: print(f"ERROR: {e}"); exit(1)'
], capture_output=True, text=True, timeout=30)
if result.returncode == 0:
self.available = True
logger.info("OpenCLIP is available - starting optimized service")
self._start_service()
else:
error_msg = result.stderr if result.stderr else "Unknown error"
raise RuntimeError(f"OpenCLIP check failed: {error_msg}")
except Exception as e:
logger.error(f"OpenCLIP availability check failed: {e}")
raise RuntimeError(f"OpenCLIP availability check failed: {e}")
def _start_service(self):
"""Start the persistent classification service"""
try:
# Start the service process
service_script = self._create_service_script()
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
f.write(service_script)
script_path = f.name
self.service_process = subprocess.Popen(
["openclip_gpu_env\\Scripts\\python.exe", script_path],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1
)
# Wait for service to initialize
time.sleep(3)
logger.info("Persistent classification service started")
except Exception as e:
logger.error(f"Failed to start classification service: {e}")
self.available = False
def _create_service_script(self):
"""Create the persistent service script"""
return '''
import sys
import json
import torch
import open_clip
from PIL import Image
import os
import time
class ClassificationService:
def __init__(self):
self.model = None
self.processor = None
self.text_features = None
self.text_labels = None
self.initialized = False
def initialize(self):
"""Initialize model and precompute text features"""
if self.initialized:
return
print("INITIALIZING: Loading OpenCLIP model...", file=sys.stderr)
start_time = time.time()
# Use smaller model for faster inference
self.model, _, self.processor = open_clip.create_model_and_transforms(
model_name="ViT-B-16", # Smaller than ViT-B-32
pretrained="laion2b_s34b_b79k"
)
# Move to GPU and enable half-precision for speed
if torch.cuda.is_available():
self.model = self.model.half().cuda() # FP16 for speed
print(f"INITIALIZED: Model on GPU (FP16) - {time.time()-start_time:.2f}s", file=sys.stderr)
else:
print("WARNING: Using CPU - slower performance", file=sys.stderr)
# Reduced label set for document processing
self.text_labels = [
"a photo of a bee", "a photo of a flower", "a photo of a document",
"a photo of a chart", "a photo of a diagram", "a photo of a table",
"a photo of a graph", "a photo of a screenshot", "a photo of a logo",
"a photo of text"
]
# Precompute text features once
with torch.no_grad():
text_tokens = open_clip.tokenize(self.text_labels)
if torch.cuda.is_available():
text_tokens = text_tokens.cuda()
self.text_features = self.model.encode_text(text_tokens)
self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
self.initialized = True
print("READY: Service initialized and ready", file=sys.stderr)
def classify_image(self, image_path):
"""Classify a single image"""
if not self.initialized:
self.initialize()
try:
# Load and process image
image = Image.open(image_path).convert("RGB")
image_tensor = self.processor(image).unsqueeze(0)
# Move to GPU if available
if torch.cuda.is_available():
image_tensor = image_tensor.half().cuda() # FP16
# Encode image and compute similarity
with torch.no_grad():
image_features = self.model.encode_image(image_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ self.text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(3)
results = []
for value, index in zip(values, indices):
results.append({
"label": self.text_labels[index],
"confidence": float(value)
})
return {"success": True, "results": results}
except Exception as e:
return {"success": False, "error": str(e)}
def classify_batch(self, image_paths):
"""Classify multiple images in batch (even faster)"""
if not self.initialized:
self.initialize()
try:
# Process all images in batch
batch_tensors = []
for image_path in image_paths:
image = Image.open(image_path).convert("RGB")
image_tensor = self.processor(image).unsqueeze(0)
batch_tensors.append(image_tensor)
# Stack and move to GPU
batch_tensor = torch.cat(batch_tensors, dim=0)
if torch.cuda.is_available():
batch_tensor = batch_tensor.half().cuda()
# Batch encode and compute similarities
with torch.no_grad():
image_features = self.model.encode_image(batch_tensor)
image_features /= image_features.norm(dim=-1, keepdim=True)
batch_results = []
for i in range(len(image_paths)):
similarity = (100.0 * image_features[i] @ self.text_features.T).softmax(dim=-1)
values, indices = similarity.topk(3)
results = []
for value, index in zip(values, indices):
results.append({
"label": self.text_labels[index],
"confidence": float(value)
})
batch_results.append(results)
return {"success": True, "batch_results": batch_results}
except Exception as e:
return {"success": False, "error": str(e)}
# Start service
service = ClassificationService()
service.initialize()
print("SERVICE_READY", file=sys.stderr, flush=True)
# Main service loop
for line in sys.stdin:
try:
request = json.loads(line.strip())
if request["type"] == "classify":
result = service.classify_image(request["image_path"])
print(json.dumps(result), flush=True)
elif request["type"] == "classify_batch":
result = service.classify_batch(request["image_paths"])
print(json.dumps(result), flush=True)
elif request["type"] == "ping":
print(json.dumps({"success": True, "message": "pong"}), flush=True)
except Exception as e:
print(json.dumps({"success": False, "error": str(e)}), flush=True)
'''
def classify_image(self, image_path: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""Classify image using persistent service"""
if not self.available or not self.service_process:
return [{"label": "service_unavailable", "confidence": 0.0}]
try:
# Send classification request
request = {
"type": "classify",
"image_path": image_path
}
self.service_process.stdin.write(json.dumps(request) + '\n')
self.service_process.stdin.flush()
# Read response
response_line = self.service_process.stdout.readline()
response = json.loads(response_line)
if response.get("success"):
return response["results"][:top_k]
else:
logger.error(f"Classification failed: {response.get('error')}")
return [{"label": "classification_error", "confidence": 0.0}]
except Exception as e:
logger.error(f"Service communication failed: {e}")
return [{"label": "service_error", "confidence": 0.0}]
def classify_images_batch(self, image_paths: List[str], top_k: int = 5) -> List[List[Dict[str, Any]]]:
"""Classify multiple images in batch (much faster)"""
if not self.available or not self.service_process:
return [[{"label": "service_unavailable", "confidence": 0.0}] for _ in image_paths]
try:
# Send batch classification request
request = {
"type": "classify_batch",
"image_paths": image_paths
}
self.service_process.stdin.write(json.dumps(request) + '\n')
self.service_process.stdin.flush()
# Read response
response_line = self.service_process.stdout.readline()
response = json.loads(response_line)
if response.get("success"):
return [results[:top_k] for results in response["batch_results"]]
else:
logger.error(f"Batch classification failed: {response.get('error')}")
return [[{"label": "classification_error", "confidence": 0.0}] for _ in image_paths]
except Exception as e:
logger.error(f"Batch service communication failed: {e}")
return [[{"label": "service_error", "confidence": 0.0}] for _ in image_paths]
def __del__(self):
"""Cleanup service process"""
if self.service_process:
self.service_process.terminate()
self.service_process.wait()
# Test the optimized classifier
def test_optimized_classifier():
"""Test the optimized classifier performance"""
classifier = OptimizedImageClassifier()
if classifier.available:
print("✅ Optimized classifier available")
# Test with a simple image
from PIL import Image
import tempfile
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
img_path = f.name
# Create test image
img = Image.new('RGB', (224, 224), color='red')
img.save(img_path)
# Test single classification
print("Testing single classification...")
start_time = time.time()
results = classifier.classify_image(img_path)
single_time = time.time() - start_time
print(f"Single classification: {single_time:.3f}s")
print(f"Results: {results}")
# Test batch classification (simulate 8 images like test.docx)
test_paths = [img_path] * 8
print("Testing batch classification (8 images)...")
start_time = time.time()
batch_results = classifier.classify_images_batch(test_paths)
batch_time = time.time() - start_time
print(f"Batch classification (8 images): {batch_time:.3f}s")
print(f"Per image: {batch_time/8:.3f}s")
# Cleanup
os.unlink(img_path)
print(f"Performance improvement: {single_time*8/batch_time:.1f}x faster")
else:
print("❌ Optimized classifier not available")
if __name__ == "__main__":
test_optimized_classifier()