Files

81 lines
2.2 KiB
Python

from typing import Callable, Dict, List, Optional, Union, Tuple, Type
import torch
from torch import nn
try:
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
has_fx_feature_extraction = True
except ImportError:
has_fx_feature_extraction = False
__all__ = [
'register_notrace_module',
'is_notrace_module',
'get_notrace_modules',
'register_notrace_function',
'is_notrace_function',
'get_notrace_functions',
'create_feature_extractor',
'get_graph_node_names',
]
# modules to treat as leafs when tracing
_leaf_modules = set()
def register_notrace_module(module: Type[nn.Module]):
"""
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
"""
_leaf_modules.add(module)
return module
def is_notrace_module(module: Type[nn.Module]):
return module in _leaf_modules
def get_notrace_modules():
return list(_leaf_modules)
# Functions we want to autowrap (treat them as leaves)
_autowrap_functions = set()
def register_notrace_function(name_or_fn):
_autowrap_functions.add(name_or_fn)
return name_or_fn
def is_notrace_function(func: Callable):
return func in _autowrap_functions
def get_notrace_functions():
return list(_autowrap_functions)
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
return _get_graph_node_names(
model,
tracer_kwargs={
'leaf_modules': list(_leaf_modules),
'autowrap_functions': list(_autowrap_functions)
}
)
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
return _create_feature_extractor(
model, return_nodes,
tracer_kwargs={
'leaf_modules': list(_leaf_modules),
'autowrap_functions': list(_autowrap_functions)
}
)