diff --git a/1-pretrain.py b/1-pretrain.py
index 50fee2a..ad1a45c 100644
--- a/1-pretrain.py
+++ b/1-pretrain.py
@@ -1,5 +1,6 @@
 import os
 import platform
+import argparse
 import time
 import math
 import warnings
@@ -23,67 +24,66 @@ def Logger(content):
 
 
 def get_lr(it, all):
-    warmup_iters = 0
+    warmup_iters = args.warmup_iters
     lr_decay_iters = all
-    min_lr = learning_rate / 10
+    min_lr = args.learning_rate / 10
 
     if it < warmup_iters:
-        return learning_rate * it / warmup_iters
+        return args.learning_rate * it / warmup_iters
     if it > lr_decay_iters:
         return min_lr
     decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
     assert 0 <= decay_ratio <= 1
     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
-    return min_lr + coeff * (learning_rate - min_lr)
+    return min_lr + coeff * (args.learning_rate - min_lr)
 
 
-def train_epoch(epoch, wandb, accumulation_steps=8):
+def train_epoch(epoch, wandb):
     start_time = time.time()
     for step, (X, Y) in enumerate(train_loader):
-        X = X.to(device)
-        Y = Y.to(device)
+        X = X.to(args.device)
+        Y = Y.to(args.device)
 
-        lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
+        lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
         for param_group in optimizer.param_groups:
             param_group['lr'] = lr
 
         with ctx:
             out = model(X, Y)
-            loss = out.last_loss / accumulation_steps
+            loss = out.last_loss / args.accumulation_steps
 
         scaler.scale(loss).backward()
 
-        if (step + 1) % accumulation_steps == 0:
+        if (step + 1) % args.accumulation_steps == 0:
             scaler.unscale_(optimizer)
-            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
 
             scaler.step(optimizer)
             scaler.update()
 
             optimizer.zero_grad(set_to_none=True)
 
