Skip to content

Commit

Permalink
style(activation): register activation function mish to tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
StepNeverStop committed Jan 9, 2021
1 parent 1f374b0 commit 805b4d7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
15 changes: 11 additions & 4 deletions rls/nn/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
# encoding: utf-8

import tensorflow as tf
from tensorflow.keras.layers import Activation
from tensorflow.keras.utils import get_custom_objects

# from rls.utils.specs import DefaultActivationFuncType

class Mish(Activation):

def swish(x):
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
return tf.multiply(x, tf.nn.sigmoid(x))
def __init__(self, activation, **kwargs):
super().__init__(activation, **kwargs)
self.__name__ = 'mish'


def mish(x):
Expand All @@ -18,5 +21,9 @@ def mish(x):
"""
return tf.multiply(x, tf.nn.tanh(tf.nn.softplus(x)))

get_custom_objects().update({
'mish': Mish(mish)
})

default_activation = swish # 'tanh', 'relu', swish, mish

default_activation = 'swish' # 'tanh', 'relu', 'swish', 'mish'
4 changes: 2 additions & 2 deletions rls/utils/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ class DefaultActivationFuncType(Enum):
TANH = 'tanh'
RELU = 'relu'
ELU = 'elu'
SWISH = 'swish'
MISH = 'mish'
SWISH = 'swish' # https://arxiv.org/abs/1710.05941
MISH = 'mish' # https://arxiv.org/abs/1908.08681


class OutputNetworkType(Enum):
Expand Down

0 comments on commit 805b4d7

Please sign in to comment.