101 lines
3.2 KiB
Python
101 lines
3.2 KiB
Python
""" PyTorch FX Based Feature Extraction Helpers
|
|
Using https://pytorch.org/vision/stable/feature_extraction.html
|
|
"""
|
|
from typing import Callable, Dict, List, Optional, Union, Tuple, Type
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from timm.layers import (
|
|
create_feature_extractor,
|
|
get_graph_node_names,
|
|
register_notrace_module,
|
|
register_notrace_function,
|
|
is_notrace_module,
|
|
is_notrace_function,
|
|
get_notrace_functions,
|
|
get_notrace_modules,
|
|
Format,
|
|
)
|
|
from ._features import _get_feature_info, _get_return_layers
|
|
|
|
|
|
|
|
__all__ = [
|
|
'register_notrace_module',
|
|
'is_notrace_module',
|
|
'get_notrace_modules',
|
|
'register_notrace_function',
|
|
'is_notrace_function',
|
|
'get_notrace_functions',
|
|
'create_feature_extractor',
|
|
'get_graph_node_names',
|
|
'FeatureGraphNet',
|
|
'GraphExtractNet',
|
|
]
|
|
|
|
|
|
class FeatureGraphNet(nn.Module):
|
|
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
|
"""
|
|
return_dict: torch.jit.Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
out_indices: Tuple[int, ...],
|
|
out_map: Optional[Dict] = None,
|
|
output_fmt: str = 'NCHW',
|
|
return_dict: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.feature_info = _get_feature_info(model, out_indices)
|
|
if out_map is not None:
|
|
assert len(out_map) == len(out_indices)
|
|
self.output_fmt = Format(output_fmt)
|
|
return_nodes = _get_return_layers(self.feature_info, out_map)
|
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
|
self.return_dict = return_dict
|
|
|
|
def forward(self, x):
|
|
out = self.graph_module(x)
|
|
if self.return_dict:
|
|
return out
|
|
return list(out.values())
|
|
|
|
|
|
class GraphExtractNet(nn.Module):
|
|
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
|
NOTE:
|
|
* one can use feature_extractor directly if dictionary output is desired
|
|
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
|
metadata for builtin feature extraction mode
|
|
* create_feature_extractor can be used directly if dictionary output is acceptable
|
|
|
|
Args:
|
|
model: model to extract features from
|
|
return_nodes: node names to return features from (dict or list)
|
|
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
|
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
|
|
"""
|
|
return_dict: torch.jit.Final[bool]
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
return_nodes: Union[Dict[str, str], List[str]],
|
|
squeeze_out: bool = True,
|
|
return_dict: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.squeeze_out = squeeze_out
|
|
self.graph_module = create_feature_extractor(model, return_nodes)
|
|
self.return_dict = return_dict
|
|
|
|
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
|
out = self.graph_module(x)
|
|
if self.return_dict:
|
|
return out
|
|
out = list(out.values())
|
|
return out[0] if self.squeeze_out and len(out) == 1 else out
|