-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace tensor_utils.compile_function() with make_callable #326
Comments
There will be two ways of doing it: def compile_function(inputs, outputs, log_name=None):
def run(*input_vals):
sess = tf.get_default_session()
return sess.make_callable(outputs, feed_list=inputs)(*input_vals)
return run
f = compile_function(inputs=[obs_var], outputs=[mean_var, log_std_var])
values = f(obs) or using only tf.Session, sess = tf.get_default_session()
values = sess.run([mean_var, log_std_var], feed_dict={obs_var: obs}) In my opinion, using only tf.Session is easier to debug since we are clear what tensors are involved when we get the values. However, using compile_function, we can pass the placeholder immediately after building the network, and it will be easier to get the values by directly calling the function. |
@ahtsan The idea of I think this is the initial intention of this issue: def compile_function(inputs, outputs, log_name=None):
def run(*input_vals):
sess = tf.get_default_session()
return sess.run(outputs, feed_dict=dict(list(zip(inputs, input_vals))))
return run
f = compile_function(inputs=[obs_var], outputs=[mean_var, log_std_var])
values = f(obs) with f = sess.make_callable([mean_var, log_std_var], feed_list=[obs_var])
values = f(obs) |
but we may not have the f = sess.make_callable([mean_var, log_std_var], feed_list=[obs_var]) so we have to make it into a function, and get the default session later on? |
That makes sense. But if this is the case, then it looks to me that |
I think the idea here is that Therefore, using |
@krzentner your thoughts? |
Let's assume we will remove It will be easy to adopt either way for algorithm since a session is available there anyway, i.e. inside sess = tf.Session()
f = sess.make_callable([policy.outputs], feed_list=[policy.inputs])
...
action = f(obs)
action2 = f(obs2) and sess = tf.Session()
...
action = sess.run(policy.outputs, feed_dict={policy.inputs: obs})
action2 = sess.run(policy.outputs, feed_dict={policy.inputs: obs2}) |
@ahtsan IMO that's a feature, not a bug. magic like that is rarely worth it. I think you are pointing out that implementing this would take some significant refactoring in the primitives, so that the session is available when a primitive is constructed. I think that would be a positive refactor. Perhaps this is a large bug to address right now, though, and would be better after models are implemented. |
@ryanjulian hmm I think you miss something from @ahtsan's comments and I agree with him on #326 (comment). Since currently everything is initialized in the |
I think we all agree. By "significant refactor", I mean that a session already exists when Let's wait on this one, but keep it in mind as we switch to models. |
https://www.tensorflow.org/api_docs/python/tf/Session#make_callable
The text was updated successfully, but these errors were encountered: