Files
railseek6/openclip_env/Lib/site-packages/timm/layers/mixed_conv2d.py

69 lines
2.1 KiB
Python

""" PyTorch Mixed Convolution
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import List, Union
import torch
from torch import nn as nn
from .conv2d_same import create_conv2d_pad
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
class MixedConv2d(nn.ModuleDict):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, List[int]] = 3,
stride: int = 1,
padding: str = '',
dilation: int = 1,
depthwise: bool = False,
**kwargs
):
super().__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = in_ch if depthwise else 1
# use add_module to keep key space clean
self.add_module(
str(idx),
create_conv2d_pad(
in_ch,
out_ch,
k,
stride=stride,
padding=padding,
dilation=dilation,
groups=conv_groups,
**kwargs,
)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
x = torch.cat(x_out, 1)
return x