This repository is the official implementation of Dendritic Integration Inspired Artificial Neural Networks Capture Data Correlation.
Figure: A. Experiments confirmed the quadratic integration rule under general
cases, along with a comprehensive theoretical framework for single neuron computation (From Li et al. 2023). B.
An illustration of the biological interpretation of our Dit-CNNs.
Our Dit-CNN is inspired by neural networks in the visual system. For example, different types of cone cells encode various color (channel) information, and retinal ganglion cells receive inputs from multiple types of cone cells, the responses can be modeled as having receptive fields (convolutional kernels) related to different color channels (
To install requirements:
pip install -r requirements.txt
To train the models on CIFAR as described in the paper, run the following command:
python cifar10.py --model dit_resnet20
📋 For details on configuring data and training popular models on ImageNet-1K, refer here.
After configuring the data, run the following commands to integrate dit_convnext into the timm library:
mv quadratic.py .../env/lib/python3.10/site-packages/timm/layers
mv convnext.py .../env/lib/python3.10/site-packages/timm/models
Then train Dit-ConvNeXt using the following command (with multiple GPUs):
torchrun --nproc_per_node=8 train.py data_path -b 64 --model convnext_tiny --amp --resplit --weight-decay 0.08 --sched cosine --lr 0.006 --epochs 300 --warmup-epochs 20 --opt adamw --aa rand-m9-mstd0.5 --mixup 0.8 --cutmix 1.0 --reprob 0.25 --drop-path 0.1 --model-ema --grad-accum-steps 8 --crop-pct 0.95
Our model achieves the following performance:
Model name | Top 1 Accuracy |
---|---|
Dit-ConvNeXt-T | 82.6% |
Dit-ConvNeXt-S | 83.6% |
Dit-ConvNeXt-B | 84.2% |
This project is licensed under the MIT License.