Files
railseek6/LightRAG-main/lightrag/utils_gpu.py

314 lines
10 KiB
Python

"""
GPU acceleration utilities for LightRAG embedding functions.
This module provides GPU-accelerated embedding computation with support for:
- CUDA (NVIDIA GPUs)
- ROCm (AMD GPUs)
- CPU fallback
- Batch processing optimization
- Memory management
- Automatic device detection
"""
import os
import logging
import asyncio
from typing import List, Optional, Callable, Any
import numpy as np
try:
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
logger = logging.getLogger("lightrag")
# Global GPU configuration
GPU_CONFIG = {
"enabled": os.getenv("LIGHTRAG_GPU_ENABLED", "true").lower() == "true",
"batch_size": int(os.getenv("LIGHTRAG_GPU_BATCH_SIZE", "64")),
"max_concurrent_batches": int(os.getenv("LIGHTRAG_GPU_MAX_CONCURRENT_BATCHES", "4")),
"precision": os.getenv("LIGHTRAG_GPU_PRECISION", "fp16"), # fp16, fp32, bf16
}
def detect_gpu_device() -> Optional[str]:
"""
Detect available GPU device and return appropriate device string.
Returns:
Device string ("cuda", "mps", "cpu") or None if no GPU available
"""
if not TORCH_AVAILABLE:
return None
try:
if torch.cuda.is_available():
# Check CUDA device count and memory
if torch.cuda.device_count() > 0:
device_name = torch.cuda.get_device_name(0)
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
logger.info(f"CUDA GPU detected: {device_name}, Free memory: {free_memory / 1024**3:.1f} GB")
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
logger.info("Apple MPS (Metal Performance Shaders) detected")
return "mps"
except Exception as e:
logger.warning(f"GPU detection failed: {e}")
return "cpu"
def get_optimal_batch_size(embedding_dim: int, device: str) -> int:
"""
Calculate optimal batch size based on embedding dimension and device.
Args:
embedding_dim: Dimension of the embeddings
device: Device type ("cuda", "mps", "cpu")
Returns:
Optimal batch size
"""
base_batch_size = GPU_CONFIG["batch_size"]
if device == "cuda":
# Adjust for GPU memory
try:
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
# Rough estimate: 4 bytes per float * dimensions * batch size
memory_per_batch = 4 * embedding_dim * base_batch_size
if free_memory > 2 * memory_per_batch: # Leave room for other operations
return base_batch_size
else:
# Reduce batch size based on available memory
reduced_batch = max(1, int((free_memory * 0.5) / (4 * embedding_dim)))
logger.info(f"Reducing batch size from {base_batch_size} to {reduced_batch} due to memory constraints")
return reduced_batch
except:
return base_batch_size
elif device == "mps":
# MPS typically has less memory than dedicated GPUs
return min(base_batch_size, 32)
else:
# CPU - use smaller batches
return min(base_batch_size, 16)
async def gpu_embedding_wrapper(
embedding_func: Callable[[List[str]], np.ndarray],
texts: List[str],
device: Optional[str] = None,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None,
**kwargs
) -> np.ndarray:
"""
GPU-accelerated embedding wrapper with batch processing optimization.
Args:
embedding_func: Original embedding function to wrap
texts: List of texts to embed
device: Target device ("cuda", "mps", "cpu")
model: Optional model for local GPU inference
tokenizer: Optional tokenizer for local models
**kwargs: Additional arguments for the embedding function
Returns:
Numpy array of embeddings
"""
if not GPU_CONFIG["enabled"] or not texts:
return await embedding_func(texts, **kwargs)
# Detect device if not specified
target_device = device or detect_gpu_device() or "cpu"
if target_device == "cpu":
# Fall back to CPU processing
return await embedding_func(texts, **kwargs)
# For Hugging Face models with local GPU support
if model is not None and tokenizer is not None and TORCH_AVAILABLE:
return await _hf_gpu_embedding(texts, model, tokenizer, target_device)
# For API-based embedding functions, use batch optimization
return await _batch_optimized_embedding(embedding_func, texts, target_device, **kwargs)
async def _hf_gpu_embedding(
texts: List[str],
model: Any,
tokenizer: Any,
device: str
) -> np.ndarray:
"""
GPU-accelerated embedding for Hugging Face models.
Args:
texts: List of texts to embed
model: Hugging Face model
tokenizer: Hugging Face tokenizer
device: Target device
Returns:
Numpy array of embeddings
"""
if not TORCH_AVAILABLE:
raise ImportError("PyTorch is required for GPU-accelerated Hugging Face embeddings")
try:
# Move model to target device
model = model.to(device)
# Determine optimal batch size
embedding_dim = model.config.hidden_size
batch_size = get_optimal_batch_size(embedding_dim, device)
embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# Tokenize batch
encoded = tokenizer(
batch_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512
).to(device)
# Generate embeddings with appropriate precision
with torch.no_grad():
if GPU_CONFIG["precision"] == "fp16":
with autocast():
outputs = model(**encoded)
else:
outputs = model(**encoded)
# Use mean pooling of last hidden states
batch_embeddings = outputs.last_hidden_state.mean(dim=1)
# Convert to appropriate precision
if GPU_CONFIG["precision"] == "fp16":
batch_embeddings = batch_embeddings.half()
elif GPU_CONFIG["precision"] == "bf16":
batch_embeddings = batch_embeddings.bfloat16()
embeddings.append(batch_embeddings.cpu().numpy())
return np.concatenate(embeddings)
except Exception as e:
logger.error(f"GPU embedding failed: {e}")
# Fall back to CPU
from lightrag.llm.hf import hf_embed
return await hf_embed(texts, tokenizer, model)
async def _batch_optimized_embedding(
embedding_func: Callable[[List[str]], np.ndarray],
texts: List[str],
device: str,
**kwargs
) -> np.ndarray:
"""
Optimize batch processing for API-based embedding functions.
Args:
embedding_func: API embedding function
texts: List of texts to embed
device: Target device (for logging)
**kwargs: Additional arguments
Returns:
Numpy array of embeddings
"""
batch_size = GPU_CONFIG["batch_size"]
max_concurrent = GPU_CONFIG["max_concurrent_batches"]
logger.debug(f"GPU-optimized batch processing: {len(texts)} texts, batch_size={batch_size}, device={device}")
# Create batches
batches = [
texts[i:i + batch_size]
for i in range(0, len(texts), batch_size)
]
# Process batches with concurrency control
semaphore = asyncio.Semaphore(max_concurrent)
async def process_batch(batch):
async with semaphore:
return await embedding_func(batch, **kwargs)
# Process all batches concurrently
embedding_tasks = [process_batch(batch) for batch in batches]
embeddings_list = await asyncio.gather(*embedding_tasks)
return np.concatenate(embeddings_list)
def wrap_with_gpu_acceleration(
embedding_func: Callable[[List[str]], np.ndarray],
embedding_dim: int,
model: Optional[Any] = None,
tokenizer: Optional[Any] = None
) -> Callable[[List[str]], np.ndarray]:
"""
Wrap an embedding function with GPU acceleration.
Args:
embedding_func: Original embedding function
embedding_dim: Dimension of the embeddings
model: Optional model for local inference
tokenizer: Optional tokenizer for local models
Returns:
Wrapped embedding function with GPU acceleration
"""
from lightrag.utils import wrap_embedding_func_with_attrs
async def gpu_accelerated_embedding(texts: List[str], **kwargs) -> np.ndarray:
return await gpu_embedding_wrapper(
embedding_func, texts, model=model, tokenizer=tokenizer, **kwargs
)
# Preserve original function attributes
gpu_accelerated_embedding.__name__ = f"gpu_accelerated_{embedding_func.__name__}"
gpu_accelerated_embedding.__doc__ = f"GPU-accelerated version of {embedding_func.__name__}"
# Wrap with attributes
return wrap_embedding_func_with_attrs(
embedding_dim=embedding_dim,
func=gpu_accelerated_embedding
)
# Example usage for existing embedding functions
def enable_gpu_for_embedding_functions():
"""
Enable GPU acceleration for all embedding functions in the system.
This should be called during application initialization.
"""
if not GPU_CONFIG["enabled"]:
logger.info("GPU acceleration is disabled via configuration")
return
device = detect_gpu_device()
if device and device != "cpu":
logger.info(f"GPU acceleration enabled using {device.upper()}")
else:
logger.info("No GPU detected, using CPU fallback")
GPU_CONFIG["enabled"] = False
# Initialize GPU configuration on import
enable_gpu_for_embedding_functions()