89 lines
3.4 KiB
Python
89 lines
3.4 KiB
Python
from itertools import repeat
|
|
import collections.abc
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
|
|
|
|
|
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
|
"""
|
|
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
|
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
|
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
|
|
|
Args:
|
|
module (torch.nn.Module): Any PyTorch module.
|
|
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
|
name (str): Full module name (prefix)
|
|
|
|
Returns:
|
|
torch.nn.Module: Resulting module
|
|
|
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
|
"""
|
|
res = module
|
|
is_match = True
|
|
if module_match:
|
|
is_match = name in module_match
|
|
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
|
res = FrozenBatchNorm2d(module.num_features)
|
|
res.num_features = module.num_features
|
|
res.affine = module.affine
|
|
if module.affine:
|
|
res.weight.data = module.weight.data.clone().detach()
|
|
res.bias.data = module.bias.data.clone().detach()
|
|
res.running_mean.data = module.running_mean.data
|
|
res.running_var.data = module.running_var.data
|
|
res.eps = module.eps
|
|
else:
|
|
for child_name, child in module.named_children():
|
|
full_child_name = '.'.join([name, child_name]) if name else child_name
|
|
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
|
if new_child is not child:
|
|
res.add_module(child_name, new_child)
|
|
return res
|
|
|
|
|
|
# From PyTorch internals
|
|
def _ntuple(n):
|
|
def parse(x):
|
|
if isinstance(x, collections.abc.Iterable):
|
|
return x
|
|
return tuple(repeat(x, n))
|
|
return parse
|
|
|
|
|
|
to_1tuple = _ntuple(1)
|
|
to_2tuple = _ntuple(2)
|
|
to_3tuple = _ntuple(3)
|
|
to_4tuple = _ntuple(4)
|
|
to_ntuple = lambda n, x: _ntuple(n)(x)
|
|
|
|
# Replaces all linear layers with linear_replacement
|
|
# TODO: add int8 support for other linear layers including attn and convnets
|
|
def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True):
|
|
for name, module in model.named_children():
|
|
if len(list(module.children())) > 0:
|
|
replace_linear(module, linear_replacement, include_modules, copy_weights)
|
|
|
|
if isinstance(module, torch.nn.Linear) and name in include_modules:
|
|
old_module = model._modules[name]
|
|
model._modules[name] = linear_replacement(
|
|
module.in_features,
|
|
module.out_features,
|
|
module.bias is not None,
|
|
)
|
|
if copy_weights:
|
|
model._modules[name].weight.data.copy_(old_module.weight.data)
|
|
if model._modules[name].bias is not None:
|
|
model._modules[name].bias.data.copy_(old_module.bias)
|
|
|
|
return model
|
|
|
|
def convert_int8_model_to_inference_mode(model):
|
|
for m in model.modules():
|
|
if hasattr(m, 'prepare_for_eval'):
|
|
int8_original_dtype = m.weight.dtype
|
|
m.prepare_for_eval()
|
|
m.int8_original_dtype = int8_original_dtype |