This repository contains a Python script for training a U-Net model on image restoration tasks. The script utilizes PyTorch, Accelerate, and various metrics to train and evaluate the model. The training process includes data loading, model training, validation, and metric logging.
- Model Training: Utilizes a U-Net architecture for image restoration.
- Data Loading: Loads training and validation data using PyTorch's
DataLoader
. - Metrics: Computes and logs metrics such as PSNR, SSIM, MAE, and LPIPS.
- Acceleration: Uses the
Accelerate
library for distributed training and mixed-precision. - Logging: Logs training metrics to a CSV file and plots them using Matplotlib.
- Checkpointing: Saves the best model based on PSNR.
- Python 3.7+
- PyTorch
- TorchMetrics
- PyTorch-SSIM
- Accelerate
- Matplotlib
- CSV
-
Clone the repository:
git clone https://github.com/yourusername/unet-image-restoration.git cd unet-image-restoration
-
Install the required packages:
pip install -r requirements.txt
The training configuration is managed via a YAML file (config.yml
). The configuration includes parameters for data directories, model settings, optimization settings, and training hyperparameters.
To start training, run the train.py
script:
python train.py
- Data Loading: The script loads training and validation data using the
get_training_data
andget_validation_data
functions from thedata
module. - Model Initialization: The U-Net model is initialized and moved to the appropriate device (CPU or GPU).
- Loss Functions: The script uses PSNR, SSIM, and LPIPS as loss functions.
- Optimization: AdamW optimizer with a cosine annealing learning rate scheduler is used for optimization.
- Training Loop: The model is trained for a specified number of epochs. After each epoch, the model is validated, and metrics are logged.
- Checkpointing: The best model based on PSNR is saved.
- Plotting: After training, the metrics are plotted and saved as an image.
The following metrics are computed and logged during training:
- PSNR (Peak Signal-to-Noise Ratio)
- SSIM (Structural Similarity Index Measure)
- MAE (Mean Absolute Error)
- LPIPS (Learned Perceptual Image Patch Similarity)
The training results, including metrics and plots, are saved in the runs/exp
directory. Each training run creates a new subdirectory with a unique number.
Contributions are welcome! Please open an issue or submit a pull request.
This project is licensed under the MIT License. See the LICENSE
file for details.
For any questions or issues, please open an issue on GitHub.