Files
railseek6/openclip_classifier_fixed.py

85 lines
2.7 KiB
Python

import sys
import os
import json
import tempfile
from pathlib import Path
def classify_image(image_path):
"""
Classify image using OpenCLIP in isolated environment
"""
try:
# Import OpenCLIP (this runs in the isolated environment)
import open_clip
import torch
from PIL import Image
# Check CUDA - force CPU for now to avoid conflicts
device = "cpu" # Force CPU to avoid CUDA conflicts with PaddleOCR
print("Using device: " + device)
# Load model and processor
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model = model.to(device)
# Load and preprocess image
image = Image.open(image_path).convert('RGB')
image = preprocess(image).unsqueeze(0).to(device)
# Define candidate labels (including bee)
candidate_labels = [
"a bee", "an insect", "an animal", "a flower", "a plant",
"a bird", "a butterfly", "a dragonfly", "a bug", "a honeybee",
"clipart", "cartoon", "illustration", "drawing", "logo"
]
# Get text features
text = open_clip.tokenize(candidate_labels).to(device)
with torch.no_grad():
# Get image and text features
image_features = model.encode_image(image)
text_features = model.encode_text(text)
# Calculate similarity
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
# Get top predictions
values, indices = similarity[0].topk(3)
results = []
for value, idx in zip(values, indices):
results.append({
"label": candidate_labels[idx],
"score": round(value.item(), 3)
})
return {
"success": True,
"predictions": results,
"device": device
}
except Exception as e:
return {
"success": False,
"error": str(e),
"predictions": []
}
if __name__ == "__main__":
# Read image path from command line
if len(sys.argv) > 1:
image_path = sys.argv[1]
result = classify_image(image_path)
print(json.dumps(result))
else:
print(json.dumps({
"success": False,
"error": "No image path provided",
"predictions": []
}))