Skip to content

Commit

Permalink
misc refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jan 2, 2021
1 parent 8209e77 commit e9550bd
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
32 changes: 10 additions & 22 deletions Agent/DQNEnsembleAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,23 @@
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow as tf
from Agent.MaskedSoftmax import MaskedSoftmax

def combineModels(models, combiner):
shape = models[0].layers[0].input_shape[0][1:]
inputs = layers.Input(shape=shape)
actionsMask = layers.Input(shape=(4, ))
res = layers.Lambda(combiner)([actionsMask] + [ x(inputs) for x in models ])
return keras.Model(inputs=[inputs, actionsMask], outputs=res)

def maskedSoftmax(mask, inputs):
mask = tf.where(tf.equal(mask, 1))
return [
tf.sparse.to_dense(
tf.sparse.softmax(
tf.sparse.SparseTensor(
indices=mask,
values=tf.gather_nd(x, mask),
dense_shape=tf.shape(x, out_type=tf.int64)
)
)
) for x in inputs
]

def multiplyOutputs(inputs):
outputs = maskedSoftmax(inputs[0], inputs[1:])
predictions = [ layers.Reshape((1, -1))(
MaskedSoftmax()( x(inputs), actionsMask )
) for x in models ]

res = 1 + outputs[0]
for x in outputs[1:]:
res = tf.math.multiply(res, 1 + x)
return res
res = layers.Lambda(combiner)( layers.Concatenate(axis=1)(predictions) )
return keras.Model(inputs=[inputs, actionsMask], outputs=res)

@tf.function
def multiplyOutputs(outputs):
return tf.math.reduce_prod(1 + outputs, axis=1)

ENSEMBLE_MODE = {
'multiply': multiplyOutputs
Expand Down
14 changes: 14 additions & 0 deletions Agent/MaskedSoftmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import tensorflow as tf

class MaskedSoftmax(tf.keras.layers.Layer):
def call(self, inputLayer, mask):
mask = tf.where(tf.equal(mask, 1))
return tf.sparse.to_dense(
tf.sparse.softmax(
tf.sparse.SparseTensor(
indices=mask,
values=tf.gather_nd(inputLayer, mask),
dense_shape=tf.shape(inputLayer, out_type=tf.int64)
)
)
)
2 changes: 1 addition & 1 deletion view_maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# -*- coding: utf-8 -*-
import tensorflow as tf
import os
from Agent.DQNEnsembleAgent import DQNEnsembleAgent
# limit GPU usage
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_virtual_device_configuration(
Expand All @@ -15,6 +14,7 @@
import pygame.locals as G
import random
from Agent.DQNAgent import DQNAgent
from Agent.DQNEnsembleAgent import DQNEnsembleAgent
import glob
from collections import namedtuple
from model import createModel
Expand Down

0 comments on commit e9550bd

Please sign in to comment.