Official codebase for STFA-ResNet: Joint Spatiotemporal-Frequency Channel Estimation for High-Mobility MIMO-OFDM Systems.
This repository provides a reproducible research pipeline for:
- STF-aware model training (
STFA,SF-CNN,DeepMIMO-Net) - NMSE/BER evaluation under high mobility
- cross-channel generalization analysis (
CDL-C onlyvsMixed (A+C+D)) - ablation studies and result summarization
STFA-ResNet is designed for doubly selective channels in high-mobility MIMO-OFDM systems.
It combines:
- 3D convolution for joint spatiotemporal-frequency feature encoding
- BiLSTM for Doppler-aware bidirectional temporal dynamics
- SE attention for channel-wise denoising and feature recalibration
- Physics-constrained loss with time/frequency smoothness regularization
Compared with LS/MMSE and representative deep baselines, the paper reports better NMSE/BER robustness, especially under low-SNR and high-Doppler conditions.
.
├── main.py # train/eval entrypoint
├── train.py # training loop and checkpointing
├── config.py # experiment profiles and hyperparameters
├── dataset.py # online/precomputed dataset loaders
├── channel_model.py # 3GPP-style CDL channel generation
├── mimo_ofdm.py # pilot transmission and classical estimators
├── scripts/
│ ├── run_benchmarks.py # unified benchmark table
│ ├── export_nmse_curves.py # NMSE vs SNR / velocity CSV
│ ├── export_ber_curves.py # BER vs SNR CSV
│ ├── export_cross_model_generalization.py # Table-I style generalization CSV
│ ├── plot_paper_figures.py # optional plotting utility
│ └── reproduce_all.sh # optional all-in-one script
└── summarize_ablation.py # ablation log summarization
pip install -r requirements.txtconda env create -f environment.yml
conda activate stfa-resnetUse STFA_PROFILE to select scale:
small: smoke testrepro(default): single-GPU friendlyfull: paper-scale setting (Nt=64, Nr=16, K=256,train/val/test=15000/2000/2000)
Recommended default:
export STFA_PROFILE=repro
export USE_AMP=1
export ACCUM_STEPS=2
export USE_WANDB=0This repository is organized as modular stages. Run each stage independently for transparent and auditable reproduction.
python check_and_remove_bad_chunks.py --remove-bad --manifest
python generate_dataset.py --resume --only-missing --manifestArtifacts:
data/data_manifest.jsondata/{train,val,test}_chunk_*.npz
CKPT_PATH=checkpoints/stfa_best.pt python main.py --mode train --model-name stfa
CKPT_PATH=checkpoints/sfcnn_best.pt python main.py --mode train --model-name sfcnn
CKPT_PATH=checkpoints/deepmimo_best.pt python main.py --mode train --model-name deepmimopython main.py --mode eval --model-name stfa --model checkpoints/stfa_best.pt
python main.py --mode eval_velocity --model-name stfa --model checkpoints/stfa_best.pt
python main.py --mode baseline --baseline-type ls
python main.py --mode baseline --baseline-type mmsepython scripts/run_benchmarks.py \
--stfa-ckpt checkpoints/stfa_best.pt \
--sfcnn-ckpt checkpoints/sfcnn_best.pt \
--deepmimo-ckpt checkpoints/deepmimo_best.pt
python scripts/export_nmse_curves.py \
--stfa-ckpt checkpoints/stfa_best.pt \
--sfcnn-ckpt checkpoints/sfcnn_best.pt \
--deepmimo-ckpt checkpoints/deepmimo_best.ptrun_benchmarks.py and export_nmse_curves.py enforce checkpoint existence to prevent accidental evaluation of untrained models.
python scripts/export_ber_curves.py \
--stfa-ckpt checkpoints/stfa_best.pt \
--sfcnn-ckpt checkpoints/sfcnn_best.pt \
--deepmimo-ckpt checkpoints/deepmimo_best.pt \
--n-samples 50 \
--output results/ber_vs_snr.csvTrain two STFA checkpoints:
# Setting A: trained on CDL-C only
export CDL_MODEL=CDL-C
CKPT_PATH=checkpoints/stfa_cdlc_only.pt python main.py --mode train --model-name stfa
# Setting B: trained on mixed CDL-A/CDL-C/CDL-D
export CDL_MODEL=MIXED
CKPT_PATH=checkpoints/stfa_mixed_acd.pt python main.py --mode train --model-name stfaExport Table-I style CSV:
python scripts/export_cross_model_generalization.py \
--model-name stfa \
--single-ckpt checkpoints/stfa_cdlc_only.pt \
--mixed-ckpt checkpoints/stfa_mixed_acd.pt \
--n-samples 100 \
--repeats 5 \
--output results/cross_model_generalization.csvbash run_ablation_background.sh
python summarize_ablation.pyresults/benchmark_summary.csvresults/nmse_vs_snr.csvresults/nmse_vs_velocity.csvresults/ber_vs_snr.csvresults/cross_model_generalization.csvablation_summary.mdablation_summary.csv
Channelformer is treated as an external baseline and merged via CSV.
python scripts/run_benchmarks.py \
--stfa-ckpt checkpoints/stfa_best.pt \
--sfcnn-ckpt checkpoints/sfcnn_best.pt \
--deepmimo-ckpt checkpoints/deepmimo_best.pt \
--channelformer-csv results/channelformer_eval.csvRequired CSV fields:
model,nmse_db_at_train_snr,params_m,inference_sec_100samples,source
and one row with model=channelformer.
- If GPU memory is limited: use
STFA_PROFILE=repro, reduceBATCH_SIZE, increaseACCUM_STEPS. - If precomputed data checks fail: rerun data cleaning/completion in Stage 1.
- For paper-level numbers, use
STFA_PROFILE=fulland report hardware/runtime details.
Please cite this repository using CITATION.cff.