""" GPU acceleration utilities for LightRAG embedding functions. This module provides GPU-accelerated embedding computation with support for: - CUDA (NVIDIA GPUs) - ROCm (AMD GPUs) - CPU fallback - Batch processing optimization - Memory management - Automatic device detection """ import os import logging import asyncio from typing import List, Optional, Callable, Any import numpy as np try: import torch import torch.nn as nn from torch.cuda.amp import autocast TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False logger = logging.getLogger("lightrag") # Global GPU configuration GPU_CONFIG = { "enabled": os.getenv("LIGHTRAG_GPU_ENABLED", "true").lower() == "true", "batch_size": int(os.getenv("LIGHTRAG_GPU_BATCH_SIZE", "64")), "max_concurrent_batches": int(os.getenv("LIGHTRAG_GPU_MAX_CONCURRENT_BATCHES", "4")), "precision": os.getenv("LIGHTRAG_GPU_PRECISION", "fp16"), # fp16, fp32, bf16 } def detect_gpu_device() -> Optional[str]: """ Detect available GPU device and return appropriate device string. Returns: Device string ("cuda", "mps", "cpu") or None if no GPU available """ if not TORCH_AVAILABLE: return None try: if torch.cuda.is_available(): # Check CUDA device count and memory if torch.cuda.device_count() > 0: device_name = torch.cuda.get_device_name(0) free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) logger.info(f"CUDA GPU detected: {device_name}, Free memory: {free_memory / 1024**3:.1f} GB") return "cuda" elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): logger.info("Apple MPS (Metal Performance Shaders) detected") return "mps" except Exception as e: logger.warning(f"GPU detection failed: {e}") return "cpu" def get_optimal_batch_size(embedding_dim: int, device: str) -> int: """ Calculate optimal batch size based on embedding dimension and device. Args: embedding_dim: Dimension of the embeddings device: Device type ("cuda", "mps", "cpu") Returns: Optimal batch size """ base_batch_size = GPU_CONFIG["batch_size"] if device == "cuda": # Adjust for GPU memory try: free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) # Rough estimate: 4 bytes per float * dimensions * batch size memory_per_batch = 4 * embedding_dim * base_batch_size if free_memory > 2 * memory_per_batch: # Leave room for other operations return base_batch_size else: # Reduce batch size based on available memory reduced_batch = max(1, int((free_memory * 0.5) / (4 * embedding_dim))) logger.info(f"Reducing batch size from {base_batch_size} to {reduced_batch} due to memory constraints") return reduced_batch except: return base_batch_size elif device == "mps": # MPS typically has less memory than dedicated GPUs return min(base_batch_size, 32) else: # CPU - use smaller batches return min(base_batch_size, 16) async def gpu_embedding_wrapper( embedding_func: Callable[[List[str]], np.ndarray], texts: List[str], device: Optional[str] = None, model: Optional[Any] = None, tokenizer: Optional[Any] = None, **kwargs ) -> np.ndarray: """ GPU-accelerated embedding wrapper with batch processing optimization. Args: embedding_func: Original embedding function to wrap texts: List of texts to embed device: Target device ("cuda", "mps", "cpu") model: Optional model for local GPU inference tokenizer: Optional tokenizer for local models **kwargs: Additional arguments for the embedding function Returns: Numpy array of embeddings """ if not GPU_CONFIG["enabled"] or not texts: return await embedding_func(texts, **kwargs) # Detect device if not specified target_device = device or detect_gpu_device() or "cpu" if target_device == "cpu": # Fall back to CPU processing return await embedding_func(texts, **kwargs) # For Hugging Face models with local GPU support if model is not None and tokenizer is not None and TORCH_AVAILABLE: return await _hf_gpu_embedding(texts, model, tokenizer, target_device) # For API-based embedding functions, use batch optimization return await _batch_optimized_embedding(embedding_func, texts, target_device, **kwargs) async def _hf_gpu_embedding( texts: List[str], model: Any, tokenizer: Any, device: str ) -> np.ndarray: """ GPU-accelerated embedding for Hugging Face models. Args: texts: List of texts to embed model: Hugging Face model tokenizer: Hugging Face tokenizer device: Target device Returns: Numpy array of embeddings """ if not TORCH_AVAILABLE: raise ImportError("PyTorch is required for GPU-accelerated Hugging Face embeddings") try: # Move model to target device model = model.to(device) # Determine optimal batch size embedding_dim = model.config.hidden_size batch_size = get_optimal_batch_size(embedding_dim, device) embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i + batch_size] # Tokenize batch encoded = tokenizer( batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(device) # Generate embeddings with appropriate precision with torch.no_grad(): if GPU_CONFIG["precision"] == "fp16": with autocast(): outputs = model(**encoded) else: outputs = model(**encoded) # Use mean pooling of last hidden states batch_embeddings = outputs.last_hidden_state.mean(dim=1) # Convert to appropriate precision if GPU_CONFIG["precision"] == "fp16": batch_embeddings = batch_embeddings.half() elif GPU_CONFIG["precision"] == "bf16": batch_embeddings = batch_embeddings.bfloat16() embeddings.append(batch_embeddings.cpu().numpy()) return np.concatenate(embeddings) except Exception as e: logger.error(f"GPU embedding failed: {e}") # Fall back to CPU from lightrag.llm.hf import hf_embed return await hf_embed(texts, tokenizer, model) async def _batch_optimized_embedding( embedding_func: Callable[[List[str]], np.ndarray], texts: List[str], device: str, **kwargs ) -> np.ndarray: """ Optimize batch processing for API-based embedding functions. Args: embedding_func: API embedding function texts: List of texts to embed device: Target device (for logging) **kwargs: Additional arguments Returns: Numpy array of embeddings """ batch_size = GPU_CONFIG["batch_size"] max_concurrent = GPU_CONFIG["max_concurrent_batches"] logger.debug(f"GPU-optimized batch processing: {len(texts)} texts, batch_size={batch_size}, device={device}") # Create batches batches = [ texts[i:i + batch_size] for i in range(0, len(texts), batch_size) ] # Process batches with concurrency control semaphore = asyncio.Semaphore(max_concurrent) async def process_batch(batch): async with semaphore: return await embedding_func(batch, **kwargs) # Process all batches concurrently embedding_tasks = [process_batch(batch) for batch in batches] embeddings_list = await asyncio.gather(*embedding_tasks) return np.concatenate(embeddings_list) def wrap_with_gpu_acceleration( embedding_func: Callable[[List[str]], np.ndarray], embedding_dim: int, model: Optional[Any] = None, tokenizer: Optional[Any] = None ) -> Callable[[List[str]], np.ndarray]: """ Wrap an embedding function with GPU acceleration. Args: embedding_func: Original embedding function embedding_dim: Dimension of the embeddings model: Optional model for local inference tokenizer: Optional tokenizer for local models Returns: Wrapped embedding function with GPU acceleration """ from lightrag.utils import wrap_embedding_func_with_attrs async def gpu_accelerated_embedding(texts: List[str], **kwargs) -> np.ndarray: return await gpu_embedding_wrapper( embedding_func, texts, model=model, tokenizer=tokenizer, **kwargs ) # Preserve original function attributes gpu_accelerated_embedding.__name__ = f"gpu_accelerated_{embedding_func.__name__}" gpu_accelerated_embedding.__doc__ = f"GPU-accelerated version of {embedding_func.__name__}" # Wrap with attributes return wrap_embedding_func_with_attrs( embedding_dim=embedding_dim, func=gpu_accelerated_embedding ) # Example usage for existing embedding functions def enable_gpu_for_embedding_functions(): """ Enable GPU acceleration for all embedding functions in the system. This should be called during application initialization. """ if not GPU_CONFIG["enabled"]: logger.info("GPU acceleration is disabled via configuration") return device = detect_gpu_device() if device and device != "cpu": logger.info(f"GPU acceleration enabled using {device.upper()}") else: logger.info("No GPU detected, using CPU fallback") GPU_CONFIG["enabled"] = False # Initialize GPU configuration on import enable_gpu_for_embedding_functions()