Skip to content

Commit

Permalink
Began converting A3C to KerasModel
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Jul 1, 2019
1 parent af54fae commit 1d1878e
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 194 deletions.
6 changes: 4 additions & 2 deletions deepchem/models/keras_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ def _create_training_ops(self, example_batch):
if tf.executing_eagerly():
return
self._label_placeholders = [
tf.placeholder(dtype=tf.as_dtype(t), shape=x.shape)
tf.placeholder(dtype=tf.as_dtype(t), shape=(None,) + x.shape[1:])
for x, t in zip(example_batch[1], self._label_dtypes)
]
self._weights_placeholders = [
tf.placeholder(dtype=tf.as_dtype(t), shape=x.shape)
tf.placeholder(dtype=tf.as_dtype(t), shape=(None,) + x.shape[1:])
for x, t in zip(example_batch[2], self._weights_dtypes)
]
self._loss_tensor = self._loss_fn(
Expand Down Expand Up @@ -937,6 +937,7 @@ def restore(self, checkpoint=None):
checkpoint will be chosen automatically. Call get_checkpoints() to get a
list of all available checkpoints.
"""
self._ensure_built()
if checkpoint is None:
checkpoint = tf.train.latest_checkpoint(self.model_dir)
if checkpoint is None:
Expand Down Expand Up @@ -972,6 +973,7 @@ def __call__(self, outputs, labels, weights):
shape = tuple(w.shape.as_list())
else:
shape = w.shape
shape = tuple(-1 if x is None else x for x in shape)
w = tf.reshape(w, shape + (1,) * (len(losses.shape) - len(w.shape)))
loss = losses * w
return tf.reduce_mean(loss) + sum(self.model.losses)
8 changes: 8 additions & 0 deletions deepchem/rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(self,
self._state_dtype = numpy.float32
else:
self._state_dtype = state_dtype
print(self._state_dtype)

@property
def state(self):
Expand Down Expand Up @@ -185,6 +186,13 @@ class Policy(object):
or even on different computers.
"""

def create_model(self, **kwargs):
raise NotImplemented("Subclasses must implement this")

@property
def output_names(self):
raise NotImplemented("Subclasses must implement this")

def create_layers(self, state, **kwargs):
"""Create the TensorGraph Layers that define the policy.
Expand Down
Loading

0 comments on commit 1d1878e

Please sign in to comment.