143 lines
3.9 KiB
Python
143 lines
3.9 KiB
Python
""" Conv2d w/ Same Padding
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Tuple, Optional, Union
|
|
|
|
from ._fx import register_notrace_module
|
|
from .config import is_exportable, is_scriptable
|
|
from .padding import pad_same, pad_same_arg, get_padding_value
|
|
|
|
|
|
_USE_EXPORT_CONV = False
|
|
|
|
|
|
def conv2d_same(
|
|
x,
|
|
weight: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
stride: Tuple[int, int] = (1, 1),
|
|
padding: Tuple[int, int] = (0, 0),
|
|
dilation: Tuple[int, int] = (1, 1),
|
|
groups: int = 1,
|
|
):
|
|
x = pad_same(x, weight.shape[-2:], stride, dilation)
|
|
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
|
|
|
|
|
@register_notrace_module
|
|
class Conv2dSame(nn.Conv2d):
|
|
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Union[int, Tuple[int, int]],
|
|
stride: Union[int, Tuple[int, int]] = 1,
|
|
padding: Union[int, Tuple[int, int], str] = 0,
|
|
dilation: Union[int, Tuple[int, int]] = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
0, # padding
|
|
dilation,
|
|
groups,
|
|
bias,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
def forward(self, x):
|
|
return conv2d_same(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
|
|
|
|
class Conv2dSameExport(nn.Conv2d):
|
|
""" ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
|
|
|
NOTE: This does not currently work with torch.jit.script
|
|
"""
|
|
|
|
# pylint: disable=unused-argument
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: Union[int, Tuple[int, int]],
|
|
stride: Union[int, Tuple[int, int]] = 1,
|
|
padding: Union[int, Tuple[int, int], str] = 0,
|
|
dilation: Union[int, Tuple[int, int]] = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
device=None,
|
|
dtype=None,
|
|
):
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
0, # padding
|
|
dilation,
|
|
groups,
|
|
bias,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.pad = None
|
|
self.pad_input_size = (0, 0)
|
|
|
|
def forward(self, x):
|
|
input_size = x.size()[-2:]
|
|
if self.pad is None:
|
|
pad_arg = pad_same_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation)
|
|
self.pad = nn.ZeroPad2d(pad_arg)
|
|
self.pad_input_size = input_size
|
|
|
|
x = self.pad(x)
|
|
return F.conv2d(
|
|
x,
|
|
self.weight,
|
|
self.bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
)
|
|
|
|
|
|
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
|
padding = kwargs.pop('padding', '')
|
|
kwargs.setdefault('bias', False)
|
|
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
|
if is_dynamic:
|
|
if _USE_EXPORT_CONV and is_exportable():
|
|
# older PyTorch ver needed this to export same padding reasonably
|
|
assert not is_scriptable() # Conv2DSameExport does not work with jit
|
|
return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs)
|
|
else:
|
|
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
|
else:
|
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
|
|
|
|