-        if step % 100 == 0:
+        if step % args.log_interval == 0:
             spend_time = time.time() - start_time
             Logger(
                 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                     epoch,
-                    epochs,
+                    args.epochs,
                     step,
                     iter_per_epoch,
-                    loss.item() * accumulation_steps,
+                    loss.item() * args.accumulation_steps,
                     optimizer.param_groups[-1]['lr'],
                     spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
 
             if (wandb is not None) and (not ddp or dist.get_rank() == 0):
-                wandb.log({"loss": loss.item() * accumulation_steps,
+                wandb.log({"loss": loss.item() * args.accumulation_steps,
                            "lr": optimizer.param_groups[-1]['lr'],
                            "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
 
-        if (step + 1) % 1000 == 0 and (not ddp or dist.get_rank() == 0):
+        if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
             model.eval()
-            # torch.save(model.state_dict(), '{}/iter_{}.pth'.format(save_dir, int(step + epoch * iter_per_epoch)))
             moe_path = '_moe' if lm_config.use_moe else ''
-            ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
+            ckp = f'{args.save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
 
             if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                 state_dict = model.module.state_dict()
@@ -98,17 +98,8 @@ def init_model():
     def count_parameters(model):
         return sum(p.numel() for p in model.parameters() if p.requires_grad)
 
-    # model init
-    model = Transformer(lm_config).to(device)
+    model = Transformer(lm_config).to(args.device)
     moe_path = '_moe' if lm_config.use_moe else ''
-    # ckp = f'{save_dir}/pretrain_{lm_config.dim}{moe_path}.pth'
-    #
-    # state_dict = torch.load(ckp, map_location=device)
-    # unwanted_prefix = '_orig_mod.'
-    # for k, v in list(state_dict.items()):
-    #     if k.startswith(unwanted_prefix):
-    #         state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
-    # model.load_state_dict(state_dict, strict=False)
 
     Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
     return model
@@ -127,79 +118,79 @@ def init_distributed_mode():
 
 
 # torchrun --nproc_per_node 2 1-pretrain.py
-# I/O
 if __name__ == "__main__":
-    # -----------------------------------------------------------------------------
+    parser = argparse.ArgumentParser(description="MiniMind Pretraining")
+    parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
+    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
+    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
+    parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate")
+    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
+    parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
+    parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
+    parser.add_argument("--wandb_project", type=str, default="MiniMind-Pretrain", help="Weights & Biases project name")
+    parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading")
+    parser.add_argument("--data_path", type=str, default="./dataset/pretrain_data.bin", help="Path to training data")
+    parser.add_argument("--ddp", action="store_true", help="Use DistributedDataParallel")
+    parser.add_argument("--accumulation_steps", type=int, default=8, help="Gradient accumulation steps")
+    parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
+    parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations")
+    parser.add_argument("--log_interval", type=int, default=100, help="Logging interval")
+    parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval")
+
+    args = parser.parse_args()
+
     lm_config = LMConfig()
     max_seq_len = lm_config.max_seq_len
-    out_dir = 'out'
-    epochs = 20
-    batch_size = 64
-    learning_rate = 2e-4
-    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-    dtype = 'bfloat16'
-    save_dir = os.path.join(out_dir)
-    os.makedirs(save_dir, exist_ok=True)
-    os.makedirs(out_dir, exist_ok=True)
-    tokens_per_iter = batch_size * max_seq_len
+    args.save_dir = os.path.join(args.out_dir)
+    os.makedirs(args.save_dir, exist_ok=True)
+    os.makedirs(args.out_dir, exist_ok=True)
+    tokens_per_iter = args.batch_size * max_seq_len
     torch.manual_seed(1337)
-    device_type = device if "cuda" in device else "cpu"
+    device_type = "cuda" if "cuda" in args.device else "cpu"
 
-    use_wandb = False  # 是否使用wandb
-    wandb_project = "MiniMind-Pretrain"
-    wandb_run_name = f"MiniMind-Pretrain-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}"
-    if use_wandb:
-        import wandb
-        wandb.init(project=wandb_project, name=wandb_run_name)
-    else:
-        wandb = None
+    args.wandb_run_name = f"MiniMind-Pretrain-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
+
+    ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
 
-    ctx = (
-        nullcontext()
-        if device_type == "cpu"
-        else torch.cuda.amp.autocast()
-    )
     ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
     ddp_local_rank, DEVICE = 0, "cuda:0"
     if ddp:
         init_distributed_mode()
-        device = torch.device(DEVICE)
-    # -----------------------------------------------------------------------------
+        args.device = torch.device(DEVICE)
 
-    # -----init dataloader------
-    data_path_list = ['./dataset/pretrain_data.bin']
+    if args.use_wandb and (not ddp or ddp_local_rank == 0):
+        import wandb
+        wandb.init(project=args.wandb_project, name=args.wandb_run_name)
+    else:
+        wandb = None
+
+    data_path_list = [args.data_path]
     train_ds = PretrainDataset(data_path_list, max_length=max_seq_len, memmap=True)
     train_sampler = DistributedSampler(train_ds) if ddp else None
-    num_workers = 8  # 可以根据系统的 CPU 核心数来调整
     train_loader = DataLoader(
         train_ds,
-        batch_size=batch_size,
+        batch_size=args.batch_size,
         pin_memory=True,
         drop_last=False,
         shuffle=False,
-        num_workers=num_workers,
+        num_workers=args.num_workers,
         sampler=train_sampler
     )
 
-    # init model
     model = init_model()
 
-    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))
-    # optimizer
-    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
-    # compile the model
+    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
+    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+
     if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
         Logger("compiling the model... (takes a ~minute)")
         unoptimized_model = model
         model = torch.compile(model)
 
     if ddp:
-        # Ignore the freqs_cis buffer so that DDP does not broadcast it at
-        # construction time since NCCL does not support ComplexFloat
         model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
         model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
 
-    # training loop
     iter_per_epoch = len(train_loader)
-    for epoch in range(epochs):
+    for epoch in range(args.epochs):
         train_epoch(epoch, wandb)
diff --git a/3-full_sft.py b/3-full_sft.py
index c413de0..d82e662 100644
--- a/3-full_sft.py
+++ b/3-full_sft.py
@@ -1,5 +1,6 @@
 import os
 import platform
+import argparse
 import time
 import math
 import warnings
@@ -12,7 +13,6 @@
 
 from torch import optim
 from torch.nn.parallel import DistributedDataParallel
-from torch.optim.lr_scheduler import CosineAnnealingLR
 from torch.utils.data import DataLoader, DistributedSampler
 from transformers import AutoTokenizer, AutoModel
 from model.model import Transformer
@@ -28,28 +28,27 @@ def Logger(content):
 
 
 def get_lr(it, all):
-    warmup_iters = 0
+    warmup_iters = args.warmup_iters
     lr_decay_iters = all
-    min_lr = learning_rate / epochs
+    min_lr = args.learning_rate / 10
 
     if it < warmup_iters:
-        return learning_rate * it / warmup_iters
+        return args.learning_rate * it / warmup_iters
     if it > lr_decay_iters:
         return min_lr
     decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
     assert 0 <= decay_ratio <= 1
     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
-    return min_lr + coeff * (learning_rate - min_lr)
+    return min_lr + coeff * (args.learning_rate - min_lr)
 
 
-# ------------------------------------------------------------------------------
 def train_epoch(epoch, wandb):
     start_time = time.time()
     for step, (X, Y, loss_mask) in enumerate(train_loader):
-        X = X.to(device)
-        Y = Y.to(device)
-        loss_mask = loss_mask.to(device)
-        lr = get_lr(epoch * iter_per_epoch + step, epochs * iter_per_epoch)
+        X = X.to(args.device)
+        Y = Y.to(args.device)
+        loss_mask = loss_mask.to(args.device)
+        lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
         for param_group in optimizer.param_groups:
             param_group['lr'] = lr
 
@@ -59,30 +58,26 @@ def train_epoch(epoch, wandb):
             loss_mask = loss_mask.view(-1)
             loss = torch.sum(loss * loss_mask) / loss_mask.sum()
 
-        # Backward pass
         scaler.scale(loss).backward()
 
-        # Unscale gradients and clip them
-        scaler.unscale_(optimizer)
-        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+        if (step + 1) % args.accumulation_steps == 0:
+            scaler.unscale_(optimizer)
+            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
 
-        # Update parameters
-        scaler.step(optimizer)
-        scaler.update()
+            scaler.step(optimizer)
+            scaler.update()
 
-        # Zero the gradients
-        optimizer.zero_grad(set_to_none=True)
+            optimizer.zero_grad(set_to_none=True)
 
-        # 打印日志
-        if step % 100 == 0:
+        if step % args.log_interval == 0:
             spend_time = time.time() - start_time
             Logger(
-                'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.8f} epoch_Time:{}min:'.format(
+                'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                     epoch,
-                    epochs,
+                    args.epochs,
                     step,
                     iter_per_epoch,
-                    loss,
+                    loss.item(),
                     optimizer.param_groups[-1]['lr'],
                     spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
 
@@ -91,11 +86,11 @@ def train_epoch(epoch, wandb):
                            "lr": optimizer.param_groups[-1]['lr'],
                            "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
 
-        if (step + 1) % 1000 == 0 and (not ddp or dist.get_rank() == 0):
+        if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
             model.eval()
-            # torch.save(model.state_dict(), '{}/sft_iter_{}.pth'.format(save_dir, int(step + epoch * iter_per_epoch)))
             moe_path = '_moe' if lm_config.use_moe else ''
-            ckp = f'{save_dir}/full_sft_{lm_config.dim}{moe_path}.pth'
+            ckp = f'{args.save_dir}/full_sft_{lm_config.dim}{moe_path}.pth'
+
             if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                 state_dict = model.module.state_dict()
             else:
@@ -105,7 +100,7 @@ def train_epoch(epoch, wandb):
             model.train()
 
 
-def init_model(lm_config):
+def init_model():
     tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
     model_from = 1  # 1从权重,2用transformers
 
@@ -116,7 +111,7 @@ def count_parameters(model):
         model = Transformer(lm_config)
         moe_path = '_moe' if lm_config.use_moe else ''
         ckp = f'./out/pretrain_{lm_config.dim}{moe_path}.pth'
-        state_dict = torch.load(ckp, map_location=device)
+        state_dict = torch.load(ckp, map_location=args.device)
         unwanted_prefix = '_orig_mod.'
         for k, v in list(state_dict.items()):
             if k.startswith(unwanted_prefix):
@@ -126,7 +121,7 @@ def count_parameters(model):
         model = AutoModel.from_pretrained('./minimind', trust_remote_code=True)
 
     Logger(f'LLM总参数量:{count_parameters(model) / 1e6:.3f} 百万')
-    model = model.to(device)
+    model = model.to(args.device)
 
     return model, tokenizer
 
@@ -143,84 +138,78 @@ def init_distributed_mode():
     torch.cuda.set_device(DEVICE)
 
 
-# I/O
 if __name__ == "__main__":
-    # -----------------------------------------------------------------------------
+    parser = argparse.ArgumentParser(description="MiniMind Full SFT")
+    parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
+    parser.add_argument("--epochs", type=int, default=19, help="Number of epochs")
+    parser.add_argument("--batch_size", type=int, default=40, help="Batch size")
+    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
+    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
+    parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
+    parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
+    parser.add_argument("--wandb_project", type=str, default="MiniMind-Full-SFT", help="Weights & Biases project name")
+    parser.add_argument("--num_workers", type=int, default=8, help="Number of workers for data loading")
+    parser.add_argument("--ddp", action="store_true", help="Use DistributedDataParallel")
+    parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
+    parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
+    parser.add_argument("--warmup_iters", type=int, default=0, help="Number of warmup iterations")
+    parser.add_argument("--log_interval", type=int, default=100, help="Logging interval")
+    parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval")
+
+    args = parser.parse_args()
+
     lm_config = LMConfig()
     max_seq_len = lm_config.max_seq_len
-    out_dir = 'out'
-    epochs = 19
-    gradient_accumulation_steps = 1
-    batch_size = 40
-    learning_rate = 1e-4
-    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-    dtype = 'bfloat16'
-    # dtype = 'float16'
-    save_dir = os.path.join(out_dir)
-    os.makedirs(save_dir, exist_ok=True)
-    tokens_per_iter = gradient_accumulation_steps * batch_size * max_seq_len
-    os.makedirs(out_dir, exist_ok=True)
+    args.save_dir = os.path.join(args.out_dir)
+    os.makedirs(args.save_dir, exist_ok=True)
+    os.makedirs(args.out_dir, exist_ok=True)
+    tokens_per_iter = args.batch_size * max_seq_len
     torch.manual_seed(1337)
-    device_type = device if "cuda" in device else "cpu"
-
-    use_wandb = False  # 是否使用wandb
-    wandb_project = "MiniMind-Full-SFT"
-    wandb_run_name = f"MiniMind-Full-SFT-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}"
-    if use_wandb:
-        import wandb
+    device_type = "cuda" if "cuda" in args.device else "cpu"
 
-        wandb.init(project=wandb_project, name=wandb_run_name)
-    else:
-        wandb = None
-
-    ctx = (
-        nullcontext()
-        if device_type == "cpu"
-        else torch.cuda.amp.autocast()
-    )
+    args.wandb_run_name = f"MiniMind-Full-SFT-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
 
-    ### ddp config
+    ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
     ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
     ddp_local_rank, DEVICE = 0, "cuda:0"
     if ddp:
         init_distributed_mode()
-        device = torch.device(DEVICE)
-    # -----------------------------------------------------------------------------
+        args.device = torch.device(DEVICE)
+
+    if args.use_wandb and (not ddp or ddp_local_rank == 0):
+        import wandb
+        wandb.init(project=args.wandb_project, name=args.wandb_run_name)
+    else:
+        wandb = None
+
+    model, tokenizer = init_model()
 
-    model, tokenizer = init_model(lm_config)
-    # -----init dataloader------
     df = pd.read_csv('./dataset/sft_data_single.csv')
     df = df.sample(frac=1.0)
     train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
     train_sampler = DistributedSampler(train_ds) if ddp else None
     train_loader = DataLoader(
         train_ds,
-        batch_size=batch_size,
-        pin_memory=False,
+        batch_size=args.batch_size,
+        pin_memory=True,
         drop_last=False,
         shuffle=False,
-        num_workers=8,
+        num_workers=args.num_workers,
         sampler=train_sampler
     )
 
-    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == dtype))
-    # optimizer
-    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
+    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == args.dtype))
+    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
 
-    iter_per_epoch = len(train_loader)
-    # compile the model
-    if False and not lm_config.use_moe and platform.system() != 'Windows' and float(
-            torch.__version__.split('.')[0]) >= 2:
+    if False and not lm_config.use_moe and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
         Logger("compiling the model... (takes a ~minute)")
         unoptimized_model = model
-        model = torch.compile(model)  # requires PyTorch 2.0
+        model = torch.compile(model)
 
     if ddp:
-        # Ignore the pos_cis buffer so that DDP does not broadcast it at
-        # construction time since NCCL does not support ComplexFloat
         model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
         model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
 
-    # training loop
-    for epoch in range(epochs):
+    iter_per_epoch = len(train_loader)
+    for epoch in range(args.epochs):
         train_epoch(epoch, wandb)
diff --git a/4-lora_sft.py b/4-lora_sft.py
index ab8ba31..e72f8ca 100644
--- a/4-lora_sft.py
+++ b/4-lora_sft.py
@@ -1,5 +1,6 @@
 import os
 import platform
+import argparse
 import time
 import math
 import warnings
@@ -16,32 +17,36 @@
 from model.LMConfig import LMConfig
 from model.dataset import SFTDataset
 
-warnings.filterwarnings('ignore', category=UserWarning)
+warnings.filterwarnings('ignore')
 
 
-def get_lr(it):
-    warmup_iters = 1000
-    lr_decay_iters = 80000
-    min_lr = 1e-5
+def Logger(content):
+    print(content)
+
+
+def get_lr(it, all):
+    warmup_iters = args.warmup_iters
+    lr_decay_iters = all
+    min_lr = args.learning_rate / 10
 
     if it < warmup_iters:
-        return learning_rate * it / warmup_iters
+        return args.learning_rate * it / warmup_iters
     if it > lr_decay_iters:
         return min_lr
     decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
     assert 0 <= decay_ratio <= 1
     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
-    return min_lr + coeff * (learning_rate - min_lr)
+    return min_lr + coeff * (args.learning_rate - min_lr)
 
 
-# ------------------------------------------------------------------------------
 def train_epoch(epoch, wandb):
     start_time = time.time()
     for step, (X, Y, loss_mask) in enumerate(train_loader):
-        X = X.to(device)
-        Y = Y.to(device)
-        loss_mask = loss_mask.to(device)
-        lr = get_lr(epoch * iter_per_epoch + step)
+        X = X.to(args.device)
+        Y = Y.to(args.device)
+        loss_mask = loss_mask.to(args.device)
+
+        lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch)
         for param_group in optimizer.param_groups:
             param_group['lr'] = lr
 
