-
Notifications
You must be signed in to change notification settings - Fork 40
/
enqueue_many_test.py
83 lines (63 loc) · 2.42 KB
/
enqueue_many_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os, sys
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"]=""
import tensorflow as tf
def create_session():
config = tf.ConfigProto(log_device_placement=False)
config.operation_timeout_in_ms=5000 # terminate on long hangs
config.gpu_options.per_process_gpu_memory_fraction=0.3 # don't hog all vRAM
sess = tf.InteractiveSession("", config=config)
return sess
import time
import threading
import os
os.environ['PYTHONUNBUFFERED'] = 'True'
from google.protobuf.internal import api_implementation
assert api_implementation._default_implementation_type == 'cpp'
from tensorflow.python.client import timeline
tf.reset_default_graph()
reverse = False
if len(sys.argv)>1:
assert sys.argv[1] == 'reverse'
reverse = True
n = 10**6
dtype = tf.int32
queue = tf.FIFOQueue(capacity=2*n, dtypes=[dtype], shapes=[()])
zeros = tf.Variable(tf.zeros((n), name="0", dtype=dtype))
ones = tf.Variable(tf.ones((n), name="1", dtype=dtype))
enqueue_zeros = queue.enqueue_many(zeros, name="zeros")
enqueue_ones = queue.enqueue_many(ones, name="ones")
sess = create_session()
sess.run(tf.global_variables_initializer())
start_time0 = time.time()
run_metadatas = []
def run_op(op):
start_time = time.time()
print("%10.2f ms: starting op %s\n" % ((start_time-start_time0)*1000, op.name), flush=True, end='')
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
sess.run(op, options=options, run_metadata=run_metadata)
end_time = time.time()
print("%10.2f ms: ending op %s\n" % ((end_time-start_time0)*1000, op.name), flush=True, end='')
run_metadatas.append(run_metadata)
threads = [threading.Thread(group=None, target=run_op, args=(op,)) for op in (enqueue_zeros, enqueue_ones)]
if reverse:
threads.reverse()
for t in threads:
t.start()
# wait for threads to finish
for t in threads:
t.join()
# generate merged timeline
merged_metadata = tf.RunMetadata()
for run_metadata in run_metadatas:
merged_metadata.MergeFrom(run_metadata)
tl = timeline.Timeline(merged_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open(sys.argv[0]+'_%s_timeline.json'%(reverse), 'w') as f:
f.write(ctf)
assert sess.run(queue.size()) == 2*n
result = sess.run(queue.dequeue_many(2*n))
padding = np.array([0])
diffs = np.concatenate([padding, result])-np.concatenate([result, padding])
print("Interleaving detected: %s" % (abs(diffs).sum()>2))