Skip to content

Commit

Permalink
feat: Agent classes for optimal acting
Browse files Browse the repository at this point in the history
  • Loading branch information
rafonsor committed Nov 21, 2023
1 parent 808cad6 commit 1cbe7ed
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
63 changes: 63 additions & 0 deletions unrl/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023 The unRL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unrl.types as t

import torch as pt

from unrl.functions import ActionValueFunction, DuelingActionValueFunction, Policy, ContinuousPolicy, VariationalPolicy


class Agent(t.Protocol):
def pick(self, state: pt.Tensor) -> pt.Tensor:
"""Pick greedy action for the current state"""
...


class QAgent(Agent):
"""Agent driven greedily by an optimal Action-value Function for discrete Action spaces"""
def __init__(self, action_value_model: ActionValueFunction | DuelingActionValueFunction):
self.action_value_model = action_value_model

def pick(self, state: pt.Tensor) -> pt.Tensor:
action_values = self.action_value_model(state)
return pt.argmax(action_values, dim=-1)


class PolicyAgent(Agent):
"""Optimal Policy Agent for discrete Action spaces"""
def __init__(self, policy: Policy):
self.policy = policy

def pick(self, state: pt.Tensor) -> pt.Tensor:
logprobs = self.policy(state)
return pt.argmax(logprobs, dim=-1)


class ContinuousPolicyAgent(Agent):
"""Optimal Policy Agent for continuous Action spaces"""
def __init__(self, policy: ContinuousPolicy):
self.policy = policy

def pick(self, state: pt.Tensor) -> pt.Tensor:
return self.policy(state)


class VariationalPolicyAgent(Agent):
"""Agent driven greedily by a probabilistic Policy for continuous Action spaces"""
def __init__(self, policy: VariationalPolicy):
self.policy = policy

def pick(self, state: pt.Tensor) -> pt.Tensor:
dist = self.policy.forward(state, dist=True)
return dist.mode
2 changes: 1 addition & 1 deletion unrl/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def forward(self,
state: State.
action: action to provide an estimation for.
stochastic_actions: (Optional) actions from which to compute an action-advantage expectation to subtract
from the
from the value estimate.
combine: when "True", returns action-value estimates by directly summing state-value and action advantage
estimates.
Expand Down

0 comments on commit 1cbe7ed

Please sign in to comment.