81 lines
2.2 KiB
Python
81 lines
2.2 KiB
Python
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
try:
|
|
# NOTE we wrap torchvision fns to use timm leaf / no trace definitions
|
|
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
|
from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
|
|
has_fx_feature_extraction = True
|
|
except ImportError:
|
|
has_fx_feature_extraction = False
|
|
|
|
|
|
__all__ = [
|
|
'register_notrace_module',
|
|
'is_notrace_module',
|
|
'get_notrace_modules',
|
|
'register_notrace_function',
|
|
'is_notrace_function',
|
|
'get_notrace_functions',
|
|
'create_feature_extractor',
|
|
'get_graph_node_names',
|
|
]
|
|
|
|
# modules to treat as leafs when tracing
|
|
_leaf_modules = set()
|
|
|
|
|
|
def register_notrace_module(module: Type[nn.Module]):
|
|
"""
|
|
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
|
"""
|
|
_leaf_modules.add(module)
|
|
return module
|
|
|
|
|
|
def is_notrace_module(module: Type[nn.Module]):
|
|
return module in _leaf_modules
|
|
|
|
|
|
def get_notrace_modules():
|
|
return list(_leaf_modules)
|
|
|
|
|
|
# Functions we want to autowrap (treat them as leaves)
|
|
_autowrap_functions = set()
|
|
|
|
|
|
def register_notrace_function(name_or_fn):
|
|
_autowrap_functions.add(name_or_fn)
|
|
return name_or_fn
|
|
|
|
|
|
def is_notrace_function(func: Callable):
|
|
return func in _autowrap_functions
|
|
|
|
|
|
def get_notrace_functions():
|
|
return list(_autowrap_functions)
|
|
|
|
|
|
def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
|
|
return _get_graph_node_names(
|
|
model,
|
|
tracer_kwargs={
|
|
'leaf_modules': list(_leaf_modules),
|
|
'autowrap_functions': list(_autowrap_functions)
|
|
}
|
|
)
|
|
|
|
|
|
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
|
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
|
return _create_feature_extractor(
|
|
model, return_nodes,
|
|
tracer_kwargs={
|
|
'leaf_modules': list(_leaf_modules),
|
|
'autowrap_functions': list(_autowrap_functions)
|
|
}
|
|
) |