-
Notifications
You must be signed in to change notification settings - Fork 47
/
monkey_patches.py
45 lines (37 loc) · 1.42 KB
/
monkey_patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import traceback
import warnings
import sys
import os
"""
Get rid of tensorboard warnings.
"""
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
"""
Warning:
Inspect where warning happens.
"""
if False:
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, 'write') else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
import pdb; pdb.set_trace()
warnings.showwarning = warn_with_traceback
"""
PL's batch size cannot be correctly computed in NKSR, fix it.
"""
if True:
import pytorch_lightning as pl
# Monkey-patch `extract_batch_size` to not raise warning from weird tensor sizes
def extract_bs(self, *args, **kwargs):
batch_size = 1
self.batch_size = batch_size
return batch_size
pl.trainer.connectors.logger_connector.result._ResultCollection._extract_batch_size = extract_bs