Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#458 from Xav1erW/Xav1erW-patch-1
Browse files Browse the repository at this point in the history
Fix env_arg and env_trainer bug while training LoRA for Aquila
  • Loading branch information
BAAI-OpenPlatform authored Jul 3, 2023
2 parents 6c5e021 + 04a180a commit 95158b9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
18 changes: 13 additions & 5 deletions flagai/env_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ def __init__(self,
self.parser.add_argument('--bmt_loss_scale', default=bmt_loss_scale, type=float, help='loss scale in bmtrain')
self.parser.add_argument('--bmt_loss_scale_steps', default=bmt_loss_scale_steps, type=int, help='loss scale steps in bmtrain')

self.parser.add_argument('--lora', default=lora, help='Use lora')
self.parser.add_argument('--lora_r', default=lora_r, help='lora r value')
self.parser.add_argument('--lora_alpha', default=lora_alpha, help='lora alpha value')
self.parser.add_argument('--lora_dropout', default=lora_dropout, help='lora dropout value')
self.parser.add_argument('--lora', default=lora, type=bool, help='Use lora')
self.parser.add_argument('--lora_r', default=lora_r, typp=int, help='lora r value')
self.parser.add_argument('--lora_alpha', default=lora_alpha, type=float, help='lora alpha value')
self.parser.add_argument('--lora_dropout', default=lora_dropout, type=float, help='lora dropout value')
self.parser.add_argument('--lora_target_modules', default=lora_target_modules, help='lora_target_modules')

## TODO, Used in caller script, configs will be updated with yaml_config.
Expand Down Expand Up @@ -204,6 +204,14 @@ def parse_args(self):
if args.env_type == "pytorch":
# not need the "not_call_launch" parameter
args.not_call_launch = True

for arg in vars(args):
# change string format list to back to python list object
value = getattr(args, arg)
if isinstance(value, str):
value = value.strip("'\"")
if value[0] == '[' and value[-1] == ']':
value = value.strip("[] ").replace(" ", "")
value = value.split(",")
setattr(args,arg,value)
return args

9 changes: 8 additions & 1 deletion flagai/env_trainer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@ def get_args_list(env_args):
for arg in args:
if not arg.startswith("__") and not arg.startswith("_") and arg not in not_need_to_launch_args:
args_list.append(f"--{arg}")
args_list.append(str(getattr(env_args, arg)))
if isinstance(getattr(env_args, arg), list):
# change list format param to string
# avoiding space in cmdline param like:
# --lora_target_modules ['wq', 'wv']
# this will interprete the wv as another cmdline param
args_list.append("'"+str(getattr(env_args, arg))+"'")
else:
args_list.append(str(getattr(env_args, arg)))

print(f"args list is {args_list}")
return args_list
Expand Down

0 comments on commit 95158b9

Please sign in to comment.