314 lines
10 KiB
Plaintext
314 lines
10 KiB
Plaintext
"""
|
|
GPU acceleration utilities for LightRAG embedding functions.
|
|
|
|
This module provides GPU-accelerated embedding computation with support for:
|
|
- CUDA (NVIDIA GPUs)
|
|
- ROCm (AMD GPUs)
|
|
- CPU fallback
|
|
- Batch processing optimization
|
|
- Memory management
|
|
- Automatic device detection
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import asyncio
|
|
from typing import List, Optional, Callable, Any
|
|
import numpy as np
|
|
|
|
try:
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.cuda.amp import autocast
|
|
TORCH_AVAILABLE = True
|
|
except ImportError:
|
|
TORCH_AVAILABLE = False
|
|
|
|
logger = logging.getLogger("lightrag")
|
|
|
|
# Global GPU configuration
|
|
GPU_CONFIG = {
|
|
"enabled": os.getenv("LIGHTRAG_GPU_ENABLED", "true").lower() == "true",
|
|
"batch_size": int(os.getenv("LIGHTRAG_GPU_BATCH_SIZE", "64")),
|
|
"max_concurrent_batches": int(os.getenv("LIGHTRAG_GPU_MAX_CONCURRENT_BATCHES", "4")),
|
|
"precision": os.getenv("LIGHTRAG_GPU_PRECISION", "fp16"), # fp16, fp32, bf16
|
|
}
|
|
|
|
|
|
def detect_gpu_device() -> Optional[str]:
|
|
"""
|
|
Detect available GPU device and return appropriate device string.
|
|
|
|
Returns:
|
|
Device string ("cuda", "mps", "cpu") or None if no GPU available
|
|
"""
|
|
if not TORCH_AVAILABLE:
|
|
return None
|
|
|
|
try:
|
|
if torch.cuda.is_available():
|
|
# Check CUDA device count and memory
|
|
if torch.cuda.device_count() > 0:
|
|
device_name = torch.cuda.get_device_name(0)
|
|
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
|
|
logger.info(f"CUDA GPU detected: {device_name}, Free memory: {free_memory / 1024**3:.1f} GB")
|
|
return "cuda"
|
|
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
logger.info("Apple MPS (Metal Performance Shaders) detected")
|
|
return "mps"
|
|
|
|
except Exception as e:
|
|
logger.warning(f"GPU detection failed: {e}")
|
|
|
|
return "cpu"
|
|
|
|
|
|
def get_optimal_batch_size(embedding_dim: int, device: str) -> int:
|
|
"""
|
|
Calculate optimal batch size based on embedding dimension and device.
|
|
|
|
Args:
|
|
embedding_dim: Dimension of the embeddings
|
|
device: Device type ("cuda", "mps", "cpu")
|
|
|
|
Returns:
|
|
Optimal batch size
|
|
"""
|
|
base_batch_size = GPU_CONFIG["batch_size"]
|
|
|
|
if device == "cuda":
|
|
# Adjust for GPU memory
|
|
try:
|
|
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
|
|
# Rough estimate: 4 bytes per float * dimensions * batch size
|
|
memory_per_batch = 4 * embedding_dim * base_batch_size
|
|
if free_memory > 2 * memory_per_batch: # Leave room for other operations
|
|
return base_batch_size
|
|
else:
|
|
# Reduce batch size based on available memory
|
|
reduced_batch = max(1, int((free_memory * 0.5) / (4 * embedding_dim)))
|
|
logger.info(f"Reducing batch size from {base_batch_size} to {reduced_batch} due to memory constraints")
|
|
return reduced_batch
|
|
except:
|
|
return base_batch_size
|
|
|
|
elif device == "mps":
|
|
# MPS typically has less memory than dedicated GPUs
|
|
return min(base_batch_size, 32)
|
|
|
|
else:
|
|
# CPU - use smaller batches
|
|
return min(base_batch_size, 16)
|
|
|
|
|
|
async def gpu_embedding_wrapper(
|
|
embedding_func: Callable[[List[str]], np.ndarray],
|
|
texts: List[str],
|
|
device: Optional[str] = None,
|
|
model: Optional[Any] = None,
|
|
tokenizer: Optional[Any] = None,
|
|
**kwargs
|
|
) -> np.ndarray:
|
|
"""
|
|
GPU-accelerated embedding wrapper with batch processing optimization.
|
|
|
|
Args:
|
|
embedding_func: Original embedding function to wrap
|
|
texts: List of texts to embed
|
|
device: Target device ("cuda", "mps", "cpu")
|
|
model: Optional model for local GPU inference
|
|
tokenizer: Optional tokenizer for local models
|
|
**kwargs: Additional arguments for the embedding function
|
|
|
|
Returns:
|
|
Numpy array of embeddings
|
|
"""
|
|
if not GPU_CONFIG["enabled"] or not texts:
|
|
return await embedding_func(texts, **kwargs)
|
|
|
|
# Detect device if not specified
|
|
target_device = device or detect_gpu_device() or "cpu"
|
|
|
|
if target_device == "cpu":
|
|
# Fall back to CPU processing
|
|
return await embedding_func(texts, **kwargs)
|
|
|
|
# For Hugging Face models with local GPU support
|
|
if model is not None and tokenizer is not None and TORCH_AVAILABLE:
|
|
return await _hf_gpu_embedding(texts, model, tokenizer, target_device)
|
|
|
|
# For API-based embedding functions, use batch optimization
|
|
return await _batch_optimized_embedding(embedding_func, texts, target_device, **kwargs)
|
|
|
|
|
|
async def _hf_gpu_embedding(
|
|
texts: List[str],
|
|
model: Any,
|
|
tokenizer: Any,
|
|
device: str
|
|
) -> np.ndarray:
|
|
"""
|
|
GPU-accelerated embedding for Hugging Face models.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
model: Hugging Face model
|
|
tokenizer: Hugging Face tokenizer
|
|
device: Target device
|
|
|
|
Returns:
|
|
Numpy array of embeddings
|
|
"""
|
|
if not TORCH_AVAILABLE:
|
|
raise ImportError("PyTorch is required for GPU-accelerated Hugging Face embeddings")
|
|
|
|
try:
|
|
# Move model to target device
|
|
model = model.to(device)
|
|
|
|
# Determine optimal batch size
|
|
embedding_dim = model.config.hidden_size
|
|
batch_size = get_optimal_batch_size(embedding_dim, device)
|
|
|
|
embeddings = []
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
batch_texts = texts[i:i + batch_size]
|
|
|
|
# Tokenize batch
|
|
encoded = tokenizer(
|
|
batch_texts,
|
|
return_tensors="pt",
|
|
padding=True,
|
|
truncation=True,
|
|
max_length=512
|
|
).to(device)
|
|
|
|
# Generate embeddings with appropriate precision
|
|
with torch.no_grad():
|
|
if GPU_CONFIG["precision"] == "fp16":
|
|
with autocast():
|
|
outputs = model(**encoded)
|
|
else:
|
|
outputs = model(**encoded)
|
|
|
|
# Use mean pooling of last hidden states
|
|
batch_embeddings = outputs.last_hidden_state.mean(dim=1)
|
|
|
|
# Convert to appropriate precision
|
|
if GPU_CONFIG["precision"] == "fp16":
|
|
batch_embeddings = batch_embeddings.half()
|
|
elif GPU_CONFIG["precision"] == "bf16":
|
|
batch_embeddings = batch_embeddings.bfloat16()
|
|
|
|
embeddings.append(batch_embeddings.cpu().numpy())
|
|
|
|
return np.concatenate(embeddings)
|
|
|
|
except Exception as e:
|
|
logger.error(f"GPU embedding failed: {e}")
|
|
# Fall back to CPU
|
|
from lightrag.llm.hf import hf_embed
|
|
return await hf_embed(texts, tokenizer, model)
|
|
|
|
|
|
async def _batch_optimized_embedding(
|
|
embedding_func: Callable[[List[str]], np.ndarray],
|
|
texts: List[str],
|
|
device: str,
|
|
**kwargs
|
|
) -> np.ndarray:
|
|
"""
|
|
Optimize batch processing for API-based embedding functions.
|
|
|
|
Args:
|
|
embedding_func: API embedding function
|
|
texts: List of texts to embed
|
|
device: Target device (for logging)
|
|
**kwargs: Additional arguments
|
|
|
|
Returns:
|
|
Numpy array of embeddings
|
|
"""
|
|
batch_size = GPU_CONFIG["batch_size"]
|
|
max_concurrent = GPU_CONFIG["max_concurrent_batches"]
|
|
|
|
logger.debug(f"GPU-optimized batch processing: {len(texts)} texts, batch_size={batch_size}, device={device}")
|
|
|
|
# Create batches
|
|
batches = [
|
|
texts[i:i + batch_size]
|
|
for i in range(0, len(texts), batch_size)
|
|
]
|
|
|
|
# Process batches with concurrency control
|
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
|
async def process_batch(batch):
|
|
async with semaphore:
|
|
return await embedding_func(batch, **kwargs)
|
|
|
|
# Process all batches concurrently
|
|
embedding_tasks = [process_batch(batch) for batch in batches]
|
|
embeddings_list = await asyncio.gather(*embedding_tasks)
|
|
|
|
return np.concatenate(embeddings_list)
|
|
|
|
|
|
def wrap_with_gpu_acceleration(
|
|
embedding_func: Callable[[List[str]], np.ndarray],
|
|
embedding_dim: int,
|
|
model: Optional[Any] = None,
|
|
tokenizer: Optional[Any] = None
|
|
) -> Callable[[List[str]], np.ndarray]:
|
|
"""
|
|
Wrap an embedding function with GPU acceleration.
|
|
|
|
Args:
|
|
embedding_func: Original embedding function
|
|
embedding_dim: Dimension of the embeddings
|
|
model: Optional model for local inference
|
|
tokenizer: Optional tokenizer for local models
|
|
|
|
Returns:
|
|
Wrapped embedding function with GPU acceleration
|
|
"""
|
|
from lightrag.utils import wrap_embedding_func_with_attrs
|
|
|
|
async def gpu_accelerated_embedding(texts: List[str], **kwargs) -> np.ndarray:
|
|
return await gpu_embedding_wrapper(
|
|
embedding_func, texts, model=model, tokenizer=tokenizer, **kwargs
|
|
)
|
|
|
|
# Preserve original function attributes
|
|
gpu_accelerated_embedding.__name__ = f"gpu_accelerated_{embedding_func.__name__}"
|
|
gpu_accelerated_embedding.__doc__ = f"GPU-accelerated version of {embedding_func.__name__}"
|
|
|
|
# Wrap with attributes
|
|
return wrap_embedding_func_with_attrs(
|
|
embedding_dim=embedding_dim,
|
|
func=gpu_accelerated_embedding
|
|
)
|
|
|
|
|
|
# Example usage for existing embedding functions
|
|
def enable_gpu_for_embedding_functions():
|
|
"""
|
|
Enable GPU acceleration for all embedding functions in the system.
|
|
This should be called during application initialization.
|
|
"""
|
|
if not GPU_CONFIG["enabled"]:
|
|
logger.info("GPU acceleration is disabled via configuration")
|
|
return
|
|
|
|
device = detect_gpu_device()
|
|
if device and device != "cpu":
|
|
logger.info(f"GPU acceleration enabled using {device.upper()}")
|
|
else:
|
|
logger.info("No GPU detected, using CPU fallback")
|
|
GPU_CONFIG["enabled"] = False
|
|
|
|
|
|
# Initialize GPU configuration on import
|
|
enable_gpu_for_embedding_functions() |