Skip to content

TomLjm/STFA-ResNet

Repository files navigation

STFA-ResNet

Official codebase for STFA-ResNet: Joint Spatiotemporal-Frequency Channel Estimation for High-Mobility MIMO-OFDM Systems.

This repository provides a reproducible research pipeline for:

  • STF-aware model training (STFA, SF-CNN, DeepMIMO-Net)
  • NMSE/BER evaluation under high mobility
  • cross-channel generalization analysis (CDL-C only vs Mixed (A+C+D))
  • ablation studies and result summarization

Method Overview

STFA-ResNet is designed for doubly selective channels in high-mobility MIMO-OFDM systems.
It combines:

  1. 3D convolution for joint spatiotemporal-frequency feature encoding
  2. BiLSTM for Doppler-aware bidirectional temporal dynamics
  3. SE attention for channel-wise denoising and feature recalibration
  4. Physics-constrained loss with time/frequency smoothness regularization

Compared with LS/MMSE and representative deep baselines, the paper reports better NMSE/BER robustness, especially under low-SNR and high-Doppler conditions.

Repository Structure

.
├── main.py                                  # train/eval entrypoint
├── train.py                                 # training loop and checkpointing
├── config.py                                # experiment profiles and hyperparameters
├── dataset.py                               # online/precomputed dataset loaders
├── channel_model.py                         # 3GPP-style CDL channel generation
├── mimo_ofdm.py                             # pilot transmission and classical estimators
├── scripts/
│   ├── run_benchmarks.py                    # unified benchmark table
│   ├── export_nmse_curves.py                # NMSE vs SNR / velocity CSV
│   ├── export_ber_curves.py                 # BER vs SNR CSV
│   ├── export_cross_model_generalization.py # Table-I style generalization CSV
│   ├── plot_paper_figures.py                # optional plotting utility
│   └── reproduce_all.sh                     # optional all-in-one script
└── summarize_ablation.py                    # ablation log summarization

Environment Setup

Option A: pip

pip install -r requirements.txt

Option B: conda

conda env create -f environment.yml
conda activate stfa-resnet

Experiment Profiles

Use STFA_PROFILE to select scale:

  • small: smoke test
  • repro (default): single-GPU friendly
  • full: paper-scale setting (Nt=64, Nr=16, K=256, train/val/test=15000/2000/2000)

Recommended default:

export STFA_PROFILE=repro
export USE_AMP=1
export ACCUM_STEPS=2
export USE_WANDB=0

Reproducibility Protocol (Modular, Recommended)

This repository is organized as modular stages. Run each stage independently for transparent and auditable reproduction.

1) Data Integrity Check and Completion

python check_and_remove_bad_chunks.py --remove-bad --manifest
python generate_dataset.py --resume --only-missing --manifest

Artifacts:

  • data/data_manifest.json
  • data/{train,val,test}_chunk_*.npz

2) Train Deep Models (Independent Checkpoints)

CKPT_PATH=checkpoints/stfa_best.pt python main.py --mode train --model-name stfa
CKPT_PATH=checkpoints/sfcnn_best.pt python main.py --mode train --model-name sfcnn
CKPT_PATH=checkpoints/deepmimo_best.pt python main.py --mode train --model-name deepmimo

3) Single-Model Evaluation and Classical Baselines

python main.py --mode eval --model-name stfa --model checkpoints/stfa_best.pt
python main.py --mode eval_velocity --model-name stfa --model checkpoints/stfa_best.pt
python main.py --mode baseline --baseline-type ls
python main.py --mode baseline --baseline-type mmse

4) Unified Benchmark and NMSE Curves

python scripts/run_benchmarks.py \
  --stfa-ckpt checkpoints/stfa_best.pt \
  --sfcnn-ckpt checkpoints/sfcnn_best.pt \
  --deepmimo-ckpt checkpoints/deepmimo_best.pt

python scripts/export_nmse_curves.py \
  --stfa-ckpt checkpoints/stfa_best.pt \
  --sfcnn-ckpt checkpoints/sfcnn_best.pt \
  --deepmimo-ckpt checkpoints/deepmimo_best.pt

run_benchmarks.py and export_nmse_curves.py enforce checkpoint existence to prevent accidental evaluation of untrained models.

5) BER Curves (Paper BER Figure)

python scripts/export_ber_curves.py \
  --stfa-ckpt checkpoints/stfa_best.pt \
  --sfcnn-ckpt checkpoints/sfcnn_best.pt \
  --deepmimo-ckpt checkpoints/deepmimo_best.pt \
  --n-samples 50 \
  --output results/ber_vs_snr.csv

6) Cross-Model Generalization (Paper Table I Style)

Train two STFA checkpoints:

# Setting A: trained on CDL-C only
export CDL_MODEL=CDL-C
CKPT_PATH=checkpoints/stfa_cdlc_only.pt python main.py --mode train --model-name stfa

# Setting B: trained on mixed CDL-A/CDL-C/CDL-D
export CDL_MODEL=MIXED
CKPT_PATH=checkpoints/stfa_mixed_acd.pt python main.py --mode train --model-name stfa

Export Table-I style CSV:

python scripts/export_cross_model_generalization.py \
  --model-name stfa \
  --single-ckpt checkpoints/stfa_cdlc_only.pt \
  --mixed-ckpt checkpoints/stfa_mixed_acd.pt \
  --n-samples 100 \
  --repeats 5 \
  --output results/cross_model_generalization.csv

7) Ablation Study

bash run_ablation_background.sh
python summarize_ablation.py

Result Artifacts

  • results/benchmark_summary.csv
  • results/nmse_vs_snr.csv
  • results/nmse_vs_velocity.csv
  • results/ber_vs_snr.csv
  • results/cross_model_generalization.csv
  • ablation_summary.md
  • ablation_summary.csv

Channelformer Integration

Channelformer is treated as an external baseline and merged via CSV.

python scripts/run_benchmarks.py \
  --stfa-ckpt checkpoints/stfa_best.pt \
  --sfcnn-ckpt checkpoints/sfcnn_best.pt \
  --deepmimo-ckpt checkpoints/deepmimo_best.pt \
  --channelformer-csv results/channelformer_eval.csv

Required CSV fields:

model,nmse_db_at_train_snr,params_m,inference_sec_100samples,source
and one row with model=channelformer.

Practical Notes

  • If GPU memory is limited: use STFA_PROFILE=repro, reduce BATCH_SIZE, increase ACCUM_STEPS.
  • If precomputed data checks fail: rerun data cleaning/completion in Stage 1.
  • For paper-level numbers, use STFA_PROFILE=full and report hardware/runtime details.

Citation

Please cite this repository using CITATION.cff.

About

Code for STFA-ResNet: Joint Spatiotemporal–Frequency Channel Estimation for High-Mobility MIMO-OFDM Systems.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors