Skip to content

dongfangzhizhu/ECViT

Repository files navigation

ECViT: Efficient Convolutional Vision Transformer

This project is a PyTorch implementation of the paper "ECViT: Efficient Convolutional Vision Transformer with Local-Attention and Multi-scale Stages".

Model Overview

ECViT is a hybrid architecture that effectively combines the advantages of CNNs and Transformers. It introduces CNN's inductive biases (such as locality and translation invariance) into the Transformer framework through:

  1. Extracting patches from low-level features
  2. Enhancing encoders with convolutional operations
  3. Introducing local attention mechanisms and pyramid structures for efficient multi-scale feature extraction and representation

Model Architecture

ECViT model is divided into three stages, each generating feature maps at different scales:

  1. First Stage: Uses convolutional networks to extract low-dimensional features from images and convert them into token sequences (including a class token)
  2. Second and Third Stages: Share similar architecture, consisting of multiple convolution-enhanced Transformer encoder layers and a merging layer
    • Each encoder contains two sublayers: Partitioned Multi-head Self-Attention (P-MSA) and Interactive Feed-Forward Network (I-FFN)
    • Residual connections are applied after each module
    • Merging layers are used to reduce sequence length and increase feature dimensions
  3. Final Classification: Uses class token through MLP head for prediction

Key Components

1. Image Tokenization

  • Uses two depth-wise separable convolutions (kernel sizes: 7×1 and 1×7, stride: 2) to extract features
  • Output channels set to 32
  • Uses Max pooling (kernel size: 3×3, stride: 2) to further extract important features and reduce data dimensions
  • Reshapes feature maps into patch sequences and adds positional encoding
  • Adds learnable class token for classification tasks

2. Partitioned Multi-head Self-Attention (P-MSA)

  • Evenly divides patch tokens into non-overlapping blocks (block size: 7)
  • Appends class token to each block to facilitate local feature extraction
  • Performs self-attention computation within each block
  • Merges class tokens from each block, integrating local information from all blocks

3. Interactive Feed-Forward Network (I-FFN)

  • Separates patch tokens and class token
  • Expands patch tokens into 2D form based on relative positions in original image
  • Applies two depth-wise separable convolutions (kernel sizes: 3×1 and 1×3, stride: 1)
  • Flattens processed patch tokens into sequences and concatenates with class token

4. Token Merging

  • Uses Max pooling (kernel size: 4×4, stride: 4) to reduce token count
  • Applies linear layer to increase feature dimensions
  • Gradually reduces sequence length while enhancing feature dimensions, enabling tokens to capture more complex visual patterns

Project Structure

ecvit/
├── model/                  # Model implementation
│   ├── __init__.py
│   ├── ecvit.py           # ECViT model
│   ├── tokenization.py    # Image Tokenization module
│   ├── attention.py       # Partitioned Multi-head Self-Attention module
│   ├── ffn.py             # Interactive Feed-Forward Network module
│   ├── encoder.py         # Transformer encoder block
│   └── merging.py         # Token merging module
├── utils/                 # Utility functions
│   ├── __init__.py
│   ├── data_utils.py      # Data loading and processing tools
│   └── train_utils.py     # Training and evaluation tools
├── scripts/               # Scripts
│   ├── train.py           # Training script
│   └── validate.py        # Validation script
├── main.py                # Main script
├── requirements.txt       # Dependencies
└── README.md              # Project documentation

Installation

  1. Clone the repository:
git clone https://github.com/yourusername/ecvit.git
cd ecvit
  1. Install dependencies:
pip install -r requirements.txt

Usage

Training Model

Train the model with simulated data:

python main.py --mode train --dataset simulated --train-size 1000 --val-size 200 --img-size 224 --num-classes 10 --batch-size 64 --epochs 30

Validating Model

Validate the trained model:

python main.py --mode validate --dataset simulated --val-size 200 --img-size 224 --num-classes 10 --checkpoint output/model_best.pth

Training and Validating Model

Complete training and validation in one go:

python main.py --mode train_and_validate --dataset simulated --train-size 1000 --val-size 200 --img-size 224 --num-classes 10 --batch-size 64 --epochs 30

Parameter Description

  • --mode: Operation mode, options: train, validate, or train_and_validate
  • --dataset: Dataset type, currently supports simulated
  • --train-size: Training dataset size
  • --val-size: Validation dataset size
  • --img-size: Image size
  • --num-classes: Number of classes
  • --embed-dims: Embedding dimensions for each stage, comma-separated
  • --depths: Number of Transformer encoder blocks for each stage, comma-separated
  • --num-heads: Number of attention heads for each stage, comma-separated
  • --block-size: Block size in partitioned multi-head self-attention
  • --mlp-ratios: MLP expansion ratios for each stage, comma-separated
  • --batch-size: Batch size
  • --epochs: Number of training epochs
  • --lr: Learning rate
  • --weight-decay: Weight decay
  • --optimizer: Optimizer type, options: adamw or sgd
  • --seed: Random seed
  • --device: Device, options: cuda or cpu
  • --output-dir: Output directory
  • --resume: Checkpoint path for resuming training
  • --checkpoint: Checkpoint path for validation mode

Citation

If you use the code from this project, please cite the original paper:

@article{qian2025ecvit,
  title={ECViT: Efficient Convolutional Vision Transformer with Local-Attention and Multi-scale Stages},
  author={Qian, Zhoujie},
  journal={arXiv preprint arXiv:2504.14825},
  year={2025}
}

License

This project is licensed under the GNU General Public License v3.0.
See the LICENSE file for the full license text.

About

This project is a PyTorch implementation of the paper "ECViT: Efficient Convolutional Vision Transformer with Local-Attention and Multi-scale Stages".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages