Skip to content

Commit

Permalink
add transpose and one_layer arguments and use save_dir tied to run_id
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanabrooks committed Aug 5, 2021
1 parent ac5a9be commit 4df83ed
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 27 deletions.
41 changes: 28 additions & 13 deletions src/gpt_agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Optional, Tuple

import torch
Expand All @@ -12,11 +13,12 @@


class Agent(agent.Agent):
def __init__(self, obs_shape, action_space, save_interval, save_path, **kwargs):
def __init__(self, obs_shape, action_space, save_interval, save_dir, **kwargs):
nn.Module.__init__(self)

self.step = 0
self.save_path = save_path
Path(save_dir).mkdir(parents=True, exist_ok=True)
self.save_path = Path(save_dir, "linguistic-analysis.pkl")
self.save_interval = save_interval
self.base = Base(obs_shape[0], **kwargs)

Expand All @@ -42,7 +44,11 @@ def act(self, inputs, rnn_hxs, masks, deterministic=False):
action = dist.sample()

action_log_probs = dist.log_probs(action)
if self.step % self.save_interval == 0 and self.save_path is not None:
if (
self.save_interval is not None
and self.step % self.save_interval == 0
and self.save_path is not None
):
torch.save(
dict(
inputs=inputs,
Expand Down Expand Up @@ -81,14 +87,13 @@ def __init__(
randomize_parameters: bool,
hidden_size: int,
action_hidden_size: Optional[int],
kernel: int,
stride: int,
transpose: bool,
one_layer: bool,
recurrent=False,
):
super().__init__(recurrent, hidden_size, hidden_size)

self.stride = stride
self.kernel = kernel
self.transpose = transpose
gpt_size = "" if gpt_size == "small" else f"-{gpt_size}"
gpt_size = f"gpt2{gpt_size}"
self.gpt = (
Expand Down Expand Up @@ -121,11 +126,15 @@ def __init__(
nn.init.calculate_gain("relu"),
)
self.perception = nn.Sequential(
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)),
nn.ReLU(),
init_(nn.Conv2d(32, 64, 8, stride=4)),
# nn.ReLU(),
# init_(nn.Conv2d(64, embedding_size, 3, stride=1)),
*(
[init_(nn.Conv2d(num_inputs, embedding_size, 64, stride=8))]
if one_layer
else [
init_(nn.Conv2d(num_inputs, 32, 8, stride=4)),
nn.ReLU(),
init_(nn.Conv2d(32, embedding_size, 16, stride=2)),
]
)
)
self.action = (
None
Expand Down Expand Up @@ -168,7 +177,13 @@ def forward(self, inputs, rnn_hxs, masks):
if self.is_recurrent:
x, rnn_hxs = self._forward_gru(perception, rnn_hxs, masks)
else:
inputs_embeds = perception.reshape(inputs.size(0), -1, self.embedding_size)
inputs_embeds = (
perception.reshape(inputs.size(0), self.embedding_size, -1).transpose(
2, 1
)
if self.transpose
else perception.reshape(inputs.size(0), -1, self.embedding_size)
)
x = self.gpt(inputs_embeds=inputs_embeds).last_hidden_state[:, -1]
if self.action is not None:
x = self.action(x)
Expand Down
18 changes: 11 additions & 7 deletions src/gpt_main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from pprint import pformat
from typing import Literal, Optional

Expand All @@ -16,11 +17,13 @@ class Args(main.Args):
"small", "medium", "large", "xl"
] = "medium" # what size of pretrained GPT to use
kernel: int = 16
linguistic_analysis_path: Optional[
linguistic_analysis_save_interval: Optional[
str
] = None # path to save linguistic analysis data
randomize_parameters: bool = False
stride: int = 8
transpose: bool = True
one_layer: bool = False


class Trainer(main.Trainer):
Expand All @@ -31,13 +34,13 @@ def make_agent(obs_shape, action_space, args: Args) -> Agent:
action_space=action_space,
gpt_size=args.gpt_size,
hidden_size=args.hidden_size,
kernel=args.kernel,
obs_shape=obs_shape,
one_layer=args.one_layer,
randomize_parameters=args.randomize_parameters,
recurrent=args.recurrent_policy,
save_interval=args.save_interval,
save_path=args.linguistic_analysis_path,
stride=args.stride,
save_interval=args.linguistic_analysis_save_interval,
save_dir=args.save_dir,
transpose=args.transpose,
)

@staticmethod
Expand All @@ -49,14 +52,15 @@ def save(agent: Agent, args, envs):
}
logging.info("Saving parameters:")
logging.info(pformat([*non_gpt_params]))
save_path = Path(args.save_dir, f"checkpoint.pkl")
torch.save(
dict(
**non_gpt_params,
obs_rms=getattr(utils.get_vec_normalize(envs), "obs_rms", None),
),
args.save_path,
save_path,
)
logging.info(f"Saved to {args.save_path}")
logging.info(f"Saved to {save_path}")


if __name__ == "__main__":
Expand Down
21 changes: 14 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import torch
import yaml
from run_logger import HasuraLogger, Logger
from sweep_logger import HasuraLogger, Logger
from tap import Tap

import utils
Expand Down Expand Up @@ -70,8 +70,8 @@ class Args(Tap):
num_steps: int = 128 # number of forward steps in A2C
ppo_epoch: int = 3 # number of PPO updates
recurrent_policy: bool = False # use recurrence in the policy
save_interval: int = 1000 # how many updates to save between
save_path: Optional[str] = None # path to save parameters if saving locally
save_interval: Optional[int] = None # how many updates to save between
save_dir: str = "/tmp/logs" # path to save parameters if saving locally
seed: int = 0 # random seed
use_proper_time_limits: bool = False # compute returns with time limits
value_coef: float = 1 # value loss coefficient
Expand All @@ -95,6 +95,9 @@ def train(cls, args: Args, logger: Optional[Logger] = None):
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

if logger is not None:
args.save_dir = Path(args.save_dir, str(logger.run_id))

torch.set_num_threads(1)
device = torch.device("cuda:0" if args.cuda else "cpu")

Expand Down Expand Up @@ -208,9 +211,13 @@ def train(cls, args: Args, logger: Optional[Logger] = None):
rollouts.after_update()

# save for every interval-th episode or for the last epoch
if j % args.save_interval == 0 or j == num_updates - 1:
if args.save_path:
Path(args.save_path).parent.mkdir(parents=True, exist_ok=True)
if (
args.save_interval is not None
and j % args.save_interval == 0
or j == num_updates - 1
):
if args.save_dir:
args.save_dir.mkdir(parents=True, exist_ok=True)
cls.save(agent, args, envs)

if j % args.log_interval == 0: # and len(episode_rewards) > 1:
Expand Down Expand Up @@ -257,7 +264,7 @@ def save(agent, args, envs):
agent,
getattr(utils.get_vec_normalize(envs), "obs_rms", None),
],
args.save_path,
Path(args.save_dir, f"checkpoint.pkl"),
)

@staticmethod
Expand Down

0 comments on commit 4df83ed

Please sign in to comment.