-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
38 lines (29 loc) · 1.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from skimage.color import lab2rgb, rgb2lab
from torch import nn
from datetime import datetime
def freeze_module(module):
for param in module.parameters():
param.requires_grad = False
def get_device():
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
device = "mps"
if device == "mps":
print(
"WARNING: MPS currently doesn't seem to work, and messes up backpropagation without any visible torch errors. I recommend using CUDA on a colab notebook or CPU instead if you're facing inexplicable issues with generations."
)
return (device)
def show_pil(img):
fig = plt.imshow(img)
fig.axes.get_xaxis().set_visible(False)
fig.axes.get_yaxis().set_visible(False)
plt.show()
def get_timestamp():
current_time = datetime.now()
timestamp = current_time.strftime("%H:%M:%S")
return timestamp