139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
import sys
|
|
import json
|
|
import torch
|
|
import open_clip
|
|
from PIL import Image
|
|
import time
|
|
import logging
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class PersistentClassifier:
|
|
def __init__(self):
|
|
self.model = None
|
|
self.processor = None
|
|
self.text_features = None
|
|
self.text_labels = None
|
|
self.load_model()
|
|
|
|
def load_model(self):
|
|
"""Load model once and keep in memory"""
|
|
logger.info("Loading OpenCLIP model...")
|
|
start_time = time.time()
|
|
|
|
# Use smaller model for faster inference
|
|
self.model, _, self.processor = open_clip.create_model_and_transforms(
|
|
model_name="ViT-B-16", # Smaller than ViT-B-32 for speed
|
|
pretrained="laion2b_s34b_b88k"
|
|
)
|
|
|
|
# Optimized label set for document processing
|
|
self.text_labels = [
|
|
"a photo of a bee", "a photo of a flower", "a photo of a document",
|
|
"a photo of a chart", "a photo of a diagram", "a photo of a table",
|
|
"a photo of a graph", "a photo of a screenshot", "a photo of a logo",
|
|
"a photo of text", "a photo of a signature", "a photo of a barcode",
|
|
"a photo of a qr code", "a photo of a person", "a photo of a building"
|
|
]
|
|
|
|
# Move to GPU and enable optimizations
|
|
if torch.cuda.is_available():
|
|
self.model = self.model.half().cuda() # FP16 for speed
|
|
logger.info(f"Model loaded on GPU (FP16) in {time.time()-start_time:.2f}s")
|
|
else:
|
|
logger.warning("Using CPU - slower performance")
|
|
|
|
# Precompute text features once
|
|
with torch.no_grad():
|
|
text_tokens = open_clip.tokenize(self.text_labels)
|
|
if torch.cuda.is_available():
|
|
text_tokens = text_tokens.cuda()
|
|
self.text_features = self.model.encode_text(text_tokens)
|
|
self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
|
|
|
|
logger.info("Model and text features loaded successfully")
|
|
|
|
def classify_batch(self, image_paths, top_k=3):
|
|
"""Classify multiple images efficiently"""
|
|
results = []
|
|
|
|
for image_path in image_paths:
|
|
try:
|
|
# Load and process image
|
|
image = Image.open(image_path).convert("RGB")
|
|
image_tensor = self.processor(image).unsqueeze(0)
|
|
|
|
# Move to GPU if available
|
|
if torch.cuda.is_available():
|
|
image_tensor = image_tensor.half().cuda()
|
|
|
|
# Encode image and compute similarity
|
|
with torch.no_grad():
|
|
image_features = self.model.encode_image(image_tensor)
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
|
|
similarity = (100.0 * image_features @ self.text_features.T).softmax(dim=-1)
|
|
values, indices = similarity[0].topk(top_k)
|
|
|
|
image_results = []
|
|
for value, index in zip(values, indices):
|
|
image_results.append({
|
|
"label": self.text_labels[index],
|
|
"confidence": float(value)
|
|
})
|
|
|
|
results.append(image_results)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing {image_path}: {e}")
|
|
results.append([{"label": "processing_error", "confidence": 0.0}])
|
|
|
|
return results
|
|
|
|
# Create persistent classifier instance
|
|
classifier = PersistentClassifier()
|
|
|
|
# Main loop for processing requests
|
|
while True:
|
|
try:
|
|
# Read input from stdin
|
|
line = sys.stdin.readline().strip()
|
|
if not line:
|
|
continue
|
|
|
|
request = json.loads(line)
|
|
|
|
if request.get('action') == 'classify':
|
|
image_paths = request['image_paths']
|
|
top_k = request.get('top_k', 3)
|
|
|
|
start_time = time.time()
|
|
results = classifier.classify_batch(image_paths, top_k)
|
|
processing_time = time.time() - start_time
|
|
|
|
response = {
|
|
'success': True,
|
|
'results': results,
|
|
'processing_time': processing_time,
|
|
'images_processed': len(image_paths)
|
|
}
|
|
|
|
print(json.dumps(response))
|
|
sys.stdout.flush()
|
|
|
|
elif request.get('action') == 'ping':
|
|
print(json.dumps({'success': True, 'message': 'alive'}))
|
|
sys.stdout.flush()
|
|
|
|
elif request.get('action') == 'exit':
|
|
break
|
|
|
|
except Exception as e:
|
|
error_response = {
|
|
'success': False,
|
|
'error': str(e)
|
|
}
|
|
print(json.dumps(error_response))
|
|
sys.stdout.flush() |