forked from carefree0910/carefree-creator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.py
829 lines (705 loc) · 25.5 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
import os
import json
import torch
import secrets
import numpy as np
from abc import ABCMeta
from PIL import Image
from enum import Enum
from typing import Any
from typing import Dict
from typing import List
from typing import Type
from typing import Union
from typing import TypeVar
from typing import Callable
from typing import Optional
from fastapi import Response
from pathlib import Path
from pydantic import Field
from pydantic import BaseModel
from functools import partial
from cftool.cv import np_to_bytes
from cftool.misc import shallow_copy_dict
from cftool.types import TNumberPair
from cflearn.zoo import DLZoo
from cflearn.parameters import OPT
from cfclient.models import TextModel
from cfclient.models import ImageModel
from cfclient.models import AlgorithmBase
from cflearn.api.cv import SDVersions
from cflearn.api.cv import DiffusionAPI
from cflearn.api.cv import TranslatorAPI
from cflearn.api.cv import ImageHarmonizationAPI
from cflearn.api.cv import ControlledDiffusionAPI
from cflearn.api.cv.diffusion import InpaintingMode
from cflearn.api.cv.diffusion import InpaintingSettings
from cflearn.misc.toolkit import download_model
from cflearn.models.cv.diffusion import StableDiffusion
from cflearn.api.cv.third_party.blip import BLIPAPI
from cflearn.api.cv.third_party.lama import LaMa
from cflearn.api.cv.third_party.isnet import ISNetAPI
from cflearn.api.cv.third_party.prompt import PromptEnhanceAPI
from .cos import download_with_retry
from .cos import download_image_with_retry
from .utils import api_pool
from .utils import to_canvas
from .utils import APIs
from .parameters import verbose
from .parameters import get_focus
from .parameters import pool_limit
from .parameters import use_controlnet
from .parameters import use_controlnet_annotator
from .parameters import Focus
class SDInpaintingVersions(str, Enum):
v1_5 = "v1.5"
BaseSDTag = "_base_sd"
NUM_CONTROL_POOL = 1
def _base_sd_path() -> str:
root = os.path.join(OPT.cache_dir, DLZoo.model_dir)
return download_model("ldm_sd_v1.5", root=root)
def _get(init_fn: Callable, init_to_cpu: bool) -> Any:
if init_to_cpu:
return init_fn()
return init_fn("cuda:0", use_half=True)
def init_sd(init_to_cpu: bool) -> ControlledDiffusionAPI:
version = SDVersions.v1_5
kw = dict(num_pool=NUM_CONTROL_POOL, lazy=True)
init_fn = partial(ControlledDiffusionAPI.from_sd_version, version, **kw)
m: ControlledDiffusionAPI = _get(init_fn, init_to_cpu)
focus = get_focus()
keep_controlnet = use_controlnet()
if not keep_controlnet:
m.disable_control()
if focus != Focus.SYNC:
m.sd_weights.limit = pool_limit()
m.current_sd_version = version
print("> registering base sd")
m.prepare_sd([version])
m.sd_weights.register(BaseSDTag, _base_sd_path())
if keep_controlnet:
print("> warmup ControlNet")
m.switch_control(*m.preset_control_hints)
if use_controlnet_annotator():
print("> prepare ControlNet Annotators")
m.prepare_annotators()
return m
def init_sd_inpainting(init_to_cpu: bool) -> ControlledDiffusionAPI:
kw = dict(num_pool=NUM_CONTROL_POOL, lazy=True)
init_fn = partial(ControlledDiffusionAPI.from_sd_inpainting, **kw)
api: ControlledDiffusionAPI = _get(init_fn, init_to_cpu)
# manually maintain sd_weights
## original weights
api.sd_weights.register(BaseSDTag, _base_sd_path())
## inpainting weights
root = os.path.join(OPT.cache_dir, DLZoo.model_dir)
inpainting_path = download_model("ldm.sd_inpainting", root=root)
api.sd_weights.register(SDInpaintingVersions.v1_5, inpainting_path)
api.current_sd_version = SDInpaintingVersions.v1_5
# inject properties from sd
register_sd()
sd: ControlledDiffusionAPI = api_pool.get(APIs.SD, no_change=True)
api.annotators = sd.annotators
api.controlnet_weights = sd.controlnet_weights
if use_controlnet():
api.switch_control(*api.preset_control_hints)
return api
def register_sd() -> None:
api_pool.register(APIs.SD, init_sd)
def register_sd_inpainting() -> None:
api_pool.register(APIs.SD_INPAINTING, init_sd_inpainting)
def register_esr() -> None:
api_pool.register(
APIs.ESR,
lambda init_to_cpu: _get(TranslatorAPI.from_esr, init_to_cpu),
)
def register_esr_anime() -> None:
api_pool.register(
APIs.ESR_ANIME,
lambda init_to_cpu: _get(TranslatorAPI.from_esr_anime, init_to_cpu),
)
def register_esr_ultrasharp() -> None:
def _init(*args: Any, **kw: Any) -> TranslatorAPI:
m = TranslatorAPI.from_esr(*args, **kw)
sr_folder = os.path.join(OPT.external_dir, "sr")
model_path = os.path.join(sr_folder, "4x-UltraSharp.ckpt")
if not os.path.isfile(model_path):
raise ValueError(f"cannot find {model_path}")
m.m.load_state_dict(torch.load(model_path))
return m
api_pool.register(
APIs.ESR_ULTRASHARP,
lambda init_to_cpu: _get(_init, init_to_cpu),
)
def register_inpainting() -> None:
api_pool.register(
APIs.INPAINTING,
lambda init_to_cpu: _get(DiffusionAPI.from_inpainting, init_to_cpu),
)
def register_lama() -> None:
api_pool.register(
APIs.LAMA,
lambda init_to_cpu: _get(LaMa, init_to_cpu),
)
def register_semantic() -> None:
api_pool.register(
APIs.SEMANTIC,
lambda init_to_cpu: _get(DiffusionAPI.from_semantic, init_to_cpu),
)
def register_hrnet() -> None:
api_pool.register(
APIs.HRNET,
lambda init_to_cpu: _get(ImageHarmonizationAPI, init_to_cpu),
)
def register_isnet() -> None:
api_pool.register(
APIs.ISNET,
lambda init_to_cpu: _get(ISNetAPI, init_to_cpu),
)
def register_blip() -> None:
api_pool.register(
APIs.BLIP,
lambda init_to_cpu: _get(BLIPAPI, init_to_cpu),
)
def register_prompt_enhance() -> None:
api_pool.register(
APIs.PROMPT_ENHANCE,
lambda init_to_cpu: _get(PromptEnhanceAPI, init_to_cpu),
)
def get_normalized_arr_from_diffusion(img_arr: np.ndarray) -> np.ndarray:
img_arr = 0.5 * (img_arr + 1.0)
img_arr = img_arr.transpose([1, 2, 0])
return img_arr
# API models
class CallbackModel(BaseModel):
callback_url: str = Field("", description="callback url to post to")
class UseAuditModel(BaseModel):
use_audit: bool = Field(False, description="Whether audit the outputs.")
class MaxWHModel(BaseModel):
max_wh: int = Field(1024, description="The maximum resolution.")
class VariationModel(BaseModel):
seed: int = Field(..., description="Seed of the variation.")
strength: float = Field(
...,
ge=0.0,
le=1.0,
description="Strength of the variation.",
)
class SDSamplers(str, Enum):
DDIM = "ddim"
PLMS = "plms"
KLMS = "klms"
SOLVER = "solver"
K_EULER = "k_euler"
K_EULER_A = "k_euler_a"
K_HEUN = "k_heun"
K_DPMPP_2M = "k_dpmpp_2m"
LCM = "lcm"
class SigmasScheduler(str, Enum):
KARRAS = "karras"
class TomeInfoModel(BaseModel):
enable: bool = Field(False, description="Whether enable tomesd.")
ratio: float = Field(0.5, description="The ratio of tokens to merge.")
max_downsample: int = Field(
1,
description="Apply ToMe to layers with at most this amount of downsampling.",
)
sx: int = Field(2, description="The stride for computing dst sets.")
sy: int = Field(2, description="The stride for computing dst sets.")
seed: int = Field(
-1,
ge=-1,
lt=2**32,
description="""
Seed of the generation.
> If `-1`, then seed from `DiffusionModel` will be used.
> If `DiffusionModel.seed` is also `-1`, then random seed will be used.
""",
)
use_rand: bool = Field(True, description="Whether allow random perturbations.")
merge_attn: bool = Field(True, description="Whether merge attention.")
merge_crossattn: bool = Field(False, description="Whether merge cross attention.")
merge_mlp: bool = Field(False, description="Whether merge mlp.")
class StyleReferenceModel(BaseModel):
url: Optional[str] = Field(
None,
description="The url of the style image, `None` means not enabling style reference.",
)
style_fidelity: float = Field(
0.5,
description="Style fidelity, larger means reference more on the given style image.",
)
reference_weight: float = Field(
1.0,
ge=0.0,
le=1.0,
description=(
"Reference weight, similar to `control_strength`, "
"but value > 1.0 or value < 0.0 will take no effect, "
"so we strictly restrict it to [0.0, 1.0]."
),
)
class HighresModel(BaseModel):
fidelity: float = Field(0.3, description="Fidelity of the original latent.")
upscale_factor: float = Field(2.0, description="Upscale factor.")
upscale_method: str = Field("nearest-exact", description="Upscale method.")
max_wh: int = Field(1024, description="Max width or height of the output image.")
class DiffusionModel(CallbackModel):
use_circular: bool = Field(
False,
description="Whether should we use circular pattern (e.g. generate textures).",
)
seed: int = Field(
-1,
ge=-1,
lt=2**32,
description="""
Seed of the generation.
> If `-1`, then random seed will be used.
""",
)
variation_seed: int = Field(
0,
ge=0,
lt=2**32,
description="""
Seed of the variation generation.
> Only take effects when `variation_strength` is larger than 0.
""",
)
variation_strength: float = Field(
0.0,
ge=0.0,
le=1.0,
description="Strength of the variation generation.",
)
variations: List[VariationModel] = Field(
default_factory=lambda: [],
description="Variation ingredients",
)
num_steps: int = Field(20, description="Number of sampling steps", ge=1, le=100)
guidance_scale: float = Field(
7.5,
description="Guidance scale for classifier-free guidance.",
)
negative_prompt: str = Field(
"",
description="Negative prompt for classifier-free guidance.",
)
is_anime: bool = Field(
False,
description="Whether should we generate anime images or not.",
)
version: str = Field(
SDVersions.v1_5,
description="Version of the diffusion model",
)
sampler: SDSamplers = Field(
SDSamplers.K_EULER,
description="Sampler of the diffusion model",
)
sigmas_scheduler: Optional[SigmasScheduler] = Field(
None,
description="Sigmas scheduler of the k-samplers, `None` will use default.",
)
clip_skip: int = Field(
-1,
ge=-1,
le=8,
description="""
Number of CLIP layers that we want to skip.
> If it is set to `-1`, then `clip_skip` = 1 if `is_anime` else 0.
""",
)
custom_embeddings: Dict[str, Union[str, List[List[float]]]] = Field(
{},
description="Custom embeddings, often used in textual inversion.",
)
tome_info: TomeInfoModel = Field(TomeInfoModel(), description="tomesd settings.")
style_reference: StyleReferenceModel = Field(
StyleReferenceModel(),
description="style reference settings.",
)
highres_info: Optional[HighresModel] = Field(None, description="Highres info.")
lora_scales: Optional[Dict[str, float]] = Field(
None,
description="lora scales, key is the name, value is the weight.",
)
lora_paths: Optional[List[str]] = Field(
None,
description="If provided, we will dynamically load lora from the given paths.",
)
class ReturnArraysModel(BaseModel):
return_arrays: bool = Field(
False,
description="Whether return List[np.ndarray] directly, only for internal usages.",
)
class CommonSDInpaintingModel(ReturnArraysModel, MaxWHModel):
keep_original: bool = Field(
False,
description="Whether strictly keep the original image identical in the output image.",
)
keep_original_num_fade_pixels: Optional[int] = Field(
50,
description="Number of pixels to fade the original image.",
)
use_raw_inpainting: bool = Field(
False,
description="""
Whether use the raw inpainting method.
> This is useful when you want to apply inpainting with custom SD models.
""",
)
use_background_guidance: bool = Field(
False,
description="""
Whether inject the latent of the background during the generation.
> If `use_raw_inpainting`, this will always be `True` because in this case
the latent of the background is the only information for us to inpaint.
""",
)
use_reference: bool = Field(
False,
description="Whether use the original image as reference.",
)
use_background_reference: bool = Field(
False,
description="Whether use the original image background as reference.",
)
reference_fidelity: float = Field(
0.0,
description="Fidelity of the reference image, only take effects when `use_reference` is `True`.",
)
inpainting_mode: InpaintingMode = Field(
InpaintingMode.NORMAL,
description="Inpainting mode. MASKED is preferred when the masked area is small.",
)
inpainting_mask_blur: Optional[int] = Field(
None,
description="The smoothness of the inpainting's mask, `None` means no smooth.",
)
inpainting_mask_padding: Optional[int] = Field(
32,
description="Padding of the inpainting mask under MASKED mode. If `None`, then no padding.",
)
inpainting_mask_binary_threshold: Optional[int] = Field(
32,
description="Binary threshold of the inpainting mask under MASKED mode. If `None`, then no thresholding.",
)
inpainting_target_wh: TNumberPair = Field(
None,
description="Target width and height of the images under MASKED mode.",
)
inpainting_padding_mode: Optional[str] = Field(None, description="Padding mode.")
class Txt2ImgModel(DiffusionModel, MaxWHModel, TextModel):
pass
class Img2ImgModel(MaxWHModel, ImageModel):
pass
class Img2ImgDiffusionModel(DiffusionModel, Img2ImgModel):
pass
class _ControlNetCoreModel(BaseModel):
hint_url: str = Field(
"",
description="""
The `cdn` / `cos` url of the user's hint image.
> If empty string is provided, we will use `url` as `hint_url`.
> `cos` url from `qcloud` is preferred.
""",
)
hint_annotator: Optional[str] = Field(
None,
description="""
The annotator type of the hint.
> If not specified, will use the control type as the annotator's type.
""",
)
hint_start: Optional[float] = Field(None, description="start ratio of the control")
hint_end: Optional[float] = Field(None, description="end ratio of the control")
control_strength: float = Field(1.0, description="The strength of the control.")
hint_binarize_threshold: Optional[int] = Field(
None,
ge=0,
le=255,
description="The threshold for binarizing the hint, None means no binarization.",
)
extra_annotator_params: Optional[Dict[str, Any]] = Field(
None,
description="Extra parameters for the annotator.",
)
bypass_annotator: bool = Field(False, description="Bypass the annotator.")
guess_mode: bool = Field(False, description="Guess mode.")
no_switch: bool = Field(
False,
description="Whether not to switch the ControlNet weights even when the base model has switched.",
)
class _ControlNetModel(_ControlNetCoreModel, UseAuditModel):
url: Optional[str] = Field(None, description="specify this to do img2img")
prompt: str = Field("", description="Prompt.")
fidelity: float = Field(
0.05,
ge=0.0,
le=1.0,
description="The fidelity of the input image, only take effects when `url` is not `None`.",
)
num_samples: int = Field(1, ge=1, le=4, description="Number of samples.")
base_model: str = Field(
SDVersions.v1_5,
description="The base model.",
)
mask_url: Optional[str] = Field(None, description="specify this to do inpainting")
use_inpainting: bool = Field(False, description="Whether use inpainting model.")
@property
def version(self) -> str:
return self.base_model
class ControlNetModel(CommonSDInpaintingModel, DiffusionModel, _ControlNetModel):
pass
def handle_diffusion_model(
m: DiffusionAPI,
data: DiffusionModel,
*,
always_uncond: bool = True,
) -> Dict[str, Any]:
if data.seed >= 0:
seed = data.seed
else:
seed = secrets.randbelow(2**32)
variation_seed = None
variation_strength = None
if data.variation_strength > 0:
variation_seed = data.variation_seed
variation_strength = data.variation_strength
if data.variations is None:
variations = None
else:
variations = [(v.seed, v.strength) for v in data.variations]
m.switch_circular(data.use_circular)
if not always_uncond and not data.negative_prompt:
unconditional_cond = None
else:
unconditional_cond = [data.negative_prompt]
clip_skip = data.clip_skip
if clip_skip == -1:
if data.is_anime or data.version.startswith("anime"):
clip_skip = 1
else:
clip_skip = 0
# lora
model = m.m
if isinstance(model, StableDiffusion):
manager = model.lora_manager
if manager.injected:
m.cleanup_sd_lora()
if data.lora_scales:
user_folder = os.path.expanduser("~")
external_folder = os.path.join(user_folder, ".cache", "external")
lora_folder = os.path.join(external_folder, "lora")
for key in data.lora_scales:
if model.lora_manager.has(key):
continue
if not os.path.isdir(lora_folder):
raise ValueError(
f"'{key}' does not exist in current loaded lora "
f"and '{lora_folder}' does not exist either."
)
for lora_file in os.listdir(lora_folder):
lora_name = os.path.splitext(lora_file)[0]
if key != lora_name:
continue
try:
print(f">> loading {key}")
lora_path = os.path.join(lora_folder, lora_file)
m.load_sd_lora(lora_name, path=lora_path)
except Exception as err:
raise ValueError(f"failed to load {key}: {err}")
m.inject_sd_lora(*list(data.lora_scales))
m.set_sd_lora_scales(data.lora_scales)
# custom embeddings
if not data.custom_embeddings:
custom_embeddings = None
else:
custom_embeddings = {}
for k, v in data.custom_embeddings.items():
if isinstance(v, str):
with open(v, "r") as f:
v = json.load(f)
custom_embeddings[k] = v
# return
return dict(
seed=seed,
variation_seed=variation_seed,
variation_strength=variation_strength,
variations=variations,
num_steps=data.num_steps,
unconditional_guidance_scale=data.guidance_scale,
unconditional_cond=unconditional_cond,
sampler=data.sampler,
sigmas_scheduler=data.sigmas_scheduler,
verbose=verbose(),
clip_skip=clip_skip,
custom_embeddings=custom_embeddings,
highres_info=None if data.highres_info is None else data.highres_info.dict(),
)
def handle_diffusion_inpainting_model(data: CommonSDInpaintingModel) -> Dict[str, Any]:
return dict(
anchor=64,
max_wh=data.max_wh,
keep_original=data.keep_original,
keep_original_num_fade_pixels=data.keep_original_num_fade_pixels,
use_raw_inpainting=data.use_raw_inpainting,
use_background_guidance=data.use_background_guidance,
use_reference=data.use_reference,
use_background_reference=data.use_background_reference,
reference_fidelity=data.reference_fidelity,
inpainting_settings=InpaintingSettings(
data.inpainting_mode,
data.inpainting_mask_blur,
data.inpainting_mask_padding,
data.inpainting_mask_binary_threshold,
data.inpainting_target_wh,
data.inpainting_padding_mode,
),
)
async def handle_diffusion_hooks(
m: DiffusionAPI,
data: DiffusionModel,
algorithm: "IAlgorithm",
kwargs: Dict[str, Any],
) -> None:
# tomesd
tome_info = data.tome_info.dict()
enable_tome = tome_info.pop("enable")
if not enable_tome:
tome_info = None
else:
if tome_info["seed"] == -1:
tome_info["seed"] = kwargs.get("seed", secrets.randbelow(2**32))
# style reference
style_reference = data.style_reference.dict()
style_url = style_reference.pop("url")
existing = kwargs.get("style_reference.url")
if existing is not None and isinstance(existing, Image.Image):
style_image = existing
else:
if style_url is not None:
style_image = await algorithm.download_image_with_retry(style_url)
else:
style_image = None
style_reference = None
# setup
m.setup_hooks(
tome_info=tome_info,
style_reference_image=style_image,
style_reference_states=style_reference,
)
class GetPromptModel(BaseModel):
text: str
need_translate: bool = Field(
True,
description="Whether we need to translate the input text.",
)
class GetPromptResponse(BaseModel):
text: str
success: bool
reason: str
def endpoint2algorithm(endpoint: str) -> str:
return endpoint[1:].replace("/", ".")
TAlgo = TypeVar("TAlgo", bound=Type[AlgorithmBase])
class IAlgorithm(AlgorithmBase, metaclass=ABCMeta):
model_class: Type[BaseModel]
response_model_class: Optional[Type[BaseModel]] = None
last_latencies: Dict[str, float] = {}
@classmethod
def auto_register(cls) -> Callable[[TAlgo], TAlgo]:
def _register(cls_: TAlgo) -> TAlgo:
return cls.register(endpoint2algorithm(cls_.endpoint))(cls_)
return _register
def log_times(self, latencies: Dict[str, float]) -> None:
from cfcreator.sdks.apis import ALL_LATENCIES_KEY
self.last_latencies = shallow_copy_dict(latencies)
latencies.pop(ALL_LATENCIES_KEY, None)
super().log_times(latencies)
async def download_with_retry(self, url: str, **kw: Any) -> bytes:
return await download_with_retry(self.http_client.session, url, **kw)
async def download_image_with_retry(self, url: str, **kw: Any) -> Image.Image:
return await download_image_with_retry(self.http_client.session, url, **kw)
async def get_image_from(
self,
key: str,
data: BaseModel,
kwargs: Dict[str, Any],
**request_kw: Any,
) -> Image.Image:
existing = kwargs.pop(key, None)
if existing is not None and isinstance(existing, Image.Image):
return existing
return await self.download_image_with_retry(getattr(data, key), **request_kw)
class IWrapperAlgorithm(IAlgorithm):
algorithms: Optional[Dict[str, IAlgorithm]] = None
def initialize(self) -> None:
from cfcreator.sdks.apis import APIs
from cfcreator.sdks.apis import ALL_LATENCIES_KEY
from cfcreator.sdks.apis import EXCEPTION_MESSAGE_KEY
if self.algorithms is None:
raise ValueError("`algorithms` should be provided for `IWrapperAlgorithm`.")
self.apis = APIs(
clients=self.clients,
algorithms=self.algorithms,
verbose=None,
lazy_load=None,
)
self.latencies_key = ALL_LATENCIES_KEY
self.exception_message_key = EXCEPTION_MESSAGE_KEY
# kafka
class Status(str, Enum):
PENDING = "pending"
WORKING = "working"
FINISHED = "finished"
EXCEPTION = "exception"
INTERRUPTED = "interrupted"
NOT_FOUND = "not_found"
# shortcuts
class SDParameters(BaseModel):
is_anime: bool
version: str
lora_paths: Optional[List[str]]
def load_sd_lora_with(sd: ControlledDiffusionAPI, data: SDParameters) -> None:
if data.lora_paths:
for lora_path in data.lora_paths:
key = Path(lora_path).stem
if isinstance(sd.m, StableDiffusion) and sd.m.lora_manager.has(key):
continue
sd.load_sd_lora(key, path=lora_path)
def get_sd_from(api_key: APIs, data: SDParameters, **kw: Any) -> ControlledDiffusionAPI:
if not data.is_anime:
version = data.version
else:
version = data.version if data.version.startswith("anime") else "anime"
sd: ControlledDiffusionAPI = api_pool.get(api_key, **kw)
if api_key != APIs.SD_INPAINTING:
sd.prepare_sd([version])
elif version != SDInpaintingVersions.v1_5:
sd.prepare_sd([version], sub_folder="inpainting", force_external=True)
sd.switch_sd(version)
sd.disable_control()
load_sd_lora_with(sd, data)
return sd
def get_response(data: ReturnArraysModel, results: List[np.ndarray]) -> Any:
if data.return_arrays:
return results
return Response(content=np_to_bytes(to_canvas(results)), media_type="image/png")
__all__ = [
"endpoint2algorithm",
"DiffusionModel",
"HighresModel",
"Txt2ImgModel",
"Img2ImgModel",
"Img2ImgDiffusionModel",
"UseAuditModel",
"ReturnArraysModel",
"ControlNetModel",
"GetPromptModel",
"GetPromptResponse",
"Status",
"IAlgorithm",
"IWrapperAlgorithm",
]