1088 lines
50 KiB
Python
1088 lines
50 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import warnings
|
|
from copy import deepcopy
|
|
from dataclasses import asdict
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from .convert import convert_state_dict
|
|
from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
|
resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg
|
|
from .coca_model import CoCa
|
|
from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss
|
|
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\
|
|
list_pretrained_tags_by_model, download_pretrained_from_hf
|
|
from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs
|
|
from .tokenizer import HFTokenizer, SimpleTokenizer, SigLipTokenizer, DEFAULT_CONTEXT_LENGTH
|
|
|
|
HF_HUB_PREFIX = 'hf-hub:'
|
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
|
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
|
|
|
|
|
def _natural_key(string_):
|
|
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
|
|
|
|
|
def _rescan_model_configs():
|
|
global _MODEL_CONFIGS
|
|
|
|
config_ext = ('.json',)
|
|
config_files = []
|
|
for config_path in _MODEL_CONFIG_PATHS:
|
|
if config_path.is_file() and config_path.suffix in config_ext:
|
|
config_files.append(config_path)
|
|
elif config_path.is_dir():
|
|
for ext in config_ext:
|
|
config_files.extend(config_path.glob(f'*{ext}'))
|
|
|
|
for cf in config_files:
|
|
with open(cf, 'r') as f:
|
|
model_cfg = json.load(f)
|
|
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
|
_MODEL_CONFIGS[cf.stem] = model_cfg
|
|
|
|
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
|
|
|
|
|
|
_rescan_model_configs() # initial populate of model config registry
|
|
|
|
|
|
def list_models():
|
|
""" enumerate available model architectures based on config files """
|
|
return list(_MODEL_CONFIGS.keys())
|
|
|
|
|
|
def add_model_config(path):
|
|
""" add model config path or file and update registry """
|
|
if not isinstance(path, Path):
|
|
path = Path(path)
|
|
_MODEL_CONFIG_PATHS.append(path)
|
|
_rescan_model_configs()
|
|
|
|
|
|
# Define Schema Prefixes as constants
|
|
HF_HUB_PREFIX = 'hf-hub:'
|
|
LOCAL_DIR_PREFIX = 'local-dir:'
|
|
|
|
def parse_model_name(model_name: str) -> Tuple[Optional[str], str]:
|
|
"""
|
|
Parses a model name string to identify a schema and the remaining identifier.
|
|
|
|
Args:
|
|
model_name: The model name string (e.g., 'ViT-B-32',
|
|
'hf-hub:org/repo', 'local-dir:/path/to/dir',
|
|
'local-dir:./relative/path').
|
|
|
|
Returns:
|
|
A tuple (schema, identifier):
|
|
- schema (Optional[str]): 'hf-hub', 'local-dir', or None if no schema detected.
|
|
- identifier (str): The part after the schema prefix, or the original
|
|
string if no schema was present. For 'local-dir',
|
|
this is the raw path string provided.
|
|
Raises:
|
|
ValueError: If a schema prefix is present but the identifier part is empty.
|
|
"""
|
|
# Check for local directory schema first
|
|
if model_name.startswith(LOCAL_DIR_PREFIX):
|
|
# Extract the identifier (path) after the prefix
|
|
identifier = model_name[len(LOCAL_DIR_PREFIX):]
|
|
# Validate that the identifier (path) is not empty
|
|
if not identifier:
|
|
raise ValueError("Empty path specified after 'local-dir:' schema.")
|
|
# Return the schema and the raw path identifier
|
|
# Note: We don't resolve or fully validate the path here,
|
|
# that's left to the calling function (e.g., using os.path.isdir)
|
|
return 'local-dir', identifier
|
|
|
|
# Check for Hugging Face Hub schema
|
|
elif model_name.startswith(HF_HUB_PREFIX):
|
|
# Extract the identifier (HF Hub ID) after the prefix
|
|
identifier = model_name[len(HF_HUB_PREFIX):]
|
|
# Validate that the identifier is not empty
|
|
if not identifier:
|
|
raise ValueError("Empty identifier specified after 'hf-hub:' schema.")
|
|
# Return the schema and the HF Hub ID
|
|
return 'hf-hub', identifier
|
|
|
|
# If neither schema prefix is found
|
|
else:
|
|
# No schema detected, return None for schema and the original string as identifier
|
|
return None, model_name
|
|
|
|
|
|
def _get_hf_config(
|
|
model_id: str,
|
|
cache_dir: Optional[str] = None,
|
|
):
|
|
""" Fetch model config from HuggingFace Hub.
|
|
"""
|
|
config_path = download_pretrained_from_hf(
|
|
model_id,
|
|
filename='open_clip_config.json',
|
|
cache_dir=cache_dir,
|
|
)
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
return config
|
|
|
|
|
|
def get_model_config(model_name):
|
|
""" Fetch model config from schema specified location or local library configs.
|
|
"""
|
|
loc, model_id = parse_model_name(model_name)
|
|
if loc == 'local-dir':
|
|
local_path = Path(model_id) / 'open_clip_config.json'
|
|
with open(local_path, 'r', encoding='utf-8') as f:
|
|
config = json.load(f)
|
|
return config.get('model_cfg', config)
|
|
elif loc == 'hf-hub':
|
|
config = _get_hf_config(model_id)
|
|
return config.get('model_cfg', config)
|
|
elif model_name in _MODEL_CONFIGS:
|
|
return deepcopy(_MODEL_CONFIGS[model_name])
|
|
else:
|
|
return None
|
|
|
|
|
|
def load_state_dict(
|
|
checkpoint_path: str,
|
|
device='cpu',
|
|
weights_only=True,
|
|
):
|
|
# Check if safetensors or not and load weights accordingly
|
|
if str(checkpoint_path).endswith(".safetensors"):
|
|
from safetensors.torch import load_file
|
|
checkpoint = load_file(checkpoint_path, device=device)
|
|
else:
|
|
try:
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
|
|
except TypeError:
|
|
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
state_dict = checkpoint['state_dict']
|
|
elif isinstance(checkpoint, torch.jit.ScriptModule):
|
|
state_dict = checkpoint.state_dict()
|
|
for key in ["input_resolution", "context_length", "vocab_size"]:
|
|
state_dict.pop(key, None)
|
|
else:
|
|
state_dict = checkpoint
|
|
if next(iter(state_dict.items()))[0].startswith('module'):
|
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
|
return state_dict
|
|
|
|
|
|
def load_checkpoint(
|
|
model: Union[CLIP, CustomTextCLIP],
|
|
checkpoint_path: str,
|
|
strict: bool = True,
|
|
weights_only: bool = True,
|
|
device='cpu',
|
|
):
|
|
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
|
|
# Separate path loading numpy big_vision (SigLIP) weights
|
|
from open_clip.convert import load_big_vision_weights
|
|
load_big_vision_weights(model, checkpoint_path)
|
|
return {}
|
|
|
|
state_dict = load_state_dict(checkpoint_path, device=device, weights_only=weights_only)
|
|
|
|
# Detect & convert 3rd party state_dicts -> open_clip
|
|
state_dict = convert_state_dict(model, state_dict)
|
|
|
|
# Detect old format and make compatible with new format
|
|
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
|
state_dict = convert_to_custom_text_state_dict(state_dict)
|
|
|
|
# correct if logit_scale differs in being scaler vs 1d param
|
|
if 'logit_scale' in state_dict and model.logit_scale.ndim != state_dict['logit_scale'].ndim:
|
|
state_dict['logit_scale'] = state_dict['logit_scale'].reshape(model.logit_scale.shape)
|
|
|
|
# correct if logit_bias differs in being scaler vs 1d param
|
|
if 'logit_bias' in state_dict and model.logit_bias.ndim != state_dict['logit_bias'].ndim:
|
|
state_dict['logit_bias'] = state_dict['logit_bias'].reshape(model.logit_bias.shape)
|
|
|
|
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
|
|
if 'logit_bias' not in state_dict and model.logit_bias is not None:
|
|
state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
|
|
|
|
# Certain text transformers no longer expect position_ids after transformers==4.31
|
|
position_id_key = 'text.transformer.embeddings.position_ids'
|
|
if position_id_key in state_dict and not hasattr(model, position_id_key):
|
|
del state_dict[position_id_key]
|
|
|
|
resize_pos_embed(state_dict, model)
|
|
resize_text_pos_embed(state_dict, model)
|
|
|
|
# Finally, load the massaged state_dict into model
|
|
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
|
return incompatible_keys
|
|
|
|
|
|
def _find_checkpoint_in_dir(dir_path: Path) -> Optional[str]:
|
|
checkpoints = list(dir_path.glob('*.safetensors')) + list(dir_path.glob('*.bin')) + list(dir_path.glob('*.pth'))
|
|
if not checkpoints:
|
|
return None
|
|
checkpoints.sort()
|
|
checkpoints.sort(key=lambda x: x.suffix == '.safetensors', reverse=True)
|
|
preferred_order = [
|
|
"open_clip_model.safetensors", "open_clip_pytorch_model.safetensors",
|
|
"open_clip_pytorch_model.bin", "open_clip_pytorch_model.pth",
|
|
"model.safetensors", "pytorch_model.bin", "pytorch_model.pth", "model.pth"
|
|
]
|
|
preferred_checkpoints = [c for c in checkpoints if c.name in preferred_order]
|
|
if preferred_checkpoints:
|
|
preferred_checkpoints.sort(key=lambda x: preferred_order.index(x.name))
|
|
chosen = preferred_checkpoints[0]
|
|
logging.info(f"Found preferred checkpoint file: {chosen.name} in {dir_path}")
|
|
return str(chosen)
|
|
chosen = checkpoints[0]
|
|
logging.warning(
|
|
f"Multiple checkpoints found in {dir_path}: {[c.name for c in checkpoints]}. Using '{chosen.name}'.")
|
|
return str(chosen)
|
|
|
|
|
|
def create_model(
|
|
model_name: str, # Can contain schemas 'hf-hub:' or 'local-dir:'
|
|
pretrained: Optional[str] = None, # Used ONLY if model_name has NO schema
|
|
load_weights: bool = True,
|
|
precision: str = 'fp32',
|
|
device: Union[str, torch.device] = 'cpu',
|
|
jit: bool = False,
|
|
force_quick_gelu: bool = False,
|
|
force_custom_text: bool = False,
|
|
force_patch_dropout: Optional[float] = None,
|
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
force_preprocess_cfg: Optional[Dict[str, Any]] = None,
|
|
force_context_length: Optional[int] = None,
|
|
pretrained_image: bool = False, # Load default base image weights (at creation, if no CLIP weights)
|
|
pretrained_text: bool = True, # Load default base text weights (at creation, if no CLIP weights) - NEW
|
|
pretrained_image_path: Optional[str] = None, # Load specific image weights from file (after creation)
|
|
pretrained_text_path: Optional[str] = None, # Load specific text weights from file (after creation)
|
|
cache_dir: Optional[str] = None,
|
|
output_dict: Optional[bool] = None,
|
|
require_pretrained: bool = False,
|
|
weights_only: bool = True,
|
|
**model_kwargs,
|
|
) -> torch.nn.Module:
|
|
"""
|
|
Creates and configures a contrastive vision-language model.
|
|
|
|
`model_name` specifies architecture/config source:
|
|
- 'ViT-B-32': Built-in model name. `pretrained` specifies CLIP weights source (tag or file path).
|
|
- 'hf-hub:org/repo': Loads config/weights from HF Hub. `pretrained` is IGNORED.
|
|
- 'local-dir:/path/to/folder': Loads config/weights from local dir. `pretrained` is IGNORED.
|
|
|
|
Base tower weights loading controlled by `pretrained_image` and `pretrained_text` flags,
|
|
only effective if no full CLIP checkpoint (`pretrained` or schema source) is loaded.
|
|
|
|
Tower-specific weights can be loaded *after* creation via `pretrained_image_path`
|
|
and `pretrained_text_path`.
|
|
|
|
Args:
|
|
model_name: Model identifier, potentially with schema ('hf-hub:', 'local-dir:').
|
|
pretrained: Source for CLIP weights (tag or file path) ONLY if model_name has no schema.
|
|
load_weights: Load the resolved pretrained weights if True, otherwise random init or tower overrides only.
|
|
precision: Model precision ('fp32', 'fp16', 'bf16', ...).
|
|
device: Device ('cpu', 'cuda', ...).
|
|
jit: If True, JIT compile the model.
|
|
force_quick_gelu: Force use of QuickGELU activation in model config.
|
|
force_custom_text: Force use of custom text encoder architecture.
|
|
force_patch_dropout: Override patch dropout value in model config.
|
|
force_image_size: Override image size in model config.
|
|
force_preprocess_cfg: Dict to override specific FINAL preprocessing parameters.
|
|
force_context_length: Override context length in model config.
|
|
pretrained_image: Load default base weights for image tower at creation if no CLIP weights loaded.
|
|
pretrained_text: Load default base weights for text tower at creation if no CLIP weights loaded (default: True).
|
|
pretrained_image_path: Path to load weights specifically into image tower after creation.
|
|
pretrained_text_path: Path to load weights specifically into text tower after creation.
|
|
cache_dir: Cache directory for downloads.
|
|
output_dict: If True and model supports it, return dict output.
|
|
require_pretrained: Raise error if no `pretrained` CLIP weights loaded when required.
|
|
weights_only: Use weights_only=True for torch.load (safer).
|
|
**model_kwargs: Additional keyword arguments for model constructor (highest override priority).
|
|
|
|
Returns:
|
|
The created model instance.
|
|
"""
|
|
schema, identifier = parse_model_name(model_name)
|
|
if 'pretrained_hf' in model_kwargs:
|
|
# for backwards compat, override pretrained_text
|
|
pretrained_text = model_kwargs.pop('pretrained_hf')
|
|
if isinstance(device, str):
|
|
device = torch.device(device)
|
|
|
|
model_cfg = None
|
|
preprocess_cfg = asdict(PreprocessCfg()) # Populate with defaults
|
|
checkpoint_path = None # Final path for full CLIP weights
|
|
pretrained_cfg_for_tag = None # Store tag config if pretrained is a tag and schema is None
|
|
|
|
logging.info(f"Parsing model identifier. Schema: {schema}, Identifier: {identifier}")
|
|
if schema and pretrained:
|
|
logging.warning(f"Ignoring `pretrained='{pretrained}'` because `model_name` has '{schema}' schema.")
|
|
pretrained = None # Nullify pretrained as it's ignored
|
|
|
|
# Handle schemas first - these ignore the `pretrained` argument
|
|
if schema == 'local-dir':
|
|
# Handle local directory schema
|
|
local_path = Path(identifier)
|
|
if not local_path.is_dir():
|
|
raise FileNotFoundError(f"Directory specified via 'local-dir:' schema not found: {local_path}")
|
|
|
|
local_config_path = local_path / 'open_clip_config.json'
|
|
logging.info(f"Attempting to load config from local dir: {local_config_path}")
|
|
if local_config_path.is_file():
|
|
try:
|
|
# Try loading and parsing the JSON config
|
|
with open(local_config_path, 'r', encoding='utf-8') as f:
|
|
local_json_config = json.load(f)
|
|
# Check if the required 'model_cfg' key is present
|
|
if 'model_cfg' in local_json_config:
|
|
# Load model config and merge preprocess config
|
|
model_cfg = local_json_config['model_cfg']
|
|
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, local_json_config.get('preprocess_cfg', {}))
|
|
logging.info(f"Loaded model config and preprocess from: {local_config_path}")
|
|
# Look for weights checkpoint in the same directory
|
|
checkpoint_path = _find_checkpoint_in_dir(local_path)
|
|
if checkpoint_path:
|
|
logging.info(f"Found CLIP weights in local folder: {checkpoint_path}")
|
|
else:
|
|
logging.warning(f"Local config loaded, but no CLIP weights found in {local_path}")
|
|
else:
|
|
# Config file exists but lacks the necessary key
|
|
raise ValueError(f"Local config {local_config_path} missing 'model_cfg'.")
|
|
except Exception as e:
|
|
# Handle JSON parsing errors or other exceptions during config load
|
|
raise ValueError(f"Could not load valid config from specified 'local-dir:{identifier}': {e}") from e
|
|
else:
|
|
# Directory exists but the config file is missing
|
|
raise FileNotFoundError(f"'local-dir:' specified, but config file missing: {local_config_path}")
|
|
|
|
elif schema == 'hf-hub':
|
|
# Handle Hugging Face Hub schema
|
|
model_id = identifier
|
|
logging.info(f"Attempting to load config from HF Hub: {model_id}")
|
|
try:
|
|
# Fetch configuration from Hugging Face Hub
|
|
hf_config = _get_hf_config(model_id, cache_dir=cache_dir)
|
|
if 'model_cfg' not in hf_config:
|
|
raise RuntimeError(f"'model_cfg' not found in config from {model_id}")
|
|
# Load model config and merge preprocess config
|
|
model_cfg = hf_config['model_cfg']
|
|
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, hf_config.get('preprocess_cfg', {}))
|
|
logging.info(f"Loaded model config from HF Hub: {model_id}")
|
|
# Attempt find default weights file from the Hub repo
|
|
try:
|
|
checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
|
logging.info(f"Found default weights file on HF Hub: {checkpoint_path}")
|
|
except Exception as e_weights:
|
|
# Log warning if weights download fails, but proceed (might only need config)
|
|
logging.warning(f"Could not find/download default weights on HF Hub for {model_id}: {e_weights}")
|
|
except Exception as e_config:
|
|
# Handle errors during config fetching from HF Hub
|
|
raise RuntimeError(f"Failed initial config/weights load from HF Hub {model_id}: {e_config}") from e_config
|
|
|
|
# No Schema Prefix - Use built-in name + pretrained arg (tag or file)
|
|
elif schema is None:
|
|
# Handle model names without schema prefix
|
|
# Use identifier (original model_name) and clean it for lookup
|
|
model_name_cleaned = identifier.replace('/', '-')
|
|
|
|
# Get base config from built-in name using the cleaned identifier
|
|
model_cfg = get_model_config(model_name_cleaned)
|
|
if model_cfg is None:
|
|
# Raise error if no matching built-in config found
|
|
raise RuntimeError(
|
|
f"Model config for '{model_name_cleaned}' not found in built-ins. Available: {list_models()}")
|
|
logging.info(f"Loaded built-in {model_name_cleaned} model config.")
|
|
|
|
# Determine checkpoint path and update preprocess_cfg based on `pretrained` arg (tag or file)
|
|
if pretrained:
|
|
# Check if `pretrained` is a known tag
|
|
pretrained_cfg_for_tag = get_pretrained_cfg(model_name_cleaned, pretrained)
|
|
if pretrained_cfg_for_tag:
|
|
try:
|
|
# Download weights associated with the tag
|
|
checkpoint_path = download_pretrained(pretrained_cfg_for_tag, cache_dir=cache_dir)
|
|
preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg_for_tag)
|
|
# QuickGELU compatibility check will happen in after force overrides
|
|
except Exception as e:
|
|
logging.error(f"Failed to download weights for tag '{pretrained}': {e}")
|
|
raise RuntimeError(f"Failed to download weights for tag '{pretrained}': {e}")
|
|
elif os.path.isfile(pretrained):
|
|
# Handle pretrained file path
|
|
logging.info(f"`pretrained` specifies file path: {pretrained}")
|
|
checkpoint_path = pretrained
|
|
else:
|
|
logging.error(
|
|
f"Pretrained tag or path ({pretrained}) for '{model_name_cleaned}' not found. "
|
|
f"Available tags: {list_pretrained_tags_by_model(model_name_cleaned)}"
|
|
)
|
|
raise RuntimeError(f"Pretrained value '{pretrained}' is not a known tag or valid file path")
|
|
|
|
# Apply model config overrides
|
|
if model_cfg is None:
|
|
raise RuntimeError("Model configuration could not be determined after Stage 1.")
|
|
text_cfg = model_cfg['text_cfg']
|
|
vision_cfg = model_cfg['vision_cfg']
|
|
if force_quick_gelu:
|
|
model_cfg["quick_gelu"] = True
|
|
if force_patch_dropout is not None:
|
|
vision_cfg["patch_dropout"] = force_patch_dropout
|
|
if force_image_size is not None:
|
|
vision_cfg["image_size"] = force_image_size
|
|
if force_context_length is not None:
|
|
text_cfg["context_length"] = force_context_length
|
|
|
|
# Check compatibility (e.g., QuickGELU warning for tags)
|
|
if schema is None and pretrained_cfg_for_tag:
|
|
# Only perform check if config came from built-in and weights from a tag
|
|
model_quick_gelu = model_cfg.get('quick_gelu', False) # Check the potentially overridden value
|
|
tag_quick_gelu = pretrained_cfg_for_tag.get('quick_gelu', False)
|
|
if tag_quick_gelu != model_quick_gelu:
|
|
# Warn if the final model config's GELU setting mismatches the tag's training setting
|
|
warnings.warn(
|
|
f"QuickGELU mismatch between final model config (quick_gelu={model_quick_gelu}) "
|
|
f"and pretrained tag '{pretrained}' (quick_gelu={tag_quick_gelu}).",
|
|
UserWarning
|
|
)
|
|
|
|
# Decide whether to use the checkpoint path based on load_weights
|
|
if checkpoint_path is not None:
|
|
if not load_weights:
|
|
logging.info(
|
|
f"Potential checkpoint path '{checkpoint_path}' found, but skipping assignment due to load_weights=False.")
|
|
checkpoint_path = None
|
|
else:
|
|
logging.info("No potential checkpoint path found from config source or pretrained arg.")
|
|
|
|
# Set default base weight loading flags for image and text towers
|
|
# Only load base pretrained weights if other weights will not be loaded into respective towers
|
|
enable_default_image_weights = pretrained_image and pretrained_image_path is None and checkpoint_path is None
|
|
enable_default_text_weights = pretrained_text and pretrained_text_path is None and checkpoint_path is None
|
|
is_timm_model = 'timm_model_name' in model_cfg.get("vision_cfg", {})
|
|
is_hf_text_model = 'hf_model_name' in model_cfg.get('text_cfg', {})
|
|
if is_timm_model:
|
|
vision_cfg['timm_model_pretrained'] = enable_default_image_weights
|
|
else:
|
|
enable_default_image_weights = False # for accurate logging
|
|
if is_hf_text_model:
|
|
text_cfg['hf_model_pretrained'] = enable_default_text_weights
|
|
else:
|
|
enable_default_text_weights = False # for accurate logging
|
|
|
|
# Determine model class (CLIP, CustomTextCLIP, CoCa)
|
|
custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_text_model
|
|
if custom_text:
|
|
# Use CustomTextCLIP (or CoCa if multimodal_cfg is present)
|
|
if "multimodal_cfg" in model_cfg:
|
|
model_class = CoCa
|
|
else:
|
|
model_class = CustomTextCLIP
|
|
else:
|
|
# Default to standard CLIP
|
|
model_class = CLIP
|
|
|
|
# Apply final **kwargs overrides (highest priority) to a copy of model_cfg
|
|
final_model_cfg = deepcopy(model_cfg)
|
|
final_model_cfg.update(model_kwargs)
|
|
|
|
# Get casting dtype based on precision argument
|
|
cast_dtype = get_cast_dtype(precision)
|
|
|
|
# Instantiate the model
|
|
logging.info(f"Instantiating model architecture: {model_class.__name__}")
|
|
model = model_class(**final_model_cfg, cast_dtype=cast_dtype)
|
|
_set_model_device_and_precision(model, device, precision, is_timm_model)
|
|
|
|
# Load Full Pretrained CLIP Weights (if path exists)
|
|
pretrained_loaded = False
|
|
if checkpoint_path:
|
|
logging.info(f'Loading full pretrained weights from: {checkpoint_path}')
|
|
# Use the load_checkpoint helper which handles state dict loading, conversions, etc.
|
|
# Use strict=True by default for full model loading to catch mismatches.
|
|
load_checkpoint(
|
|
model,
|
|
checkpoint_path,
|
|
strict=True,
|
|
weights_only=weights_only,
|
|
device='cpu' # Load to CPU first
|
|
)
|
|
pretrained_loaded = True
|
|
|
|
# Load tower-specific weights (image and text), after the full CLIP checkpoint, potentially overwriting parts.
|
|
pretrained_image_loaded = False # Track if specific image weights loaded
|
|
if pretrained_image_path:
|
|
if os.path.isfile(pretrained_image_path):
|
|
logging.info(f"Attempting to load image tower weights from: {pretrained_image_path}")
|
|
try:
|
|
# Load the state dict from the file
|
|
image_state_dict = load_state_dict(
|
|
pretrained_image_path,
|
|
device='cpu',
|
|
weights_only=weights_only
|
|
)
|
|
# Check if model has the 'visual' attribute
|
|
if hasattr(model, 'visual'):
|
|
# Load into the visual tower, use strict=False for flexibility
|
|
incompatible_keys = model.visual.load_state_dict(image_state_dict, strict=False)
|
|
logging.info(
|
|
f"Loaded image tower weights from {pretrained_image_path}. Incompatible keys: {incompatible_keys}")
|
|
pretrained_image_loaded = True # Mark specific image weights as loaded
|
|
else:
|
|
# Model structure doesn't match expectation
|
|
logging.warning(
|
|
f"Model does not have a 'visual' attribute, cannot load image tower weights from {pretrained_image_path}")
|
|
except Exception as e:
|
|
# Handle errors during image tower weight loading
|
|
logging.error(f"Error loading image tower weights from {pretrained_image_path}: {e}")
|
|
else:
|
|
# Path provided is not a valid file
|
|
logging.warning(f"Invalid file path specified for pretrained_image_path: {pretrained_image_path}")
|
|
|
|
pretrained_text_loaded = False # Track if specific text weights loaded
|
|
if pretrained_text_path:
|
|
if os.path.isfile(pretrained_text_path):
|
|
logging.info(f"Attempting to load text tower weights from: {pretrained_text_path}")
|
|
try:
|
|
# Load the state dict from the file
|
|
text_state_dict = load_state_dict(
|
|
pretrained_text_path,
|
|
device='cpu',
|
|
weights_only=weights_only
|
|
)
|
|
# Safely get the text attribute (usually 'text', but could be different)
|
|
text_module = getattr(model, 'text', model)
|
|
if text_module is not None:
|
|
# Load into the text tower, use strict=False for flexibility
|
|
incompatible_keys = text_module.load_state_dict(text_state_dict, strict=False)
|
|
logging.info(f"Loaded text tower weights from {pretrained_text_path}. Incompatible keys: {incompatible_keys}")
|
|
pretrained_text_loaded = True # Mark specific text weights as loaded
|
|
else:
|
|
# Model structure doesn't match expectation
|
|
logging.warning(f"Model does not have a standard 'text' attribute, cannot load text tower weights from {pretrained_text_path}")
|
|
except Exception as e:
|
|
# Handle errors during text tower weight loading
|
|
logging.error(f"Error loading text tower weights from {pretrained_text_path}: {e}")
|
|
else:
|
|
# Path provided is not a valid file
|
|
logging.warning(f"Invalid file path specified for pretrained_text_path: {pretrained_text_path}")
|
|
|
|
partially_loaded = enable_default_text_weights or enable_default_image_weights \
|
|
or pretrained_image_loaded or pretrained_text_loaded
|
|
if require_pretrained and not pretrained_loaded:
|
|
# If CLIP weights were required but failed to load, raise an error.
|
|
# Loading tower-specific weights does not satisfy `require_pretrained`.
|
|
raise RuntimeError(
|
|
f"Required pretrained weights (`model_name='{model_name}', pretrained='{pretrained}'`) could not be loaded. "
|
|
)
|
|
elif not pretrained_loaded and partially_loaded:
|
|
# Some tower weights loaded
|
|
logging.warning(f"Model {model_name} initialized partially.")
|
|
elif not pretrained_loaded and not partially_loaded:
|
|
# Absolutely no weights were loaded from any source
|
|
logging.warning(f"No pretrained weights loaded for model '{model_name}'. Model initialized randomly.")
|
|
|
|
if output_dict and hasattr(model, "output_dict"):
|
|
# Enable dictionary output if model supports it
|
|
model.output_dict = True
|
|
|
|
# If force_image_size was specified and we have a timm model, call set_input_size after loading weights
|
|
if force_image_size is not None and is_timm_model and hasattr(model.visual, 'set_input_size'):
|
|
logging.info(f"Calling set_input_size({force_image_size}) on timm vision model.")
|
|
model.visual.set_input_size(force_image_size)
|
|
|
|
if jit:
|
|
logging.info("Attempting JIT scripting...")
|
|
try:
|
|
model = torch.jit.script(model)
|
|
logging.info("JIT scripting successful.")
|
|
except Exception as e:
|
|
logging.warning(f"JIT scripting failed: {e}. Returning non-JIT model.")
|
|
|
|
# Prepare and set final preprocessing configuration on the model
|
|
final_preprocess_cfg = deepcopy(preprocess_cfg) # Start with config determined earlier
|
|
# Ensure image_size in preprocess config matches the actual model's visual component size, if possible
|
|
visual_module = getattr(model, 'visual', None)
|
|
if visual_module is not None and hasattr(visual_module, 'image_size'):
|
|
# Update preprocess size from the instantiated visual module
|
|
final_preprocess_cfg['size'] = visual_module.image_size
|
|
# Apply force_preprocess_cfg overrides (highest priority for preprocessing)
|
|
final_preprocess_cfg = merge_preprocess_dict(final_preprocess_cfg, force_preprocess_cfg or {})
|
|
|
|
# Attach the final config to the model
|
|
set_model_preprocess_cfg(model, final_preprocess_cfg)
|
|
logging.info(f"Final image preprocessing configuration set: {final_preprocess_cfg}")
|
|
|
|
# Log completion and return the configured model
|
|
logging.info(f"Model {model_name} creation process complete.")
|
|
return model
|
|
|
|
|
|
def get_tokenizer(
|
|
model_name: str = '',
|
|
context_length: Optional[int] = None,
|
|
cache_dir: Optional[str] = None,
|
|
**kwargs, # Additional tokenizer kwargs passed to constructor
|
|
):
|
|
"""
|
|
Gets the appropriate tokenizer based on the model identifier schema or name.
|
|
|
|
`model_name` can specify source via schema:
|
|
- 'ViT-B-32': Looks up built-in config to determine tokenizer type.
|
|
- 'hf-hub:org/repo': Loads config from HF Hub to determine tokenizer type.
|
|
- 'local-dir:/path/to/folder': Loads config from local dir to determine tokenizer type.
|
|
"""
|
|
schema, identifier = parse_model_name(model_name)
|
|
|
|
config = {} # Stores the loaded model_cfg relevant section (usually text_cfg)
|
|
local_dir_path = None # Store path if schema is local-dir to resolve relative paths
|
|
hf_fallback_id = None
|
|
|
|
# Determine Configuration Source based on Schema
|
|
logging.info(f"Parsing tokenizer identifier. Schema: {schema}, Identifier: {identifier}")
|
|
|
|
if schema == 'local-dir':
|
|
# Handle local directory schema
|
|
local_dir_path = Path(identifier) # Store the path for later use
|
|
if not local_dir_path.is_dir():
|
|
raise FileNotFoundError(f"Directory specified via 'local-dir:' schema not found at {local_dir_path}")
|
|
local_config_path = local_dir_path / 'open_clip_config.json'
|
|
logging.info(f"Attempting to load config from local-dir: {local_config_path}")
|
|
if local_config_path.is_file():
|
|
try:
|
|
# Load and parse the JSON config
|
|
with open(local_config_path, 'r', encoding='utf-8') as f:
|
|
local_json_config = json.load(f)
|
|
if 'model_cfg' in local_json_config:
|
|
config = local_json_config['model_cfg']
|
|
else:
|
|
raise ValueError(f"Local config {local_config_path} missing 'model_cfg'.")
|
|
except Exception as e:
|
|
raise ValueError(f"Could not load valid config for 'local-dir:{identifier}' ({e}).") from e
|
|
else:
|
|
raise FileNotFoundError(f"'local-dir:' specified, but config file missing: {local_config_path}")
|
|
|
|
elif schema == 'hf-hub':
|
|
# Handle Hugging Face Hub schema
|
|
model_id = identifier
|
|
logging.info(f"Attempting to load config from hf-hub:{model_id}")
|
|
config_err = ''
|
|
try:
|
|
# Fetch config from HF Hub
|
|
hf_config = _get_hf_config(model_id, cache_dir=cache_dir)
|
|
config = hf_config.get('model_cfg', None)
|
|
if not config:
|
|
config_err = 'model_cfg key not found'
|
|
except Exception as e:
|
|
config_err = str(e)
|
|
if not config:
|
|
hf_fallback_id = model_id
|
|
config = {}
|
|
logging.warning(
|
|
f"Could not load config from hf-hub:{model_id} ({config_err})."
|
|
f"Falling back to using model_id for tokenizer.")
|
|
|
|
elif schema is None and identifier:
|
|
# Try built-in config lookup using the identifier (original model_name)
|
|
logging.info(f"Attempting to load config from built-in: {identifier}")
|
|
config = get_model_config(identifier)
|
|
|
|
# Check if config determination failed completely (should only be possible if initial schema parsing failed badly)
|
|
if config is None:
|
|
logging.warning(f"Model configuration not found, returning default SimpleTokenizer.")
|
|
return SimpleTokenizer(context_length=context_length or DEFAULT_CONTEXT_LENGTH, **kwargs)
|
|
|
|
# Safely access text_cfg even if config is {} (from non-builtin name case)
|
|
text_config = config.get('text_cfg', {})
|
|
|
|
# Resolve context length: argument > config > default
|
|
if context_length is None:
|
|
# Use context_length from text_cfg if available, otherwise default
|
|
context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH)
|
|
|
|
# Merge tokenizer kwargs: function kwargs override config kwargs
|
|
tokenizer_kwargs = text_config.get('tokenizer_kwargs', {}) # Start with config kwargs
|
|
tokenizer_kwargs.update(kwargs) # Apply caller kwargs, overriding config ones
|
|
|
|
# Get the specified HF tokenizer name from config, if any
|
|
hf_tokenizer_name = text_config.get('hf_tokenizer_name', '')
|
|
if not hf_tokenizer_name and hf_fallback_id:
|
|
hf_tokenizer_name = hf_fallback_id
|
|
|
|
if hf_tokenizer_name:
|
|
# If 'hf_tokenizer_name' key exists in text_cfg (even if empty string): Use HFTokenizer.
|
|
if schema == 'local-dir':
|
|
# If config came from local-dir, ALWAYS use the local dir path for HFTokenizer.
|
|
# This assumes the tokenizer files are inside that directory.
|
|
tokenizer_source = local_dir_path
|
|
else:
|
|
tokenizer_source = hf_tokenizer_name
|
|
tokenizer_mode = text_config.get('tokenizer_mode', None)
|
|
|
|
logging.info(f"Using HFTokenizer with source: '{tokenizer_source}', mode: '{tokenizer_mode}'")
|
|
tokenizer = HFTokenizer(
|
|
tokenizer_source,
|
|
context_length=context_length,
|
|
cache_dir=cache_dir,
|
|
tokenizer_mode=tokenizer_mode,
|
|
**tokenizer_kwargs,
|
|
)
|
|
|
|
elif schema is None and 'siglip' in identifier.lower():
|
|
# Check for SigLIP naming convention ONLY if no schema was present AND no hf_tokenizer_name found
|
|
# Avoids misinterpreting 'local-dir:/path/with/siglip/in/name'
|
|
tn_variant = 'gemma' if 'siglip2' in identifier.lower() else 'mc4' if 'i18n' in identifier.lower() else 'c4-en'
|
|
logging.info(f"Using SigLipTokenizer variant: {tn_variant}")
|
|
tokenizer = SigLipTokenizer(
|
|
tn_variant,
|
|
context_length=context_length,
|
|
)
|
|
else:
|
|
# Default to SimpleTokenizer if no HF specified and not SigLIP name match
|
|
logging.info("Using default SimpleTokenizer.")
|
|
tokenizer = SimpleTokenizer(
|
|
context_length=context_length,
|
|
**tokenizer_kwargs,
|
|
)
|
|
|
|
return tokenizer
|
|
|
|
|
|
def _set_model_device_and_precision(
|
|
model: torch.nn.Module,
|
|
device: torch.device,
|
|
precision: str,
|
|
is_timm_model: bool = False
|
|
):
|
|
if precision in ("fp16", "bf16"):
|
|
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
|
# manual mixed precision that matches original OpenAI behaviour
|
|
if is_timm_model:
|
|
from .transformer import LayerNormFp32
|
|
# FIXME this is a bit janky, create timm based model in low-precision and
|
|
# then cast only LayerNormFp32 instances back to float32 so they don't break.
|
|
# Why? The convert_weights_to_lp fn only works with native models.
|
|
model.to(device=device, dtype=dtype)
|
|
|
|
def _convert_ln(m):
|
|
if isinstance(m, LayerNormFp32):
|
|
m.weight.data = m.weight.data.to(torch.float32)
|
|
m.bias.data = m.bias.data.to(torch.float32)
|
|
|
|
model.apply(_convert_ln)
|
|
else:
|
|
model.to(device=device)
|
|
convert_weights_to_lp(model, dtype=dtype)
|
|
elif precision in ("pure_fp16", "pure_bf16"):
|
|
dtype = torch.float16 if 'fp16' in precision else torch.bfloat16
|
|
model.to(device=device, dtype=dtype)
|
|
else:
|
|
model.to(device=device)
|
|
|
|
|
|
def create_loss(args):
|
|
if args.distill:
|
|
return DistillClipLoss(
|
|
local_loss=args.local_loss,
|
|
gather_with_grad=args.gather_with_grad,
|
|
cache_labels=True,
|
|
rank=args.rank,
|
|
world_size=args.world_size,
|
|
use_horovod=args.horovod,
|
|
)
|
|
elif "coca" in args.model.lower():
|
|
return CoCaLoss(
|
|
caption_loss_weight=args.coca_caption_loss_weight,
|
|
clip_loss_weight=args.coca_contrastive_loss_weight,
|
|
local_loss=args.local_loss,
|
|
gather_with_grad=args.gather_with_grad,
|
|
cache_labels=True,
|
|
rank=args.rank,
|
|
world_size=args.world_size,
|
|
use_horovod=args.horovod,
|
|
)
|
|
elif args.siglip:
|
|
assert not args.horovod, "Horovod not currently supported for SigLip"
|
|
return SigLipLoss(
|
|
rank=args.rank,
|
|
world_size=args.world_size,
|
|
dist_impl=args.loss_dist_impl, # siglip has multiple distributed implementations to choose from
|
|
)
|
|
|
|
return ClipLoss(
|
|
local_loss=args.local_loss,
|
|
gather_with_grad=args.gather_with_grad,
|
|
cache_labels=True,
|
|
rank=args.rank,
|
|
world_size=args.world_size,
|
|
use_horovod=args.horovod,
|
|
)
|
|
|
|
|
|
def create_model_and_transforms(
|
|
model_name: str,
|
|
pretrained: Optional[str] = None,
|
|
load_weights: bool = True,
|
|
precision: str = 'fp32',
|
|
device: Union[str, torch.device] = 'cpu',
|
|
jit: bool = False,
|
|
force_quick_gelu: bool = False,
|
|
force_custom_text: bool = False,
|
|
force_patch_dropout: Optional[float] = None,
|
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
force_context_length: Optional[int] = None,
|
|
image_mean: Optional[Tuple[float, ...]] = None,
|
|
image_std: Optional[Tuple[float, ...]] = None,
|
|
image_interpolation: Optional[str] = None,
|
|
image_resize_mode: Optional[str] = None, # only effective for inference
|
|
aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
|
|
pretrained_image: bool = False,
|
|
pretrained_text: bool = True,
|
|
pretrained_image_path: Optional[str] = None,
|
|
pretrained_text_path: Optional[str] = None,
|
|
cache_dir: Optional[str] = None,
|
|
output_dict: Optional[bool] = None,
|
|
weights_only: bool = True,
|
|
**model_kwargs,
|
|
):
|
|
"""
|
|
Creates a contrastive vision-language model along with preprocessing transforms for training and validation.
|
|
|
|
This function combines model creation with the generation of appropriate image preprocessing pipelines,
|
|
making it convenient for training workflows where both model and transforms are needed.
|
|
|
|
`model_name` specifies architecture/config source:
|
|
- 'ViT-B-32': Built-in model name. `pretrained` specifies CLIP weights source (tag or file path).
|
|
- 'hf-hub:org/repo': Loads config/weights from HF Hub. `pretrained` is IGNORED.
|
|
- 'local-dir:/path/to/folder': Loads config/weights from local dir. `pretrained` is IGNORED.
|
|
|
|
The preprocessing transforms are automatically configured based on the model's requirements,
|
|
with separate pipelines for training (with augmentation) and validation (without augmentation).
|
|
|
|
Args:
|
|
model_name: Model identifier, potentially with schema ('hf-hub:', 'local-dir:').
|
|
pretrained: Source for CLIP weights (tag or file path) ONLY if model_name has no schema.
|
|
load_weights: Load the resolved pretrained weights if True, otherwise random init or tower overrides only.
|
|
precision: Model precision ('fp32', 'fp16', 'bf16', ...).
|
|
device: Device ('cpu', 'cuda', ...).
|
|
jit: If True, JIT compile the model.
|
|
force_quick_gelu: Force use of QuickGELU activation in model config.
|
|
force_custom_text: Force use of custom text encoder architecture.
|
|
force_patch_dropout: Override patch dropout value in model config.
|
|
force_image_size: Override image size in model config.
|
|
force_context_length: Override context length in model config.
|
|
image_mean: Override default image normalization mean values (per channel).
|
|
image_std: Override default image normalization std values (per channel).
|
|
image_interpolation: Override default interpolation method for image resizing.
|
|
image_resize_mode: Override resize mode for inference preprocessing ('squash', 'longest', 'shortest').
|
|
aug_cfg: Augmentation configuration for training transforms. Can be dict or AugmentationCfg object.
|
|
Controls random crop, color jitter, etc. If None, uses model defaults.
|
|
pretrained_image: Load default (timm) base weights for image tower at creation if no CLIP weights loaded.
|
|
pretrained_text: Load default (hf) base weights for text tower at creation if no CLIP weights loaded.
|
|
pretrained_image_path: Path to load weights specifically into image tower after creation.
|
|
pretrained_text_path: Path to load weights specifically into text tower after creation.
|
|
cache_dir: Cache directory for downloads.
|
|
output_dict: If True and model supports it, return dict output.
|
|
weights_only: Use weights_only=True for torch.load (safer).
|
|
**model_kwargs: Additional keyword arguments for model constructor (highest override priority).
|
|
|
|
Returns:
|
|
Tuple[torch.nn.Module, Callable, Callable]: A tuple containing:
|
|
- model: The created model instance
|
|
- preprocess_train: Image preprocessing transform for training (includes augmentation)
|
|
- preprocess_val: Image preprocessing transform for validation/inference (no augmentation)
|
|
|
|
Example:
|
|
>>> # Basic usage with built-in model
|
|
>>> model, train_transform, val_transform = create_model_and_transforms('ViT-B-32', pretrained='openai')
|
|
>>>
|
|
>>> # With custom augmentation
|
|
>>> aug_cfg = {'scale': (0.9, 1.0), 'ratio': (1.0, 1.0)}
|
|
>>> model, train_transform, val_transform = create_model_and_transforms(
|
|
... 'ViT-L-14',
|
|
... pretrained='datacomp_xl_s13b_b90k',
|
|
... aug_cfg=aug_cfg
|
|
... )
|
|
>>>
|
|
>>> # From Hugging Face Hub
|
|
>>> model, train_transform, val_transform = create_model_and_transforms('hf-hub:org/model-repo')
|
|
|
|
Note:
|
|
The training transform includes data augmentation based on `aug_cfg`, while the validation
|
|
transform performs only the necessary preprocessing (resize, center crop, normalize) without
|
|
any random augmentation.
|
|
"""
|
|
force_preprocess_cfg = merge_preprocess_kwargs(
|
|
{},
|
|
mean=image_mean,
|
|
std=image_std,
|
|
interpolation=image_interpolation,
|
|
resize_mode=image_resize_mode,
|
|
)
|
|
|
|
model = create_model(
|
|
model_name,
|
|
pretrained,
|
|
load_weights=load_weights,
|
|
precision=precision,
|
|
device=device,
|
|
jit=jit,
|
|
force_quick_gelu=force_quick_gelu,
|
|
force_custom_text=force_custom_text,
|
|
force_patch_dropout=force_patch_dropout,
|
|
force_image_size=force_image_size,
|
|
force_preprocess_cfg=force_preprocess_cfg,
|
|
force_context_length=force_context_length,
|
|
pretrained_image=pretrained_image,
|
|
pretrained_text=pretrained_text,
|
|
pretrained_image_path=pretrained_image_path,
|
|
pretrained_text_path=pretrained_text_path,
|
|
cache_dir=cache_dir,
|
|
output_dict=output_dict,
|
|
weights_only=weights_only,
|
|
**model_kwargs,
|
|
)
|
|
|
|
pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg)
|
|
|
|
preprocess_train = image_transform_v2(
|
|
pp_cfg,
|
|
is_train=True,
|
|
aug_cfg=aug_cfg,
|
|
)
|
|
preprocess_val = image_transform_v2(
|
|
pp_cfg,
|
|
is_train=False,
|
|
)
|
|
|
|
return model, preprocess_train, preprocess_val
|
|
|
|
|
|
def create_model_from_pretrained(
|
|
model_name: str,
|
|
pretrained: Optional[str] = None,
|
|
precision: str = 'fp32',
|
|
device: Union[str, torch.device] = 'cpu',
|
|
jit: bool = False,
|
|
force_quick_gelu: bool = False,
|
|
force_custom_text: bool = False,
|
|
force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
|
|
force_context_length: Optional[int] = None,
|
|
image_mean: Optional[Tuple[float, ...]] = None,
|
|
image_std: Optional[Tuple[float, ...]] = None,
|
|
image_interpolation: Optional[str] = None,
|
|
image_resize_mode: Optional[str] = None, # only effective for inference
|
|
return_transform: bool = True,
|
|
cache_dir: Optional[str] = None,
|
|
weights_only: bool = True,
|
|
**model_kwargs,
|
|
):
|
|
"""
|
|
Creates a contrastive vision-language model from pretrained weights with optional preprocessing transform.
|
|
|
|
This function is a convenience wrapper around `create_model` that enforces loading of pretrained weights
|
|
(require_pretrained=True) and optionally returns the appropriate preprocessing transform for inference.
|
|
It's designed for use cases where a pretrained model is required, such as feature extraction,
|
|
zero-shot classification, or fine-tuning.
|
|
|
|
`model_name` specifies architecture/config source:
|
|
- 'ViT-B-32': Built-in model name. `pretrained` specifies CLIP weights source (tag or file path).
|
|
- 'hf-hub:org/repo': Loads config/weights from HF Hub. `pretrained` is IGNORED.
|
|
- 'local-dir:/path/to/folder': Loads config/weights from local dir. `pretrained` is IGNORED.
|
|
|
|
Unlike `create_model`, this function will raise an error if pretrained weights cannot be loaded.
|
|
|
|
Args:
|
|
model_name: Model identifier, potentially with schema ('hf-hub:', 'local-dir:').
|
|
pretrained: Source for CLIP weights (tag or file path) ONLY if model_name has no schema.
|
|
If None and schema requires it, will raise an error.
|
|
precision: Model precision ('fp32', 'fp16', 'bf16', ...).
|
|
device: Device ('cpu', 'cuda', ...).
|
|
jit: If True, JIT compile the model.
|
|
force_quick_gelu: Force use of QuickGELU activation in model config.
|
|
force_custom_text: Force use of custom text encoder architecture.
|
|
force_image_size: Override image size in model config. Useful for using models at different resolutions.
|
|
force_context_length: Override context length in model config.
|
|
image_mean: Override default image normalization mean values (per channel).
|
|
image_std: Override default image normalization std values (per channel).
|
|
image_interpolation: Override default interpolation method for image resizing ('bicubic', 'bilinear', 'nearest').
|
|
image_resize_mode: Override resize mode for inference preprocessing ('squash', 'longest', 'shortest').
|
|
Only affects the returned preprocessing transform, not training.
|
|
return_transform: If True, returns (model, preprocess). If False, returns only model.
|
|
cache_dir: Cache directory for downloads.
|
|
weights_only: Use weights_only=True for torch.load (safer).
|
|
**model_kwargs: Additional keyword arguments for model constructor (highest override priority).
|
|
|
|
Returns:
|
|
Union[torch.nn.Module, Tuple[torch.nn.Module, Callable]]:
|
|
- If return_transform=False: Just the model instance
|
|
- If return_transform=True: Tuple of (model, preprocess) where preprocess is the
|
|
inference preprocessing transform
|
|
|
|
Raises:
|
|
RuntimeError: If pretrained weights are required but cannot be loaded.
|
|
|
|
Example:
|
|
>>> # Load model with preprocessing
|
|
>>> model, preprocess = create_model_from_pretrained('ViT-B-32', pretrained='openai')
|
|
>>>
|
|
>>> # Load model without preprocessing (e.g., when using custom preprocessing)
|
|
>>> model = create_model_from_pretrained('ViT-B-32', pretrained='openai', return_transform=False)
|
|
>>>
|
|
>>> # Load from Hugging Face Hub
|
|
>>> model, preprocess = create_model_from_pretrained('hf-hub:laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K')
|
|
>>>
|
|
>>> # Load with custom image size
|
|
>>> model, preprocess = create_model_from_pretrained(
|
|
... 'ViT-L-14',
|
|
... pretrained='openai',
|
|
... force_image_size=336
|
|
... )
|
|
|
|
Note:
|
|
This function always requires pretrained weights to be available and loaded successfully.
|
|
For cases where you want to create a model without pretrained weights or with only
|
|
partial weight loading, use `create_model` or `create_model_and_transforms` instead.
|
|
"""
|
|
force_preprocess_cfg = merge_preprocess_kwargs(
|
|
{},
|
|
mean=image_mean,
|
|
std=image_std,
|
|
interpolation=image_interpolation,
|
|
resize_mode=image_resize_mode,
|
|
)
|
|
|
|
model = create_model(
|
|
model_name,
|
|
pretrained,
|
|
precision=precision,
|
|
device=device,
|
|
jit=jit,
|
|
force_quick_gelu=force_quick_gelu,
|
|
force_custom_text=force_custom_text,
|
|
force_image_size=force_image_size,
|
|
force_preprocess_cfg=force_preprocess_cfg,
|
|
force_context_length=force_context_length,
|
|
cache_dir=cache_dir,
|
|
require_pretrained=True,
|
|
weights_only=weights_only,
|
|
**model_kwargs,
|
|
)
|
|
|
|
if not return_transform:
|
|
return model
|
|
|
|
preprocess = image_transform_v2(
|
|
PreprocessCfg(**model.visual.preprocess_cfg),
|
|
is_train=False,
|
|
)
|
|
|
|
return model, preprocess
|