CoAtNet
Overview
This is a PyTorch implementation of CoAtNet specified in "CoAtNet: Marrying Convolution and Attention for All Data Sizes", arXiv 2021.
Usage
import torch
from coatnet import coatnet_0
img = torch.randn(1, 3, 224, 224)
net = coatnet_0()
out = net(img)
Try out other block combinations mentioned in the paper:
from coatnet import CoAtNet
num_blocks = [2, 2, 3, 5, 2] # L
channels = [64, 96, 192, 384, 768] # D
block_types=['C', 'T', 'T', 'T'] # 'C' for MBConv, 'T' for Transformer
net = CoAtNet((224, 224), 3, num_blocks, channels, block_types=block_types)
out = net(img)
Citation
@article{dai2021coatnet,
title={CoAtNet: Marrying Convolution and Attention for All Data Sizes},
author={Dai, Zihang and Liu, Hanxiao and Le, Quoc V and Tan, Mingxing},
journal={arXiv preprint arXiv:2106.04803},
year={2021}
}
Credits
Code adapted from MobileNetV2 and ViT.