87 lines
3.1 KiB
Python
87 lines
3.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify Ollama rerank functionality
|
|
"""
|
|
import asyncio
|
|
import sys
|
|
import os
|
|
|
|
# Add LightRAG to path
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'LightRAG-main'))
|
|
|
|
from lightrag.rerank import ollama_rerank
|
|
|
|
async def test_ollama_rerank():
|
|
"""Test the Ollama rerank function"""
|
|
print("Testing Ollama rerank function...")
|
|
|
|
# Test query and documents
|
|
query = "What is artificial intelligence?"
|
|
documents = [
|
|
"Artificial intelligence is the simulation of human intelligence processes by machines.",
|
|
"Machine learning is a subset of AI that enables systems to learn from data.",
|
|
"Deep learning uses neural networks with multiple layers to analyze data.",
|
|
"Natural language processing allows computers to understand human language.",
|
|
"Computer vision enables machines to interpret visual information."
|
|
]
|
|
|
|
try:
|
|
print(f"Query: {query}")
|
|
print(f"Number of documents: {len(documents)}")
|
|
|
|
# Call ollama_rerank
|
|
results = await ollama_rerank(
|
|
query=query,
|
|
documents=documents,
|
|
top_n=3,
|
|
model="jina-reranker-v2:latest",
|
|
base_url="http://localhost:11434"
|
|
)
|
|
|
|
print(f"\nRerank results (top {len(results)}):")
|
|
for i, result in enumerate(results):
|
|
idx = result['index']
|
|
score = result['relevance_score']
|
|
text = documents[idx] if idx < len(documents) else "Unknown"
|
|
print(f"{i+1}. Index: {idx}, Score: {score:.4f}")
|
|
print(f" Text: {text[:80]}...")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Error testing Ollama rerank: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
# Check if Ollama is running
|
|
import requests
|
|
try:
|
|
response = requests.get("http://localhost:11434/api/tags", timeout=5)
|
|
if response.status_code == 200:
|
|
print("Ollama server is running")
|
|
models = response.json().get("models", [])
|
|
print(f"Available models: {[m.get('name', '') for m in models]}")
|
|
|
|
# Check for jina-reranker-v2 model
|
|
jina_models = [m for m in models if 'jina-reranker' in m.get('name', '')]
|
|
if jina_models:
|
|
print(f"Found Jina rerank models: {[m['name'] for m in jina_models]}")
|
|
else:
|
|
print("Warning: No Jina rerank models found in Ollama")
|
|
print("You may need to pull the model: ollama pull jina-reranker-v2:latest")
|
|
else:
|
|
print(f"Ollama server returned status {response.status_code}")
|
|
except Exception as e:
|
|
print(f"Cannot connect to Ollama server: {e}")
|
|
print("Make sure Ollama is running on http://localhost:11434")
|
|
sys.exit(1)
|
|
|
|
# Run the test
|
|
success = asyncio.run(test_ollama_rerank())
|
|
if success:
|
|
print("\n✅ Ollama rerank test passed!")
|
|
else:
|
|
print("\n❌ Ollama rerank test failed!")
|
|
sys.exit(1) |