@@ -50,33 +55,38 @@ def train_epoch(epoch, wandb):
             loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=0, reduction='none')
             loss_mask = loss_mask.view(-1)
             loss = torch.sum(loss * loss_mask) / loss_mask.sum()
+            loss = loss / args.accumulation_steps
 
         scaler.scale(loss).backward()
 
-        scaler.unscale_(optimizer)
-        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+        if (step + 1) % args.accumulation_steps == 0:
+            scaler.unscale_(optimizer)
+            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
 
-        scaler.step(optimizer)
-        scaler.update()
+            scaler.step(optimizer)
+            scaler.update()
 
-        optimizer.zero_grad(set_to_none=True)
+            optimizer.zero_grad(set_to_none=True)
 
-        if step % 100 == 0:
+        if step % args.log_interval == 0:
             spend_time = time.time() - start_time
-            print(
+            Logger(
                 'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                     epoch,
-                    epochs,
+                    args.epochs,
                     step,
                     iter_per_epoch,
-                    loss.item(),
+                    loss.item() * args.accumulation_steps,
                     optimizer.param_groups[-1]['lr'],
                     spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
-
             if wandb is not None:
-                wandb.log({"loss": loss.item(), "lr": optimizer.param_groups[-1]['lr'],
+                wandb.log({"loss": loss.item() * args.accumulation_steps,
+                           "lr": optimizer.param_groups[-1]['lr'],
                            "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
 
+        if (step + 1) % args.save_interval == 0:
+            model.save_pretrained(args.save_dir)
+
 
 def find_all_linear_names(model):
     cls = torch.nn.Linear
@@ -95,7 +105,7 @@ def init_model():
     model_name_or_path = "./minimind-v1-small"
     tokenizer_name_or_path = "./minimind-v1-small"
     tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True, use_fast=False)
-    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(device)
+    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(args.device)
 
     target_modules = find_all_linear_names(model)
     peft_config = LoraConfig(
@@ -108,74 +118,70 @@ def init_model():
     )
     model = get_peft_model(model, peft_config)
     model.print_trainable_parameters()
-    model = model.to(device)
+    model = model.to(args.device)
     return model, tokenizer
 
 
-# I/O
 if __name__ == "__main__":
-    # -----------------------------------------------------------------------------
+    parser = argparse.ArgumentParser(description="MiniMind LoRA Fine-tuning")
+    parser.add_argument("--out_dir", type=str, default="out", help="Output directory")
+    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
+    parser.add_argument("--batch_size", type=int, default=16, help="Batch size")
+    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
+    parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
+    parser.add_argument("--dtype", type=str, default="bfloat16", help="Data type")
+    parser.add_argument("--use_wandb", action="store_true", help="Use Weights & Biases")
+    parser.add_argument("--wandb_project", type=str, default="MiniMind-LoRA", help="Weights & Biases project name")
+    parser.add_argument("--num_workers", type=int, default=0, help="Number of workers for data loading")
+    parser.add_argument("--accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
+    parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping threshold")
+    parser.add_argument("--warmup_iters", type=int, default=1000, help="Number of warmup iterations")
+    parser.add_argument("--log_interval", type=int, default=100, help="Logging interval")
+    parser.add_argument("--save_interval", type=int, default=1000, help="Model saving interval")
+
+    args = parser.parse_args()
+
     lm_config = LMConfig()
     max_seq_len = lm_config.max_seq_len
-    out_dir = 'out'
-    epochs = 20
-    gradient_accumulation_steps = 1
-    batch_size = 16
-    learning_rate = 1e-4
-    weight_decay = 1e-1
-    device = 'cuda:0'
-    dtype = 'bfloat16'
-    save_dir = os.path.join(out_dir)
-    os.makedirs(save_dir, exist_ok=True)
-    tokens_per_iter = gradient_accumulation_steps * batch_size * max_seq_len
-    os.makedirs(out_dir, exist_ok=True)
+    args.save_dir = os.path.join(args.out_dir)
+    os.makedirs(args.save_dir, exist_ok=True)
+    os.makedirs(args.out_dir, exist_ok=True)
+    tokens_per_iter = args.batch_size * max_seq_len
     torch.manual_seed(1337)
-    device_type = device if "cuda" in device else "cpu"
+    device_type = "cuda" if "cuda" in args.device else "cpu"
 
-    use_wandb = False  # 是否使用wandb
-    wandb_project = "MiniMind-LoRA-SFT"
-    wandb_run_name = f"MiniMind-LoRA-SFT-Epoch-{epochs}-BatchSize-{batch_size}-LearningRate-{learning_rate}"
-    if use_wandb:
-        import wandb
+    args.wandb_run_name = f"MiniMind-LoRA-Epoch-{args.epochs}-BatchSize-{args.batch_size}-LearningRate-{args.learning_rate}"
 
-        wandb.init(project=wandb_project, name=wandb_run_name)
+    ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
+
+    if args.use_wandb:
+        import wandb
+        wandb.init(project=args.wandb_project, name=args.wandb_run_name)
     else:
         wandb = None
 
-    ctx = (
-        nullcontext()
-        if device_type == "cpu"
-        else torch.cuda.amp.autocast()
-    )
-    # -----------------------------------------------------------------------------
-
     model, tokenizer = init_model()
 
-    # -----init dataloader------
-    df = pd.read_csv('./dataset/sft_data_single.csv')
+    df = pd.read_csv('./dataset/sft_data.csv')
     df = df.sample(frac=1.0)
     train_ds = SFTDataset(df, tokenizer, max_length=max_seq_len)
     train_loader = DataLoader(
         train_ds,
-        batch_size=batch_size,
-        pin_memory=False,
+        batch_size=args.batch_size,
+        pin_memory=True,
         drop_last=False,
         shuffle=False,
-        num_workers=0,
+        num_workers=args.num_workers,
     )
 
-    scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
-    # optimizer
-    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
-    iter_per_epoch = len(train_loader)
-    # compile the model
+    scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype == 'float16'))
+    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
+
     if False and platform.system() != 'Windows' and float(torch.__version__.split('.')[0]) >= 2:
-        print("compiling the model... (takes a ~minute)")
+        Logger("compiling the model... (takes a ~minute)")
         unoptimized_model = model
         model = torch.compile(model)
 
-    raw_model = model
-    # training loop
-    for epoch in range(epochs):
+    iter_per_epoch = len(train_loader)
+    for epoch in range(args.epochs):
         train_epoch(epoch, wandb)
-        model.save_pretrained('minimind')
diff --git a/README.md b/README.md
index d7f8c34..265b5a0 100644
--- a/README.md
+++ b/README.md
@@ -213,6 +213,14 @@ streamlit run fast_inference.py
     deepspeed --master_port 29500 --num_gpus=N 3-full_sft.py
     ```
 
+* 记录训练过程
+    ```bash
+    torchrun --nproc_per_node N 1-pretrain.py --use_wandb
+    # and
+    python 1-pretrain.py --use_wandb
+    ```
+    通过添加`--use_wandb`参数,可以记录训练过程,训练完成后,可以在wandb网站上查看训练过程。通过修改`wandb_project`和`wandb_run_name`参数,可以指定项目名称和运行名称。
+
 # 📌 Data sources
 
 - 🤖 分词器:nlp中的Tokenizer类似于词典,将单词从自然语言通过“词典”映射到0,1,36这样的数字,可以理解为数字就代表了单词在“词典”中的页码。
diff --git a/README_en.md b/README_en.md
index 4b1bbbf..a0fa5a4 100644
--- a/README_en.md
+++ b/README_en.md
@@ -234,6 +234,13 @@ git clone https://github.com/jingyaogong/minimind.git
     # and
     deepspeed --master_port 29500 --num_gpus=N 3-full_sft.py
     ```
+* Record the training process
+    ```bash
+    torchrun --nproc_per_node N 1-pretrain.py --use_wandb
+    # and
+    python 1-pretrain.py --use_wandb
+    ```
+    By adding the `--use_wandb` parameter, you can record the training process. After training is complete, you can view the training process on the wandb website. You can specify the project name and run name by modifying the `wandb_project` and `wandb_run_name` parameters.
 
 # 📌 Data sources