221 lines
8.4 KiB
Python
221 lines
8.4 KiB
Python
"""
|
|
Persistent Classifier Client - Communicates with persistent classifier process
|
|
Provides fast image classification with minimal overhead
|
|
"""
|
|
import os
|
|
import json
|
|
import subprocess
|
|
import logging
|
|
import time
|
|
from typing import List, Dict, Any
|
|
from pathlib import Path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PersistentClassifierClient:
|
|
"""Client for persistent image classifier process"""
|
|
|
|
def __init__(self):
|
|
self.venv_python = "openclip_gpu_env\\Scripts\\python.exe"
|
|
self.process = None
|
|
self.available = False
|
|
self._start_process()
|
|
|
|
def _start_process(self):
|
|
"""Start the persistent classifier process"""
|
|
try:
|
|
if not os.path.exists(self.venv_python):
|
|
logger.error(f"Virtual environment not found: {self.venv_python}")
|
|
return
|
|
|
|
# Start persistent classifier process
|
|
self.process = subprocess.Popen([
|
|
self.venv_python, "persistent_classifier.py"
|
|
], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
|
text=True, bufsize=1, universal_newlines=True)
|
|
|
|
# Wait for process to initialize
|
|
time.sleep(5)
|
|
|
|
# Test connection
|
|
if self._ping():
|
|
self.available = True
|
|
logger.info("✅ Persistent classifier process started successfully")
|
|
else:
|
|
logger.error("❌ Failed to start persistent classifier process")
|
|
self._stop_process()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to start persistent classifier: {e}")
|
|
self._stop_process()
|
|
|
|
def _ping(self):
|
|
"""Check if the classifier process is responsive"""
|
|
try:
|
|
if not self.process or self.process.poll() is not None:
|
|
return False
|
|
|
|
request = {"action": "ping"}
|
|
self.process.stdin.write(json.dumps(request) + '\n')
|
|
self.process.stdin.flush()
|
|
|
|
# Wait for response with timeout
|
|
start_time = time.time()
|
|
while time.time() - start_time < 5: # 5 second timeout
|
|
line = self.process.stdout.readline().strip()
|
|
if line:
|
|
response = json.loads(line)
|
|
return response.get('success', False)
|
|
time.sleep(0.1)
|
|
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Ping failed: {e}")
|
|
return False
|
|
|
|
def classify_images_batch(self, image_paths: List[str], top_k: int = 3) -> List[List[Dict[str, Any]]]:
|
|
"""Classify multiple images using persistent process"""
|
|
if not self.available or not self.process:
|
|
logger.error("Persistent classifier not available")
|
|
return [[{"label": "service_unavailable", "confidence": 0.0}] for _ in image_paths]
|
|
|
|
try:
|
|
# Prepare request
|
|
request = {
|
|
"action": "classify",
|
|
"image_paths": image_paths,
|
|
"top_k": top_k
|
|
}
|
|
|
|
# Send request
|
|
self.process.stdin.write(json.dumps(request) + '\n')
|
|
self.process.stdin.flush()
|
|
|
|
# Read response with timeout
|
|
start_time = time.time()
|
|
while time.time() - start_time < 30: # 30 second timeout
|
|
line = self.process.stdout.readline().strip()
|
|
if line:
|
|
response = json.loads(line)
|
|
if response.get('success'):
|
|
logger.info(f"✅ Batch classification completed: {len(image_paths)} images in {response.get('processing_time', 0):.3f}s")
|
|
return response.get('results', [])
|
|
else:
|
|
logger.error(f"Classification failed: {response.get('error')}")
|
|
break
|
|
time.sleep(0.1)
|
|
|
|
logger.error("Classification timeout")
|
|
return [[{"label": "timeout_error", "confidence": 0.0}] for _ in image_paths]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Classification request failed: {e}")
|
|
return [[{"label": "request_error", "confidence": 0.0}] for _ in image_paths]
|
|
|
|
def classify_image(self, image_path: str, top_k: int = 3) -> List[Dict[str, Any]]:
|
|
"""Classify a single image (convenience wrapper)"""
|
|
results = self.classify_images_batch([image_path], top_k)
|
|
return results[0] if results else [{"label": "error", "confidence": 0.0}]
|
|
|
|
def _stop_process(self):
|
|
"""Stop the persistent classifier process"""
|
|
if self.process and self.process.poll() is None:
|
|
try:
|
|
# Send exit command
|
|
request = {"action": "exit"}
|
|
self.process.stdin.write(json.dumps(request) + '\n')
|
|
self.process.stdin.flush()
|
|
self.process.wait(timeout=5)
|
|
except:
|
|
self.process.terminate()
|
|
finally:
|
|
self.process = None
|
|
self.available = False
|
|
|
|
def __del__(self):
|
|
"""Cleanup when object is destroyed"""
|
|
self._stop_process()
|
|
|
|
|
|
def test_persistent_classifier():
|
|
"""Test the persistent classifier performance"""
|
|
print("🧪 TESTING PERSISTENT CLASSIFIER")
|
|
print("=" * 40)
|
|
|
|
client = PersistentClassifierClient()
|
|
|
|
if client.available:
|
|
print("✅ Persistent classifier available")
|
|
|
|
# Create test images
|
|
from PIL import Image
|
|
import tempfile
|
|
|
|
test_images = []
|
|
for i in range(3):
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
|
|
img_path = f.name
|
|
# Create different colored test images
|
|
colors = ['red', 'green', 'blue']
|
|
img = Image.new('RGB', (224, 224), color=colors[i])
|
|
img.save(img_path)
|
|
test_images.append(img_path)
|
|
|
|
# Test single classification
|
|
print("Testing single classification...")
|
|
start_time = time.time()
|
|
results = client.classify_image(test_images[0])
|
|
single_time = time.time() - start_time
|
|
print(f"Single classification: {single_time:.3f}s")
|
|
print(f"Results: {results}")
|
|
|
|
# Test batch classification
|
|
print("Testing batch classification (3 images)...")
|
|
start_time = time.time()
|
|
batch_results = client.classify_images_batch(test_images)
|
|
batch_time = time.time() - start_time
|
|
print(f"Batch classification (3 images): {batch_time:.3f}s")
|
|
print(f"Per image: {batch_time/3:.3f}s")
|
|
|
|
if batch_time > 0:
|
|
print(f"Performance improvement: {single_time*3/batch_time:.1f}x faster")
|
|
|
|
# Cleanup
|
|
for img_path in test_images:
|
|
os.unlink(img_path)
|
|
|
|
# Test with actual test.docx bee image (if extracted)
|
|
print("\nTesting with actual test.docx bee image...")
|
|
if os.path.exists("test.docx"):
|
|
# Extract images from test.docx
|
|
from word_image_extractor import extract_images_from_docx
|
|
extracted_images = extract_images_from_docx("test.docx", "extracted_images_test")
|
|
|
|
if extracted_images:
|
|
print(f"Found {len(extracted_images)} images in test.docx")
|
|
|
|
# Test classification on first image
|
|
start_time = time.time()
|
|
bee_results = client.classify_image(extracted_images[0])
|
|
bee_time = time.time() - start_time
|
|
print(f"Bee image classification: {bee_time:.3f}s")
|
|
print(f"Results: {bee_results}")
|
|
|
|
# Check for bee detection
|
|
bee_found = any("bee" in result["label"].lower() for result in bee_results)
|
|
if bee_found:
|
|
print("🎯 BEE DETECTED SUCCESSFULLY!")
|
|
else:
|
|
print("⚠️ Bee not detected in results")
|
|
else:
|
|
print("⚠️ No images extracted from test.docx")
|
|
else:
|
|
print("⚠️ test.docx not found for bee image test")
|
|
|
|
else:
|
|
print("❌ Persistent classifier not available")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_persistent_classifier() |