Skip to content

Commit

Permalink
add arbitrary dataset feature to each script
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jun 15, 2023
1 parent f2989b3 commit 9806b00
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 93 deletions.
54 changes: 29 additions & 25 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,33 +42,37 @@ def train(args):

tokenizer = train_util.load_tokenizer(args)

blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
)
else:
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
17 changes: 16 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,15 @@ def __getitem__(self, idx):
raise NotImplementedError


def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: MinimalDataset = dataset_class(tokenizer, args.max_token_length, args.resolution, args.debug_dataset)
return train_dataset_group


# endregion

# region モジュール入れ替え部
Expand Down Expand Up @@ -2455,14 +2464,20 @@ def add_dataset_arguments(
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)

parser.add_argument(
"--token_warmup_step",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)",
)

parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset (package.module.Class) / 任意のデータセットを用いるときのクラス名 (package.module.Class)",
)

if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
Expand Down
40 changes: 22 additions & 18 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,30 @@ def train(args):

tokenizer = train_util.load_tokenizer(args)

blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
)
else:
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
14 changes: 1 addition & 13 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,7 @@ def train(args):
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
module = ".".join(args.dataset_class.split(".")[:-1])
dataset_class = args.dataset_class.split(".")[-1]
module = importlib.import_module(module)
dataset_class = getattr(module, dataset_class)
train_dataset_group: train_util.MinimalDataset = dataset_class(
tokenizer, args.max_token_length, args.resolution, args.debug_dataset
)
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down Expand Up @@ -867,12 +861,6 @@ def setup_parser() -> argparse.ArgumentParser:
nargs="*",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
)
parser.add_argument(
"--dataset_class",
type=str,
default=None,
help="dataset class for arbitrary dataset / 任意のデータセットのクラス名",
)
return parser


Expand Down
73 changes: 38 additions & 35 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,43 +153,46 @@ def train(args):
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")

# データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
)
else:
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
use_dreambooth_method = args.in_json is None
if use_dreambooth_method:
print("Use DreamBooth method.")
user_config = {
"datasets": [
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
]
}
else:
print("Train with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}

blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)

current_epoch = Value("i", 0)
current_step = Value("i", 0)
Expand Down
11 changes: 10 additions & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight, prepare_scheduler_for_custom_training, pyramid_noise_like, apply_noise_offset, scale_v_prediction_loss_like_noise_prediction
from library.custom_train_functions import (
apply_snr_weight,
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
)
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI

imagenet_templates_small = [
Expand Down Expand Up @@ -88,6 +94,9 @@ def train(args):
print(
"sample_every_n_steps and sample_every_n_epochs are not supported in this script currently / sample_every_n_stepsとsample_every_n_epochsは現在このスクリプトではサポートされていません"
)
assert (
args.dataset_class is None
), "dataset_class is not supported in this script currently / dataset_classは現在このスクリプトではサポートされていません"

cache_latents = args.cache_latents

Expand Down

0 comments on commit 9806b00

Please sign in to comment.