85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
|
|
import sys
|
|
import os
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
def classify_image(image_path):
|
|
"""
|
|
Classify image using OpenCLIP in isolated environment
|
|
"""
|
|
try:
|
|
# Import OpenCLIP (this runs in the isolated environment)
|
|
import open_clip
|
|
import torch
|
|
from PIL import Image
|
|
|
|
# Check CUDA
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"🔍 Using device: {device}")
|
|
|
|
# Load model and processor
|
|
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
|
|
model = model.to(device)
|
|
|
|
# Load and preprocess image
|
|
image = Image.open(image_path).convert('RGB')
|
|
image = preprocess(image).unsqueeze(0).to(device)
|
|
|
|
# Define candidate labels (including bee)
|
|
candidate_labels = [
|
|
"a bee", "an insect", "an animal", "a flower", "a plant",
|
|
"a bird", "a butterfly", "a dragonfly", "a bug", "a honeybee",
|
|
"clipart", "cartoon", "illustration", "drawing", "logo"
|
|
]
|
|
|
|
# Get text features
|
|
text = open_clip.tokenize(candidate_labels).to(device)
|
|
|
|
with torch.no_grad():
|
|
# Get image and text features
|
|
image_features = model.encode_image(image)
|
|
text_features = model.encode_text(text)
|
|
|
|
# Calculate similarity
|
|
image_features /= image_features.norm(dim=-1, keepdim=True)
|
|
text_features /= text_features.norm(dim=-1, keepdim=True)
|
|
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
|
|
|
|
# Get top predictions
|
|
values, indices = similarity[0].topk(3)
|
|
|
|
results = []
|
|
for value, idx in zip(values, indices):
|
|
results.append({
|
|
"label": candidate_labels[idx],
|
|
"score": round(value.item(), 3)
|
|
})
|
|
|
|
return {
|
|
"success": True,
|
|
"predictions": results,
|
|
"device": device
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"predictions": []
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
# Read image path from command line
|
|
if len(sys.argv) > 1:
|
|
image_path = sys.argv[1]
|
|
result = classify_image(image_path)
|
|
print(json.dumps(result))
|
|
else:
|
|
print(json.dumps({
|
|
"success": False,
|
|
"error": "No image path provided",
|
|
"predictions": []
|
|
}))
|