Files
railseek6/openclip_env/Lib/site-packages/open_clip/factory.py

1088 lines
50 KiB
Python

import json
import logging
import os
import re
import warnings
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import torch
from .convert import convert_state_dict
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
from .coca_model import CoCa
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
list_pretrained_tags_by_model, download_pretrained_from_hf
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
from .tokenizer import HFTokenizer, SimpleTokenizer, SigLipTokenizer, DEFAULT_CONTEXT_LENGTH
HF_HUB_PREFIX = 'hf-hub:'
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def _rescan_model_configs():
global _MODEL_CONFIGS
config_ext = ('.json',)
config_files = []
for config_path in _MODEL_CONFIG_PATHS:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(config_path)
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f'*{ext}'))
for cf in config_files:
with open(cf, 'r') as f:
model_cfg = json.load(f)
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
def list_models():
""" enumerate available model architectures based on config files """
return list(_MODEL_CONFIGS.keys())
def add_model_config(path):
""" add model config path or file and update registry """
if not isinstance(path, Path):
path = Path(path)
_MODEL_CONFIG_PATHS.append(path)
_rescan_model_configs()
# Define Schema Prefixes as constants
HF_HUB_PREFIX = 'hf-hub:'
LOCAL_DIR_PREFIX = 'local-dir:'
def parse_model_name(model_name: str) -> Tuple[Optional[str], str]:
"""
Parses a model name string to identify a schema and the remaining identifier.
Args:
model_name: The model name string (e.g., 'ViT-B-32',
'hf-hub:org/repo', 'local-dir:/path/to/dir',
'local-dir:./relative/path').
Returns:
A tuple (schema, identifier):
- schema (Optional[str]): 'hf-hub', 'local-dir', or None if no schema detected.
- identifier (str): The part after the schema prefix, or the original
string if no schema was present. For 'local-dir',
this is the raw path string provided.
Raises:
ValueError: If a schema prefix is present but the identifier part is empty.
"""
# Check for local directory schema first
if model_name.startswith(LOCAL_DIR_PREFIX):
# Extract the identifier (path) after the prefix
identifier = model_name[len(LOCAL_DIR_PREFIX):]
# Validate that the identifier (path) is not empty
if not identifier:
raise ValueError("Empty path specified after 'local-dir:' schema.")
# Return the schema and the raw path identifier
# Note: We don't resolve or fully validate the path here,
# that's left to the calling function (e.g., using os.path.isdir)
return 'local-dir', identifier
# Check for Hugging Face Hub schema
elif model_name.startswith(HF_HUB_PREFIX):
# Extract the identifier (HF Hub ID) after the prefix
identifier = model_name[len(HF_HUB_PREFIX):]
# Validate that the identifier is not empty
if not identifier:
raise ValueError("Empty identifier specified after 'hf-hub:' schema.")
# Return the schema and the HF Hub ID
return 'hf-hub', identifier
# If neither schema prefix is found
else:
# No schema detected, return None for schema and the original string as identifier
return None, model_name
def _get_hf_config(
model_id: str,
cache_dir: Optional[str] = None,
):
""" Fetch model config from HuggingFace Hub.
"""
config_path = download_pretrained_from_hf(
model_id,
filename='open_clip_config.json',
cache_dir=cache_dir,
)
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config
def get_model_config(model_name):
""" Fetch model config from schema specified location or local library configs.
"""
loc, model_id = parse_model_name(model_name)
if loc == 'local-dir':
local_path = Path(model_id) / 'open_clip_config.json'
with open(local_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get('model_cfg', config)
elif loc == 'hf-hub':
config = _get_hf_config(model_id)
return config.get('model_cfg', config)
elif model_name in _MODEL_CONFIGS:
return deepcopy(_MODEL_CONFIGS[model_name])
else:
return None
def load_state_dict(
checkpoint_path: str,
device='cpu',
weights_only=True,
):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
from safetensors.torch import load_file
checkpoint = load_file(checkpoint_path, device=device)
else:
try:
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
except TypeError:
checkpoint = torch.load(checkpoint_path, map_location=device)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
elif isinstance(checkpoint, torch.jit.ScriptModule):
state_dict = checkpoint.state_dict()
for key in ["input_resolution", "context_length", "vocab_size"]:
state_dict.pop(key, None)
else:
state_dict = checkpoint
if next(iter(state_dict.items()))[0].startswith('module'):
state_dict = {k[7:]: v for k, v in state_dict.items()}
return state_dict
def load_checkpoint(
model: Union[CLIP, CustomTextCLIP],
checkpoint_path: str,
strict: bool = True,
weights_only: bool = True,
device='cpu',
):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
# Separate path loading numpy big_vision (SigLIP) weights
from open_clip.convert import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}
state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
# Detect & convert 3rd party state_dicts -> open_clip
state_dict = convert_state_dict(model, state_dict)
# Detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# correct if logit_scale differs in being scaler vs 1d param
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
# correct if logit_bias differs in being scaler vs 1d param
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
if 'logit_bias' not in state_dict and model.logit_bias is not None:
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
resize_text_pos_embed(state_dict, model)
# Finally, load the massaged state_dict into model
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys
def _find_checkpoint_in_dir(dir_path: Path) -> Optional[str]:
checkpoints = list(dir_path.glob('*.safetensors')) + list(dir_path.glob('*.bin')) + list(dir_path.glob('*.pth'))
if not checkpoints:
return None
checkpoints.sort()
checkpoints.sort(key=lambda x: x.suffix == '.safetensors', reverse=True)
preferred_order = [
"open_clip_model.safetensors", "open_clip_pytorch_model.safetensors",
"open_clip_pytorch_model.bin", "open_clip_pytorch_model.pth",
"model.safetensors", "pytorch_model.bin", "pytorch_model.pth", "model.pth"
]
preferred_checkpoints = [c for c in checkpoints if c.name in preferred_order]
if preferred_checkpoints:
preferred_checkpoints.sort(key=lambda x: preferred_order.index(x.name))
chosen = preferred_checkpoints[0]
logging.info(f"Found preferred checkpoint file: {chosen.name} in {dir_path}")
return str(chosen)
chosen = checkpoints[0]
logging.warning(
f"Multiple checkpoints found in {dir_path}: {[c.name for c in checkpoints]}. Using '{chosen.name}'.")
return str(chosen)
def create_model(
model_name: str, # Can contain schemas 'hf-hub:' or 'local-dir:'
pretrained: Optional[str] = None, # Used ONLY if model_name has NO schema
load_weights: bool = True,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
force_context_length: Optional[int] = None,
pretrained_image: bool = False, # Load default base image weights (at creation, if no CLIP weights)
pretrained_text: bool = True, # Load default base text weights (at creation, if no CLIP weights) - NEW
pretrained_image_path: Optional[str] = None, # Load specific image weights from file (after creation)
pretrained_text_path: Optional[str] = None, # Load specific text weights from file (after creation)
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
require_pretrained: bool = False,
weights_only: bool = True,
**model_kwargs,
) -> torch.nn.Module:
"""
Creates and configures a contrastive vision-language model.
`model_name` specifies architecture/config source:
- 'ViT-B-32': Built-in model name. `pretrained` specifies CLIP weights source (tag or file path).
- 'hf-hub:org/repo': Loads config/weights from HF Hub. `pretrained` is IGNORED.
- 'local-dir:/path/to/folder': Loads config/weights from local dir. `pretrained` is IGNORED.
Base tower weights loading controlled by `pretrained_image` and `pretrained_text` flags,
only effective if no full CLIP checkpoint (`pretrained` or schema source) is loaded.
Tower-specific weights can be loaded *after* creation via `pretrained_image_path`
and `pretrained_text_path`.
Args:
model_name: Model identifier, potentially with schema ('hf-hub:', 'local-dir:').
pretrained: Source for CLIP weights (tag or file path) ONLY if model_name has no schema.
load_weights: Load the resolved pretrained weights if True, otherwise random init or tower overrides only.
precision: Model precision ('fp32', 'fp16', 'bf16', ...).
device: Device ('cpu', 'cuda', ...).
jit: If True, JIT compile the model.
force_quick_gelu: Force use of QuickGELU activation in model config.
force_custom_text: Force use of custom text encoder architecture.
force_patch_dropout: Override patch dropout value in model config.
force_image_size: Override image size in model config.
force_preprocess_cfg: Dict to override specific FINAL preprocessing parameters.
force_context_length: Override context length in model config.
pretrained_image: Load default base weights for image tower at creation if no CLIP weights loaded.
pretrained_text: Load default base weights for text tower at creation if no CLIP weights loaded (default: True).
pretrained_image_path: Path to load weights specifically into image tower after creation.
pretrained_text_path: Path to load weights specifically into text tower after creation.
cache_dir: Cache directory for downloads.
output_dict: If True and model supports it, return dict output.
require_pretrained: Raise error if no `pretrained` CLIP weights loaded when required.
weights_only: Use weights_only=True for torch.load (safer).
**model_kwargs: Additional keyword arguments for model constructor (highest override priority).
Returns:
The created model instance.
"""
schema, identifier = parse_model_name(model_name)
if 'pretrained_hf' in model_kwargs:
# for backwards compat, override pretrained_text
pretrained_text = model_kwargs.pop('pretrained_hf')
if isinstance(device, str):
device = torch.device(device)
model_cfg = None
preprocess_cfg = asdict(PreprocessCfg()) # Populate with defaults
checkpoint_path = None # Final path for full CLIP weights
pretrained_cfg_for_tag = None # Store tag config if pretrained is a tag and schema is None
logging.info(f"Parsing model identifier. Schema: {schema}, Identifier: {identifier}")
if schema and pretrained:
logging.warning(f"Ignoring `pretrained='{pretrained}'` because `model_name` has '{schema}' schema.")
pretrained = None # Nullify pretrained as it's ignored
# Handle schemas first - these ignore the `pretrained` argument
if schema == 'local-dir':
# Handle local directory schema
local_path = Path(identifier)
if not local_path.is_dir():
raise FileNotFoundError(f"Directory specified via 'local-dir:' schema not found: {local_path}")
local_config_path = local_path / 'open_clip_config.json'
logging.info(f"Attempting to load config from local dir: {local_config_path}")
if local_config_path.is_file():
try:
# Try loading and parsing the JSON config
with open(local_config_path, 'r', encoding='utf-8') as f:
local_json_config = json.load(f)
# Check if the required 'model_cfg' key is present
if 'model_cfg' in local_json_config:
# Load model config and merge preprocess config
model_cfg = local_json_config['model_cfg']
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, local_json_config.get('preprocess_cfg', {}))
logging.info(f"Loaded model config and preprocess from: {local_config_path}")
# Look for weights checkpoint in the same directory
checkpoint_path = _find_checkpoint_in_dir(local_path)
if checkpoint_path:
logging.info(f"Found CLIP weights in local folder: {checkpoint_path}")
else:
logging.warning(f"Local config loaded, but no CLIP weights found in {local_path}")
else:
# Config file exists but lacks the necessary key
raise ValueError(f"Local config {local_config_path} missing 'model_cfg'.")
except Exception as e:
# Handle JSON parsing errors or other exceptions during config load
raise ValueError(f"Could not load valid config from specified 'local-dir:{identifier}': {e}") from e
else:
# Directory exists but the config file is missing
raise FileNotFoundError(f"'local-dir:' specified, but config file missing: {local_config_path}")
elif schema == 'hf-hub':
# Handle Hugging Face Hub schema
model_id = identifier
logging.info(f"Attempting to load config from HF Hub: {model_id}")
try:
# Fetch configuration from Hugging Face Hub
hf_config = _get_hf_config(model_id, cache_dir=cache_dir)
if 'model_cfg' not in hf_config:
raise RuntimeError(f"'model_cfg' not found in config from {model_id}")
# Load model config and merge preprocess config
model_cfg = hf_config['model_cfg']
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, hf_config.get('preprocess_cfg', {}))
logging.info(f"Loaded model config from HF Hub: {model_id}")
# Attempt find default weights file from the Hub repo
try:
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
logging.info(f"Found default weights file on HF Hub: {checkpoint_path}")
except Exception as e_weights:
# Log warning if weights download fails, but proceed (might only need config)
logging.warning(f"Could not find/download default weights on HF Hub for {model_id}: {e_weights}")
except Exception as e_config:
# Handle errors during config fetching from HF Hub
raise RuntimeError(f"Failed initial config/weights load from HF Hub {model_id}: {e_config}") from e_config
# No Schema Prefix - Use built-in name + pretrained arg (tag or file)
elif schema is None:
# Handle model names without schema prefix
# Use identifier (original model_name) and clean it for lookup
model_name_cleaned = identifier.replace('/', '-')
# Get base config from built-in name using the cleaned identifier
model_cfg = get_model_config(model_name_cleaned)
if model_cfg is None:
# Raise error if no matching built-in config found
raise RuntimeError(
f"Model config for '{model_name_cleaned}' not found in built-ins. Available: {list_models()}")
logging.info(f"Loaded built-in {model_name_cleaned} model config.")
# Determine checkpoint path and update preprocess_cfg based on `pretrained` arg (tag or file)
if pretrained:
# Check if `pretrained` is a known tag
pretrained_cfg_for_tag = get_pretrained_cfg(model_name_cleaned, pretrained)
if pretrained_cfg_for_tag:
try:
# Download weights associated with the tag
checkpoint_path = download_pretrained(pretrained_cfg_for_tag, cache_dir=cache_dir)
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg_for_tag)
# QuickGELU compatibility check will happen in after force overrides
except Exception as e:
logging.error(f"Failed to download weights for tag '{pretrained}': {e}")
raise RuntimeError(f"Failed to download weights for tag '{pretrained}': {e}")
elif os.path.isfile(pretrained):
# Handle pretrained file path
logging.info(f"`pretrained` specifies file path: {pretrained}")
checkpoint_path = pretrained
else:
logging.error(
f"Pretrained tag or path ({pretrained}) for '{model_name_cleaned}' not found. "
f"Available tags: {list_pretrained_tags_by_model(model_name_cleaned)}"
)
raise RuntimeError(f"Pretrained value '{pretrained}' is not a known tag or valid file path")
# Apply model config overrides
if model_cfg is None:
raise RuntimeError("Model configuration could not be determined after Stage 1.")
text_cfg = model_cfg['text_cfg']
vision_cfg = model_cfg['vision_cfg']
if force_quick_gelu:
model_cfg["quick_gelu"] = True
if force_patch_dropout is not None:
vision_cfg["patch_dropout"] = force_patch_dropout
if force_image_size is not None:
vision_cfg["image_size"] = force_image_size
if force_context_length is not None:
text_cfg["context_length"] = force_context_length
# Check compatibility (e.g., QuickGELU warning for tags)
if schema is None and pretrained_cfg_for_tag:
# Only perform check if config came from built-in and weights from a tag
model_quick_gelu = model_cfg.get('quick_gelu', False) # Check the potentially overridden value
tag_quick_gelu = pretrained_cfg_for_tag.get('quick_gelu', False)
if tag_quick_gelu != model_quick_gelu:
# Warn if the final model config's GELU setting mismatches the tag's training setting
warnings.warn(
f"QuickGELU mismatch between final model config (quick_gelu={model_quick_gelu}) "
f"and pretrained tag '{pretrained}' (quick_gelu={tag_quick_gelu}).",
UserWarning
)
# Decide whether to use the checkpoint path based on load_weights
if checkpoint_path is not None:
if not load_weights:
logging.info(
f"Potential checkpoint path '{checkpoint_path}' found, but skipping assignment due to load_weights=False.")
checkpoint_path = None
else:
logging.info("No potential checkpoint path found from config source or pretrained arg.")
# Set default base weight loading flags for image and text towers
# Only load base pretrained weights if other weights will not be loaded into respective towers
enable_default_image_weights = pretrained_image and pretrained_image_path is None and checkpoint_path is None
enable_default_text_weights = pretrained_text and pretrained_text_path is None and checkpoint_path is None
is_timm_model = 'timm_model_name' in model_cfg.get("vision_cfg", {})
is_hf_text_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
if is_timm_model:
vision_cfg['timm_model_pretrained'] = enable_default_image_weights
else:
enable_default_image_weights = False # for accurate logging
if is_hf_text_model:
text_cfg['hf_model_pretrained'] = enable_default_text_weights
else:
enable_default_text_weights = False # for accurate logging
# Determine model class (CLIP, CustomTextCLIP, CoCa)
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_text_model
if custom_text:
# Use CustomTextCLIP (or CoCa if multimodal_cfg is present)
if "multimodal_cfg" in model_cfg:
model_class = CoCa
else:
model_class = CustomTextCLIP
else:
# Default to standard CLIP
model_class = CLIP
# Apply final **kwargs overrides (highest priority) to a copy of model_cfg
final_model_cfg = deepcopy(model_cfg)
final_model_cfg.update(model_kwargs)
# Get casting dtype based on precision argument
cast_dtype = get_cast_dtype(precision)
# Instantiate the model
logging.info(f"Instantiating model architecture: {model_class.__name__}")
model = model_class(**final_model_cfg, cast_dtype=cast_dtype)
_set_model_device_and_precision(model, device, precision, is_timm_model)
# Load Full Pretrained CLIP Weights (if path exists)
pretrained_loaded = False
if checkpoint_path:
logging.info(f'Loading full pretrained weights from: {checkpoint_path}')
# Use the load_checkpoint helper which handles state dict loading, conversions, etc.
# Use strict=True by default for full model loading to catch mismatches.
load_checkpoint(
model,
checkpoint_path,
strict=True,
weights_only=weights_only,
device='cpu' # Load to CPU first
)
pretrained_loaded = True
# Load tower-specific weights (image and text), after the full CLIP checkpoint, potentially overwriting parts.
pretrained_image_loaded = False # Track if specific image weights loaded
if pretrained_image_path:
if os.path.isfile(pretrained_image_path):
logging.info(f"Attempting to load image tower weights from: {pretrained_image_path}")
try:
# Load the state dict from the file
image_state_dict = load_state_dict(
pretrained_image_path,
device='cpu',
weights_only=weights_only
)
# Check if model has the 'visual' attribute
if hasattr(model, 'visual'):
# Load into the visual tower, use strict=False for flexibility
incompatible_keys = model.visual.load_state_dict(image_state_dict, strict=False)
logging.info(
f"Loaded image tower weights from {pretrained_image_path}. Incompatible keys: {incompatible_keys}")
pretrained_image_loaded = True # Mark specific image weights as loaded
else:
# Model structure doesn't match expectation
logging.warning(
f"Model does not have a 'visual' attribute, cannot load image tower weights from {pretrained_image_path}")
except Exception as e:
# Handle errors during image tower weight loading
logging.error(f"Error loading image tower weights from {pretrained_image_path}: {e}")
else:
# Path provided is not a valid file
logging.warning(f"Invalid file path specified for pretrained_image_path: {pretrained_image_path}")
pretrained_text_loaded = False # Track if specific text weights loaded
if pretrained_text_path:
if os.path.isfile(pretrained_text_path):
logging.info(f"Attempting to load text tower weights from: {pretrained_text_path}")
try:
# Load the state dict from the file
text_state_dict = load_state_dict(
pretrained_text_path,
device='cpu',
weights_only=weights_only
)
# Safely get the text attribute (usually 'text', but could be different)
text_module = getattr(model, 'text', model)
if text_module is not None:
# Load into the text tower, use strict=False for flexibility
incompatible_keys = text_module.load_state_dict(text_state_dict, strict=False)
logging.info(f"Loaded text tower weights from {pretrained_text_path}. Incompatible keys: {incompatible_keys}")
pretrained_text_loaded = True # Mark specific text weights as loaded
else:
# Model structure doesn't match expectation
logging.warning(f"Model does not have a standard 'text' attribute, cannot load text tower weights from {pretrained_text_path}")
except Exception as e:
# Handle errors during text tower weight loading
logging.error(f"Error loading text tower weights from {pretrained_text_path}: {e}")
else:
# Path provided is not a valid file
logging.warning(f"Invalid file path specified for pretrained_text_path: {pretrained_text_path}")
partially_loaded = enable_default_text_weights or enable_default_image_weights \
or pretrained_image_loaded or pretrained_text_loaded
if require_pretrained and not pretrained_loaded:
# If CLIP weights were required but failed to load, raise an error.
# Loading tower-specific weights does not satisfy `require_pretrained`.
raise RuntimeError(
f"Required pretrained weights (`model_name='{model_name}', pretrained='{pretrained}'`) could not be loaded. "
)
elif not pretrained_loaded and partially_loaded:
# Some tower weights loaded
logging.warning(f"Model {model_name} initialized partially.")
elif not pretrained_loaded and not partially_loaded:
# Absolutely no weights were loaded from any source
logging.warning(f"No pretrained weights loaded for model '{model_name}'. Model initialized randomly.")
if output_dict and hasattr(model, "output_dict"):
# Enable dictionary output if model supports it
model.output_dict = True
# If force_image_size was specified and we have a timm model, call set_input_size after loading weights
if force_image_size is not None and is_timm_model and hasattr(model.visual, 'set_input_size'):
logging.info(f"Calling set_input_size({force_image_size}) on timm vision model.")
model.visual.set_input_size(force_image_size)
if jit:
logging.info("Attempting JIT scripting...")
try:
model = torch.jit.script(model)
logging.info("JIT scripting successful.")
except Exception as e:
logging.warning(f"JIT scripting failed: {e}. Returning non-JIT model.")
# Prepare and set final preprocessing configuration on the model
final_preprocess_cfg = deepcopy(preprocess_cfg) # Start with config determined earlier
# Ensure image_size in preprocess config matches the actual model's visual component size, if possible
visual_module = getattr(model, 'visual', None)
if visual_module is not None and hasattr(visual_module, 'image_size'):
# Update preprocess size from the instantiated visual module
final_preprocess_cfg['size'] = visual_module.image_size
# Apply force_preprocess_cfg overrides (highest priority for preprocessing)
final_preprocess_cfg = merge_preprocess_dict(final_preprocess_cfg, force_preprocess_cfg or {})
# Attach the final config to the model
set_model_preprocess_cfg(model, final_preprocess_cfg)
logging.info(f"Final image preprocessing configuration set: {final_preprocess_cfg}")
# Log completion and return the configured model
logging.info(f"Model {model_name} creation process complete.")
return model
def get_tokenizer(
model_name: str = '',
context_length: Optional[int] = None,
cache_dir: Optional[str] = None,
**kwargs, # Additional tokenizer kwargs passed to constructor
):
"""
Gets the appropriate tokenizer based on the model identifier schema or name.
`model_name` can specify source via schema:
- 'ViT-B-32': Looks up built-in config to determine tokenizer type.
- 'hf-hub:org/repo': Loads config from HF Hub to determine tokenizer type.
- 'local-dir:/path/to/folder': Loads config from local dir to determine tokenizer type.
"""
schema, identifier = parse_model_name(model_name)
config = {} # Stores the loaded model_cfg relevant section (usually text_cfg)
local_dir_path = None # Store path if schema is local-dir to resolve relative paths
hf_fallback_id = None
# Determine Configuration Source based on Schema
logging.info(f"Parsing tokenizer identifier. Schema: {schema}, Identifier: {identifier}")
if schema == 'local-dir':
# Handle local directory schema
local_dir_path = Path(identifier) # Store the path for later use
if not local_dir_path.is_dir():
raise FileNotFoundError(f"Directory specified via 'local-dir:' schema not found at {local_dir_path}")
local_config_path = local_dir_path / 'open_clip_config.json'
logging.info(f"Attempting to load config from local-dir: {local_config_path}")
if local_config_path.is_file():
try:
# Load and parse the JSON config
with open(local_config_path, 'r', encoding='utf-8') as f:
local_json_config = json.load(f)
if 'model_cfg' in local_json_config:
config = local_json_config['model_cfg']
else:
raise ValueError(f"Local config {local_config_path} missing 'model_cfg'.")
except Exception as e:
raise ValueError(f"Could not load valid config for 'local-dir:{identifier}' ({e}).") from e
else:
raise FileNotFoundError(f"'local-dir:' specified, but config file missing: {local_config_path}")
elif schema == 'hf-hub':
# Handle Hugging Face Hub schema
model_id = identifier
logging.info(f"Attempting to load config from hf-hub:{model_id}")
config_err = ''
try:
# Fetch config from HF Hub
hf_config = _get_hf_config(model_id, cache_dir=cache_dir)
config = hf_config.get('model_cfg', None)
if not config:
config_err = 'model_cfg key not found'
except Exception as e:
config_err = str(e)
if not config:
hf_fallback_id = model_id
config = {}
logging.warning(
f"Could not load config from hf-hub:{model_id} ({config_err})."
f"Falling back to using model_id for tokenizer.")
elif schema is None and identifier:
# Try built-in config lookup using the identifier (original model_name)
logging.info(f"Attempting to load config from built-in: {identifier}")
config = get_model_config(identifier)
# Check if config determination failed completely (should only be possible if initial schema parsing failed badly)
if config is None:
logging.warning(f"Model configuration not found, returning default SimpleTokenizer.")
return SimpleTokenizer(context_length=context_length or DEFAULT_CONTEXT_LENGTH, **kwargs)
# Safely access text_cfg even if config is {} (from non-builtin name case)
text_config = config.get('text_cfg', {})
# Resolve context length: argument > config > default
if context_length is None:
# Use context_length from text_cfg if available, otherwise default
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
# Merge tokenizer kwargs: function kwargs override config kwargs
tokenizer_kwargs = text_config.get('tokenizer_kwargs', {}) # Start with config kwargs
tokenizer_kwargs.update(kwargs) # Apply caller kwargs, overriding config ones
# Get the specified HF tokenizer name from config, if any
hf_tokenizer_name = text_config.get('hf_tokenizer_name', '')
if not hf_tokenizer_name and hf_fallback_id:
hf_tokenizer_name = hf_fallback_id
if hf_tokenizer_name:
# If 'hf_tokenizer_name' key exists in text_cfg (even if empty string): Use HFTokenizer.
if schema == 'local-dir':
# If config came from local-dir, ALWAYS use the local dir path for HFTokenizer.
# This assumes the tokenizer files are inside that directory.
tokenizer_source = local_dir_path
else:
tokenizer_source = hf_tokenizer_name
tokenizer_mode = text_config.get('tokenizer_mode', None)
logging.info(f"Using HFTokenizer with source: '{tokenizer_source}', mode: '{tokenizer_mode}'")
tokenizer = HFTokenizer(
tokenizer_source,
context_length=context_length,
cache_dir=cache_dir,
tokenizer_mode=tokenizer_mode,
**tokenizer_kwargs,
)
elif schema is None and 'siglip' in identifier.lower():
# Check for SigLIP naming convention ONLY if no schema was present AND no hf_tokenizer_name found
# Avoids misinterpreting 'local-dir:/path/with/siglip/in/name'
tn_variant = 'gemma' if 'siglip2' in identifier.lower() else 'mc4' if 'i18n' in identifier.lower() else 'c4-en'
logging.info(f"Using SigLipTokenizer variant: {tn_variant}")
tokenizer = SigLipTokenizer(
tn_variant,
context_length=context_length,
)
else:
# Default to SimpleTokenizer if no HF specified and not SigLIP name match
logging.info("Using default SimpleTokenizer.")
tokenizer = SimpleTokenizer(
context_length=context_length,
**tokenizer_kwargs,
)
return tokenizer
def _set_model_device_and_precision(
model: torch.nn.Module,
device: torch.device,
precision: str,
is_timm_model: bool = False
):
if precision in ("fp16", "bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
# manual mixed precision that matches original OpenAI behaviour
if is_timm_model:
from .transformer import LayerNormFp32
# FIXME this is a bit janky, create timm based model in low-precision and
# then cast only LayerNormFp32 instances back to float32 so they don't break.
# Why? The convert_weights_to_lp fn only works with native models.
model.to(device=device, dtype=dtype)
def _convert_ln(m):
if isinstance(m, LayerNormFp32):
m.weight.data = m.weight.data.to(torch.float32)
m.bias.data = m.bias.data.to(torch.float32)
model.apply(_convert_ln)
else:
model.to(device=device)
convert_weights_to_lp(model, dtype=dtype)
elif precision in ("pure_fp16", "pure_bf16"):
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
model.to(device=device, dtype=dtype)
else:
model.to(device=device)
def create_loss(args):
if args.distill:
return DistillClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif "coca" in args.model.lower():
return CoCaLoss(
caption_loss_weight=args.coca_caption_loss_weight,
clip_loss_weight=args.coca_contrastive_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif args.siglip:
assert not args.horovod, "Horovod not currently supported for SigLip"
return SigLipLoss(
rank=args.rank,
world_size=args.world_size,
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
)
return ClipLoss(
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
def create_model_and_transforms(
model_name: str,
pretrained: Optional[str] = None,
load_weights: bool = True,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_patch_dropout: Optional[float] = None,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_context_length: Optional[int] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
pretrained_image: bool = False,
pretrained_text: bool = True,
pretrained_image_path: Optional[str] = None,
pretrained_text_path: Optional[str] = None,
cache_dir: Optional[str] = None,
output_dict: Optional[bool] = None,
weights_only: bool = True,
**model_kwargs,
):
"""
Creates a contrastive vision-language model along with preprocessing transforms for training and validation.
This function combines model creation with the generation of appropriate image preprocessing pipelines,
making it convenient for training workflows where both model and transforms are needed.
`model_name` specifies architecture/config source:
- 'ViT-B-32': Built-in model name. `pretrained` specifies CLIP weights source (tag or file path).
- 'hf-hub:org/repo': Loads config/weights from HF Hub. `pretrained` is IGNORED.
- 'local-dir:/path/to/folder': Loads config/weights from local dir. `pretrained` is IGNORED.
The preprocessing transforms are automatically configured based on the model's requirements,
with separate pipelines for training (with augmentation) and validation (without augmentation).
Args:
model_name: Model identifier, potentially with schema ('hf-hub:', 'local-dir:').
pretrained: Source for CLIP weights (tag or file path) ONLY if model_name has no schema.
load_weights: Load the resolved pretrained weights if True, otherwise random init or tower overrides only.
precision: Model precision ('fp32', 'fp16', 'bf16', ...).
device: Device ('cpu', 'cuda', ...).
jit: If True, JIT compile the model.
force_quick_gelu: Force use of QuickGELU activation in model config.
force_custom_text: Force use of custom text encoder architecture.
force_patch_dropout: Override patch dropout value in model config.
force_image_size: Override image size in model config.
force_context_length: Override context length in model config.
image_mean: Override default image normalization mean values (per channel).
image_std: Override default image normalization std values (per channel).
image_interpolation: Override default interpolation method for image resizing.
image_resize_mode: Override resize mode for inference preprocessing ('squash', 'longest', 'shortest').
aug_cfg: Augmentation configuration for training transforms. Can be dict or AugmentationCfg object.
Controls random crop, color jitter, etc. If None, uses model defaults.
pretrained_image: Load default (timm) base weights for image tower at creation if no CLIP weights loaded.
pretrained_text: Load default (hf) base weights for text tower at creation if no CLIP weights loaded.
pretrained_image_path: Path to load weights specifically into image tower after creation.
pretrained_text_path: Path to load weights specifically into text tower after creation.
cache_dir: Cache directory for downloads.
output_dict: If True and model supports it, return dict output.
weights_only: Use weights_only=True for torch.load (safer).
**model_kwargs: Additional keyword arguments for model constructor (highest override priority).
Returns:
Tuple[torch.nn.Module, Callable, Callable]: A tuple containing:
- model: The created model instance
- preprocess_train: Image preprocessing transform for training (includes augmentation)
- preprocess_val: Image preprocessing transform for validation/inference (no augmentation)
Example:
>>> # Basic usage with built-in model
>>> model, train_transform, val_transform = create_model_and_transforms('ViT-B-32', pretrained='openai')
>>>
>>> # With custom augmentation
>>> aug_cfg = {'scale': (0.9, 1.0), 'ratio': (1.0, 1.0)}
>>> model, train_transform, val_transform = create_model_and_transforms(
... 'ViT-L-14',
... pretrained='datacomp_xl_s13b_b90k',
... aug_cfg=aug_cfg
... )
>>>
>>> # From Hugging Face Hub
>>> model, train_transform, val_transform = create_model_and_transforms('hf-hub:org/model-repo')
Note:
The training transform includes data augmentation based on `aug_cfg`, while the validation
transform performs only the necessary preprocessing (resize, center crop, normalize) without
any random augmentation.
"""
force_preprocess_cfg = merge_preprocess_kwargs(
{},
mean=image_mean,
std=image_std,
interpolation=image_interpolation,
resize_mode=image_resize_mode,
)
model = create_model(
model_name,
pretrained,
load_weights=load_weights,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_patch_dropout=force_patch_dropout,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
force_context_length=force_context_length,
pretrained_image=pretrained_image,
pretrained_text=pretrained_text,
pretrained_image_path=pretrained_image_path,
pretrained_text_path=pretrained_text_path,
cache_dir=cache_dir,
output_dict=output_dict,
weights_only=weights_only,
**model_kwargs,
)
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
preprocess_train = image_transform_v2(
pp_cfg,
is_train=True,
aug_cfg=aug_cfg,
)
preprocess_val = image_transform_v2(
pp_cfg,
is_train=False,
)
return model, preprocess_train, preprocess_val
def create_model_from_pretrained(
model_name: str,
pretrained: Optional[str] = None,
precision: str = 'fp32',
device: Union[str, torch.device] = 'cpu',
jit: bool = False,
force_quick_gelu: bool = False,
force_custom_text: bool = False,
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
force_context_length: Optional[int] = None,
image_mean: Optional[Tuple[float, ...]] = None,
image_std: Optional[Tuple[float, ...]] = None,
image_interpolation: Optional[str] = None,
image_resize_mode: Optional[str] = None, # only effective for inference
return_transform: bool = True,
cache_dir: Optional[str] = None,
weights_only: bool = True,
**model_kwargs,
):
"""
Creates a contrastive vision-language model from pretrained weights with optional preprocessing transform.
This function is a convenience wrapper around `create_model` that enforces loading of pretrained weights
(require_pretrained=True) and optionally returns the appropriate preprocessing transform for inference.
It's designed for use cases where a pretrained model is required, such as feature extraction,
zero-shot classification, or fine-tuning.
`model_name` specifies architecture/config source:
- 'ViT-B-32': Built-in model name. `pretrained` specifies CLIP weights source (tag or file path).
- 'hf-hub:org/repo': Loads config/weights from HF Hub. `pretrained` is IGNORED.
- 'local-dir:/path/to/folder': Loads config/weights from local dir. `pretrained` is IGNORED.
Unlike `create_model`, this function will raise an error if pretrained weights cannot be loaded.
Args:
model_name: Model identifier, potentially with schema ('hf-hub:', 'local-dir:').
pretrained: Source for CLIP weights (tag or file path) ONLY if model_name has no schema.
If None and schema requires it, will raise an error.
precision: Model precision ('fp32', 'fp16', 'bf16', ...).
device: Device ('cpu', 'cuda', ...).
jit: If True, JIT compile the model.
force_quick_gelu: Force use of QuickGELU activation in model config.
force_custom_text: Force use of custom text encoder architecture.
force_image_size: Override image size in model config. Useful for using models at different resolutions.
force_context_length: Override context length in model config.
image_mean: Override default image normalization mean values (per channel).
image_std: Override default image normalization std values (per channel).
image_interpolation: Override default interpolation method for image resizing ('bicubic', 'bilinear', 'nearest').
image_resize_mode: Override resize mode for inference preprocessing ('squash', 'longest', 'shortest').
Only affects the returned preprocessing transform, not training.
return_transform: If True, returns (model, preprocess). If False, returns only model.
cache_dir: Cache directory for downloads.
weights_only: Use weights_only=True for torch.load (safer).
**model_kwargs: Additional keyword arguments for model constructor (highest override priority).
Returns:
Union[torch.nn.Module, Tuple[torch.nn.Module, Callable]]:
- If return_transform=False: Just the model instance
- If return_transform=True: Tuple of (model, preprocess) where preprocess is the
inference preprocessing transform
Raises:
RuntimeError: If pretrained weights are required but cannot be loaded.
Example:
>>> # Load model with preprocessing
>>> model, preprocess = create_model_from_pretrained('ViT-B-32', pretrained='openai')
>>>
>>> # Load model without preprocessing (e.g., when using custom preprocessing)
>>> model = create_model_from_pretrained('ViT-B-32', pretrained='openai', return_transform=False)
>>>
>>> # Load from Hugging Face Hub
>>> model, preprocess = create_model_from_pretrained('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
>>>
>>> # Load with custom image size
>>> model, preprocess = create_model_from_pretrained(
... 'ViT-L-14',
... pretrained='openai',
... force_image_size=336
... )
Note:
This function always requires pretrained weights to be available and loaded successfully.
For cases where you want to create a model without pretrained weights or with only
partial weight loading, use `create_model` or `create_model_and_transforms` instead.
"""
force_preprocess_cfg = merge_preprocess_kwargs(
{},
mean=image_mean,
std=image_std,
interpolation=image_interpolation,
resize_mode=image_resize_mode,
)
model = create_model(
model_name,
pretrained,
precision=precision,
device=device,
jit=jit,
force_quick_gelu=force_quick_gelu,
force_custom_text=force_custom_text,
force_image_size=force_image_size,
force_preprocess_cfg=force_preprocess_cfg,
force_context_length=force_context_length,
cache_dir=cache_dir,
require_pretrained=True,
weights_only=weights_only,
**model_kwargs,
)
if not return_transform:
return model
preprocess = image_transform_v2(
PreprocessCfg(**model.visual.preprocess_cfg),
is_train=False,
)
return model, preprocess