forked from ximinng/DiffSketcher
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_painterly_render.py
117 lines (102 loc) · 4.61 KB
/
run_painterly_render.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# -*- coding: utf-8 -*-
# Author: ximing
# Description: the main func of this project.
# Copyright (c) 2023, XiMing Xing.
# License: MIT License
import os
import sys
import argparse
from datetime import datetime
import random
from typing import Any, List
from functools import partial
from accelerate.utils import set_seed
import omegaconf
sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
from libs.engine import merge_and_update_config
from libs.utils.argparse import accelerate_parser, base_data_parser
def render_batch_wrap(args: omegaconf.DictConfig,
seed_range: List,
pipeline: Any,
**pipe_args):
start_time = datetime.now()
for idx, seed in enumerate(seed_range):
args.seed = seed # update seed
print(f"\n-> [{idx}/{len(seed_range)}], "
f"current seed: {seed}, "
f"current time: {datetime.now() - start_time}\n")
pipe = pipeline(args)
pipe.painterly_rendering(**pipe_args)
def main(args, seed_range):
args.batch_size = 1 # rendering one SVG at a time
render_batch_fn = partial(render_batch_wrap, args=args, seed_range=seed_range)
if args.task == "diffsketcher": # text2sketch
from pipelines.painter.diffsketcher_pipeline import DiffSketcherPipeline
if not args.render_batch:
pipe = DiffSketcherPipeline(args)
pipe.painterly_rendering(args.prompt)
else: # generate many SVG at once
render_batch_fn(pipeline=DiffSketcherPipeline, prompt=args.prompt)
# TODO: support for more task
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="vary style and content painterly rendering",
parents=[accelerate_parser(), base_data_parser()]
)
# flag
parser.add_argument("-tk", "--task",
default="diffsketcher", type=str,
choices=['diffsketcher'],
help="choose a method.")
# config
parser.add_argument("-c", "--config",
required=True, type=str,
default="",
help="YAML/YML file for configuration.")
# TODO: data path
parser.add_argument("-style", "--style_file",
default="", type=str,
help="the path of style img place.")
# prompt
parser.add_argument("-pt", "--prompt", default="A horse is drinking water by the lake", type=str)
parser.add_argument("-npt", "--negative_prompt", default="", type=str)
# DiffSVG
parser.add_argument("--print_timing", "-timing", action="store_true",
help="set print svg rendering timing.")
# diffuser
parser.add_argument("--download", action="store_true",
help="download models from huggingface automatically.")
parser.add_argument("--force_download", "-download", action="store_true",
help="force the models to be downloaded from huggingface.")
parser.add_argument("--resume_download", "-dpm_resume", action="store_true",
help="download the models again from the breakpoint.")
# rendering quantity
# like: python main.py -rdbz -srange 100 200
parser.add_argument("--render_batch", "-rdbz", action="store_true")
parser.add_argument("-srange", "--seed_range",
required=False, nargs='+',
help="Sampling quantity.")
# visual rendering process
parser.add_argument("-mv", "--make_video", action="store_true",
help="make a video of the rendering process.")
parser.add_argument("-frame_freq", "--video_frame_freq",
default=1, type=int,
help="video frame control.")
args = parser.parse_args()
# set the random seed range
seed_range = None
if args.render_batch:
# random sampling without specifying a range
start_, end_ = 1, 1000000
if args.seed_range is not None: # specify range sequential sampling
seed_range_ = list(args.seed_range)
assert len(seed_range_) == 2 and int(seed_range_[1]) > int(seed_range_[0])
start_, end_ = int(seed_range_[0]), int(seed_range_[1])
seed_range = [i for i in range(start_, end_)]
else:
# a list of lengths 1000 sampled from the range start_ to end_ (e.g.: [1, 1000000])
numbers = list(range(start_, end_))
seed_range = random.sample(numbers, k=1000)
args = merge_and_update_config(args)
set_seed(args.seed)
main(args, seed_range)