415 lines
15 KiB
Python
415 lines
15 KiB
Python
"""NaFlex data loader for dynamic sequence length training.
|
|
|
|
This module provides a specialized data loader for Vision Transformer models that supports:
|
|
- Dynamic sequence length sampling during training for improved efficiency
|
|
- Variable patch size training with probabilistic selection
|
|
- Patch-level random erasing augmentation
|
|
- Efficient GPU prefetching with normalization
|
|
|
|
Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
|
|
"""
|
|
|
|
import math
|
|
from contextlib import suppress
|
|
from functools import partial
|
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from .loader import _worker_init, adapt_to_chs
|
|
from .naflex_dataset import NaFlexMapDatasetWrapper, NaFlexCollator
|
|
from .naflex_random_erasing import PatchRandomErasing
|
|
from .transforms_factory import create_transform
|
|
|
|
|
|
class NaFlexPrefetchLoader:
|
|
"""Data prefetcher for NaFlex format which normalizes patches."""
|
|
|
|
def __init__(
|
|
self,
|
|
loader: torch.utils.data.DataLoader,
|
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
|
channels: int = 3,
|
|
device: torch.device = torch.device('cuda'),
|
|
img_dtype: Optional[torch.dtype] = None,
|
|
re_prob: float = 0.,
|
|
re_mode: str = 'const',
|
|
re_count: int = 1,
|
|
re_num_splits: int = 0,
|
|
) -> None:
|
|
"""Initialize NaFlexPrefetchLoader.
|
|
|
|
Args:
|
|
loader: DataLoader to prefetch from.
|
|
mean: Mean values for normalization.
|
|
std: Standard deviation values for normalization.
|
|
channels: Number of image channels.
|
|
device: Device to move tensors to.
|
|
img_dtype: Data type for image tensors.
|
|
re_prob: Random erasing probability.
|
|
re_mode: Random erasing mode.
|
|
re_count: Maximum number of erasing rectangles.
|
|
re_num_splits: Number of augmentation splits.
|
|
"""
|
|
self.loader = loader
|
|
self.device = device
|
|
self.img_dtype = img_dtype or torch.float32
|
|
|
|
# Create mean/std tensors for normalization (will be applied to patches)
|
|
mean = adapt_to_chs(mean, channels)
|
|
std = adapt_to_chs(std, channels)
|
|
normalization_shape = (1, 1, channels)
|
|
self.channels = channels
|
|
self.mean = torch.tensor(
|
|
[x * 255 for x in mean], device=device, dtype=self.img_dtype).view(normalization_shape)
|
|
self.std = torch.tensor(
|
|
[x * 255 for x in std], device=device, dtype=self.img_dtype).view(normalization_shape)
|
|
|
|
if re_prob > 0.:
|
|
self.random_erasing = PatchRandomErasing(
|
|
erase_prob=re_prob,
|
|
mode=re_mode,
|
|
max_count=re_count,
|
|
num_splits=re_num_splits,
|
|
device=device,
|
|
)
|
|
else:
|
|
self.random_erasing = None
|
|
|
|
# Check for CUDA/NPU availability
|
|
self.is_cuda = device.type == 'cuda' and torch.cuda.is_available()
|
|
self.is_npu = device.type == 'npu' and torch.npu.is_available()
|
|
|
|
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
|
|
"""Iterate through the loader with prefetching and normalization.
|
|
|
|
Yields:
|
|
Tuple of (input_dict, targets) with normalized patches.
|
|
"""
|
|
first = True
|
|
if self.is_cuda:
|
|
stream = torch.cuda.Stream(device=self.device)
|
|
stream_context = partial(torch.cuda.stream, stream=stream)
|
|
elif self.is_npu:
|
|
stream = torch.npu.Stream(device=self.device)
|
|
stream_context = partial(torch.npu.stream, stream=stream)
|
|
else:
|
|
stream = None
|
|
stream_context = suppress
|
|
|
|
for next_input_dict, next_target in self.loader:
|
|
with stream_context():
|
|
# Move all tensors in input_dict to device
|
|
for k, v in next_input_dict.items():
|
|
if isinstance(v, torch.Tensor):
|
|
dtype = self.img_dtype if k == 'patches' else None
|
|
next_input_dict[k] = next_input_dict[k].to(
|
|
device=self.device,
|
|
non_blocking=True,
|
|
dtype=dtype,
|
|
)
|
|
|
|
next_target = next_target.to(device=self.device, non_blocking=True)
|
|
|
|
# Normalize patch values - handle both [B, N, P*P*C] and [B, N, Ph, Pw, C] formats
|
|
patches_tensor = next_input_dict['patches']
|
|
original_shape = patches_tensor.shape
|
|
|
|
if patches_tensor.ndim == 3:
|
|
# Format: [B, N, P*P*C] - flattened patches
|
|
batch_size, num_patches, patch_pixels = original_shape
|
|
# To [B*N, P*P, C] for normalization and erasing
|
|
patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
|
|
elif patches_tensor.ndim == 5:
|
|
# Format: [B, N, Ph, Pw, C] - unflattened patches (variable patch size mode)
|
|
batch_size, num_patches, patch_h, patch_w, channels = original_shape
|
|
assert channels == self.channels, f"Expected {self.channels} channels, got {channels}"
|
|
# To [B*N, Ph*Pw, C] for normalization and erasing
|
|
patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
|
|
else:
|
|
raise ValueError(f"Unexpected patches tensor dimensions: {patches_tensor.ndim}. Expected 3 or 5.")
|
|
|
|
# Apply normalization
|
|
patches = patches.sub(self.mean).div(self.std)
|
|
|
|
if self.random_erasing is not None:
|
|
patches = self.random_erasing(
|
|
patches,
|
|
patch_coord=next_input_dict['patch_coord'],
|
|
patch_valid=next_input_dict.get('patch_valid', None),
|
|
)
|
|
|
|
# Reshape back to original format
|
|
next_input_dict['patches'] = patches.view(original_shape)
|
|
|
|
if not first:
|
|
yield input_dict, target
|
|
else:
|
|
first = False
|
|
|
|
if stream is not None:
|
|
if self.is_cuda:
|
|
torch.cuda.current_stream(device=self.device).wait_stream(stream)
|
|
elif self.is_npu:
|
|
torch.npu.current_stream(device=self.device).wait_stream(stream)
|
|
|
|
input_dict = next_input_dict
|
|
target = next_target
|
|
|
|
yield input_dict, target
|
|
|
|
def __len__(self) -> int:
|
|
"""Get length of underlying loader.
|
|
|
|
Returns:
|
|
Number of batches in the loader.
|
|
"""
|
|
return len(self.loader)
|
|
|
|
@property
|
|
def sampler(self):
|
|
"""Get sampler from underlying loader.
|
|
|
|
Returns:
|
|
Sampler from the underlying DataLoader.
|
|
"""
|
|
return self.loader.sampler
|
|
|
|
@property
|
|
def dataset(self):
|
|
"""Get dataset from underlying loader.
|
|
|
|
Returns:
|
|
Dataset from the underlying DataLoader.
|
|
"""
|
|
return self.loader.dataset
|
|
|
|
|
|
def create_naflex_loader(
|
|
dataset,
|
|
patch_size: Optional[Union[Tuple[int, int], int]] = None,
|
|
patch_size_choices: Optional[List[int]] = None,
|
|
patch_size_choice_probs: Optional[List[float]] = None,
|
|
train_seq_lens: Tuple[int, ...] = (128, 256, 576, 784, 1024),
|
|
max_seq_len: int = 576,
|
|
batch_size: int = 32,
|
|
is_training: bool = False,
|
|
mixup_fn: Optional[Callable] = None,
|
|
|
|
no_aug: bool = False,
|
|
re_prob: float = 0.,
|
|
re_mode: str = 'const',
|
|
re_count: int = 1,
|
|
re_split: bool = False,
|
|
train_crop_mode: Optional[str] = None,
|
|
scale: Optional[Tuple[float, float]] = None,
|
|
ratio: Optional[Tuple[float, float]] = None,
|
|
hflip: float = 0.5,
|
|
vflip: float = 0.,
|
|
color_jitter: float = 0.4,
|
|
color_jitter_prob: Optional[float] = None,
|
|
grayscale_prob: float = 0.,
|
|
gaussian_blur_prob: float = 0.,
|
|
auto_augment: Optional[str] = None,
|
|
num_aug_repeats: int = 0,
|
|
num_aug_splits: int = 0,
|
|
interpolation: str = 'bilinear',
|
|
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
|
|
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
|
|
crop_pct: Optional[float] = None,
|
|
crop_mode: Optional[str] = None,
|
|
crop_border_pixels: Optional[int] = None,
|
|
|
|
num_workers: int = 4,
|
|
distributed: bool = False,
|
|
rank: int = 0,
|
|
world_size: int = 1,
|
|
seed: int = 42,
|
|
epoch: int = 0,
|
|
use_prefetcher: bool = True,
|
|
pin_memory: bool = True,
|
|
img_dtype: torch.dtype = torch.float32,
|
|
device: Union[str, torch.device] = torch.device('cuda'),
|
|
persistent_workers: bool = True,
|
|
worker_seeding: str = 'all',
|
|
) -> Union[torch.utils.data.DataLoader, NaFlexPrefetchLoader]:
|
|
"""Create a data loader with dynamic sequence length sampling for training.
|
|
|
|
Args:
|
|
dataset: Dataset to load from.
|
|
patch_size: Single patch size to use.
|
|
patch_size_choices: List of patch sizes for variable patch size training.
|
|
patch_size_choice_probs: Probabilities for each patch size choice.
|
|
train_seq_lens: Training sequence lengths for dynamic batching.
|
|
max_seq_len: Fixed sequence length for validation.
|
|
batch_size: Batch size for validation and max training sequence length.
|
|
is_training: Whether this is for training (enables dynamic batching).
|
|
mixup_fn: Optional mixup function.
|
|
no_aug: Disable augmentation.
|
|
re_prob: Random erasing probability.
|
|
re_mode: Random erasing mode.
|
|
re_count: Maximum number of erasing rectangles.
|
|
re_split: Random erasing split flag.
|
|
train_crop_mode: Training crop mode.
|
|
scale: Scale range for random resize crop.
|
|
ratio: Aspect ratio range for random resize crop.
|
|
hflip: Horizontal flip probability.
|
|
vflip: Vertical flip probability.
|
|
color_jitter: Color jitter factor.
|
|
color_jitter_prob: Color jitter probability.
|
|
grayscale_prob: Grayscale conversion probability.
|
|
gaussian_blur_prob: Gaussian blur probability.
|
|
auto_augment: AutoAugment policy.
|
|
num_aug_repeats: Number of augmentation repeats.
|
|
num_aug_splits: Number of augmentation splits.
|
|
interpolation: Interpolation method.
|
|
mean: Normalization mean values.
|
|
std: Normalization standard deviation values.
|
|
crop_pct: Crop percentage for validation.
|
|
crop_mode: Crop mode.
|
|
crop_border_pixels: Crop border pixels.
|
|
num_workers: Number of data loading workers.
|
|
distributed: Whether using distributed training.
|
|
rank: Process rank for distributed training.
|
|
world_size: Total number of processes.
|
|
seed: Random seed.
|
|
epoch: Starting epoch.
|
|
use_prefetcher: Whether to use prefetching.
|
|
pin_memory: Whether to pin memory.
|
|
img_dtype: Image data type.
|
|
device: Device to move tensors to.
|
|
persistent_workers: Whether to use persistent workers.
|
|
worker_seeding: Worker seeding mode.
|
|
|
|
Returns:
|
|
DataLoader or NaFlexPrefetchLoader instance.
|
|
"""
|
|
|
|
if is_training:
|
|
# For training, use the dynamic sequence length mechanism
|
|
assert num_aug_repeats == 0, 'Augmentation repeats not currently supported in NaFlex loader'
|
|
|
|
transform_factory = partial(
|
|
create_transform,
|
|
is_training=True,
|
|
no_aug=no_aug,
|
|
train_crop_mode=train_crop_mode,
|
|
scale=scale,
|
|
ratio=ratio,
|
|
hflip=hflip,
|
|
vflip=vflip,
|
|
color_jitter=color_jitter,
|
|
color_jitter_prob=color_jitter_prob,
|
|
grayscale_prob=grayscale_prob,
|
|
gaussian_blur_prob=gaussian_blur_prob,
|
|
auto_augment=auto_augment,
|
|
interpolation=interpolation,
|
|
mean=mean,
|
|
std=std,
|
|
crop_pct=crop_pct,
|
|
crop_mode=crop_mode,
|
|
crop_border_pixels=crop_border_pixels,
|
|
re_prob=re_prob,
|
|
re_mode=re_mode,
|
|
re_count=re_count,
|
|
use_prefetcher=use_prefetcher,
|
|
naflex=True,
|
|
)
|
|
|
|
max_train_seq_len = max(train_seq_lens)
|
|
max_tokens_per_batch = batch_size * max_train_seq_len
|
|
|
|
if isinstance(dataset, torch.utils.data.IterableDataset):
|
|
assert False, "IterableDataset Wrapper is a WIP"
|
|
|
|
naflex_dataset = NaFlexMapDatasetWrapper(
|
|
dataset,
|
|
transform_factory=transform_factory,
|
|
patch_size=patch_size,
|
|
patch_size_choices=patch_size_choices,
|
|
patch_size_choice_probs=patch_size_choice_probs,
|
|
seq_lens=train_seq_lens,
|
|
max_tokens_per_batch=max_tokens_per_batch,
|
|
mixup_fn=mixup_fn,
|
|
seed=seed,
|
|
distributed=distributed,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
shuffle=True,
|
|
epoch=epoch,
|
|
)
|
|
|
|
# NOTE: Collation is handled by the dataset wrapper for training
|
|
loader = torch.utils.data.DataLoader(
|
|
naflex_dataset,
|
|
batch_size=None,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
sampler=None,
|
|
pin_memory=pin_memory,
|
|
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
|
|
persistent_workers=persistent_workers
|
|
)
|
|
|
|
if use_prefetcher:
|
|
loader = NaFlexPrefetchLoader(
|
|
loader,
|
|
mean=mean,
|
|
std=std,
|
|
img_dtype=img_dtype,
|
|
device=device,
|
|
re_prob=re_prob,
|
|
re_mode=re_mode,
|
|
re_count=re_count,
|
|
)
|
|
|
|
else:
|
|
# For validation, use fixed sequence length (unchanged)
|
|
dataset.transform = create_transform(
|
|
is_training=False,
|
|
interpolation=interpolation,
|
|
mean=mean,
|
|
std=std,
|
|
# FIXME add crop args when sequence transforms support crop modes
|
|
use_prefetcher=use_prefetcher,
|
|
naflex=True,
|
|
patch_size=patch_size,
|
|
max_seq_len=max_seq_len,
|
|
patchify=True,
|
|
)
|
|
|
|
# Create the collator
|
|
collate_fn = NaFlexCollator(max_seq_len=max_seq_len)
|
|
|
|
# Handle distributed training
|
|
sampler = None
|
|
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
|
# For validation, use OrderedDistributedSampler
|
|
from timm.data.distributed_sampler import OrderedDistributedSampler
|
|
sampler = OrderedDistributedSampler(dataset)
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
sampler=sampler,
|
|
collate_fn=collate_fn,
|
|
pin_memory=pin_memory,
|
|
drop_last=False,
|
|
)
|
|
|
|
if use_prefetcher:
|
|
loader = NaFlexPrefetchLoader(
|
|
loader,
|
|
mean=mean,
|
|
std=std,
|
|
img_dtype=img_dtype,
|
|
device=device,
|
|
)
|
|
|
|
return loader
|