jina rerank working
This commit is contained in:
@@ -262,7 +262,7 @@ def parse_args() -> argparse.Namespace:
|
||||
"--rerank-binding",
|
||||
type=str,
|
||||
default=get_env_value("RERANK_BINDING", DEFAULT_RERANK_BINDING),
|
||||
choices=["null", "cohere", "jina", "aliyun"],
|
||||
choices=["null", "cohere", "jina", "aliyun", "ollama"],
|
||||
help=f"Rerank binding type (default: from env or {DEFAULT_RERANK_BINDING})",
|
||||
)
|
||||
|
||||
|
||||
@@ -538,13 +538,14 @@ def create_app(args):
|
||||
# Configure rerank function based on args.rerank_bindingparameter
|
||||
rerank_model_func = None
|
||||
if args.rerank_binding != "null":
|
||||
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank
|
||||
from lightrag.rerank import cohere_rerank, jina_rerank, ali_rerank, ollama_rerank
|
||||
|
||||
# Map rerank binding to corresponding function
|
||||
rerank_functions = {
|
||||
"cohere": cohere_rerank,
|
||||
"jina": jina_rerank,
|
||||
"aliyun": ali_rerank,
|
||||
"ollama": ollama_rerank,
|
||||
}
|
||||
|
||||
# Select the appropriate rerank function based on binding
|
||||
|
||||
@@ -290,6 +290,99 @@ async def ali_rerank(
|
||||
)
|
||||
|
||||
|
||||
async def ollama_rerank(
|
||||
query: str,
|
||||
documents: List[str],
|
||||
top_n: Optional[int] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "jina-reranker-v2:latest",
|
||||
base_url: str = "http://localhost:11434",
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Rerank documents using Ollama with Jina rerank models.
|
||||
|
||||
This function uses Ollama's embedding API to get embeddings for the query
|
||||
and documents, then calculates cosine similarity for reranking.
|
||||
|
||||
Args:
|
||||
query: The search query
|
||||
documents: List of strings to rerank
|
||||
top_n: Number of top results to return
|
||||
api_key: API key (not used for Ollama, kept for compatibility)
|
||||
model: Ollama model name for reranking
|
||||
base_url: Ollama server URL
|
||||
extra_body: Additional parameters for Ollama API
|
||||
|
||||
Returns:
|
||||
List of dictionary of ["index": int, "relevance_score": float]
|
||||
"""
|
||||
import numpy as np
|
||||
from lightrag.llm.ollama import ollama_embed
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
# Get embeddings for query and all documents
|
||||
all_texts = [query] + documents
|
||||
|
||||
try:
|
||||
# Get embeddings from Ollama
|
||||
embeddings = await ollama_embed(
|
||||
texts=all_texts,
|
||||
embed_model=model,
|
||||
host=base_url,
|
||||
api_key=api_key,
|
||||
options=extra_body or {}
|
||||
)
|
||||
|
||||
if len(embeddings) != len(all_texts):
|
||||
logger.error(f"Embedding count mismatch: expected {len(all_texts)}, got {len(embeddings)}")
|
||||
return []
|
||||
|
||||
# Extract query embedding (first one) and document embeddings
|
||||
query_embedding = embeddings[0]
|
||||
doc_embeddings = embeddings[1:]
|
||||
|
||||
# Calculate cosine similarities
|
||||
similarities = []
|
||||
for i, doc_embedding in enumerate(doc_embeddings):
|
||||
# Cosine similarity: dot product of normalized vectors
|
||||
norm_query = np.linalg.norm(query_embedding)
|
||||
norm_doc = np.linalg.norm(doc_embedding)
|
||||
|
||||
if norm_query == 0 or norm_doc == 0:
|
||||
similarity = 0.0
|
||||
else:
|
||||
similarity = np.dot(query_embedding, doc_embedding) / (norm_query * norm_doc)
|
||||
|
||||
# Convert to relevance score (0-1 range, higher is better)
|
||||
# Cosine similarity ranges from -1 to 1, so we normalize to 0-1
|
||||
relevance_score = (similarity + 1) / 2
|
||||
|
||||
similarities.append((i, relevance_score))
|
||||
|
||||
# Sort by relevance score (descending)
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# Apply top_n if specified
|
||||
if top_n is not None and top_n > 0:
|
||||
similarities = similarities[:top_n]
|
||||
|
||||
# Convert to expected format
|
||||
results = [
|
||||
{"index": idx, "relevance_score": float(score)}
|
||||
for idx, score in similarities
|
||||
]
|
||||
|
||||
logger.debug(f"Ollama rerank completed: {len(results)} results")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in ollama_rerank: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
"""Please run this test as a module:
|
||||
python -m lightrag.rerank
|
||||
"""
|
||||
|
||||
@@ -6,10 +6,13 @@ os.environ['OPENAI_API_KEY'] = 'sk-55f6e57f1d834b0e93ceaf98cc2cb715'
|
||||
os.environ['DEEPSEEK_API_KEY'] = 'sk-55f6e57f1d834b0e93ceaf98cc2cb715'
|
||||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||||
os.environ['OLLAMA_EMBEDDING_MODEL'] = 'snowflake-arctic-embed:latest'
|
||||
os.environ['OLLAMA_RERANKER_MODEL'] = 'jina-reranker:latest'
|
||||
os.environ['OLLAMA_RERANKER_MODEL'] = 'jina-reranker-v2:latest' # Updated to v2 model
|
||||
os.environ['OPENAI_API_MODEL'] = 'deepseek-chat'
|
||||
os.environ['OPENAI_API_BASE'] = 'https://api.deepseek.com/v1'
|
||||
os.environ['LLM_BINDING_HOST'] = 'https://api.deepseek.com/v1'
|
||||
# Ollama rerank configuration - using local Ollama server
|
||||
os.environ['RERANK_BINDING_HOST'] = 'http://localhost:11434' # Local Ollama server
|
||||
os.environ['RERANK_BINDING_API_KEY'] = '' # No API key needed for local Ollama
|
||||
|
||||
# Set database environment variables
|
||||
os.environ['REDIS_URI'] = 'redis://localhost:6379'
|
||||
@@ -29,7 +32,7 @@ cmd = [
|
||||
'--auto-scan-at-startup',
|
||||
'--llm-binding', 'openai',
|
||||
'--embedding-binding', 'ollama',
|
||||
'--rerank-binding', 'null',
|
||||
'--rerank-binding', 'ollama', # Changed from 'jina' to 'ollama' for local Ollama rerank
|
||||
'--summary-max-tokens', '0', # Disable entity extraction by setting summary tokens to 0
|
||||
'--timeout', '600' # Increase server timeout to 600 seconds to avoid nginx 504
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@ set OPENAI_API_KEY=sk-55f6e57f1d834b0e93ceaf98cc2cb715
|
||||
set OPENAI_BASE_URL=https://api.deepseek.com/v1
|
||||
set LLM_MODEL=deepseek-chat
|
||||
set OLLAMA_EMBEDDING_MODEL=snowflake-arctic-embed:latest
|
||||
set OLLAMA_RERANKER_MODEL=jina-reranker:latest
|
||||
set OLLAMA_RERANKER_MODEL=jina-reranker-v2:latest
|
||||
set PYTHONIOENCODING=utf-8
|
||||
|
||||
echo Setting GPU processing environment...
|
||||
@@ -37,6 +37,6 @@ set QDRANT_URI=http://localhost:6333/
|
||||
set POSTGRES_URI=postgresql://jleu3482:jleu1212@localhost:5432/rag_anything
|
||||
|
||||
echo Starting LightRAG server on port 3015 with enhanced document processing...
|
||||
python -m lightrag.api.lightrag_server --port 3015 --working-dir rag_storage --input-dir inputs --key jleu1212 --auto-scan-at-startup --llm-binding openai --embedding-binding ollama --rerank-binding jina --summary-max-tokens 1200
|
||||
python -m lightrag.api.lightrag_server --port 3015 --working-dir rag_storage --input-dir inputs --key jleu1212 --auto-scan-at-startup --llm-binding openai --embedding-binding ollama --rerank-binding ollama --summary-max-tokens 1200
|
||||
|
||||
pause
|
||||
Reference in New Issue
Block a user