Skip to content

Commit

Permalink
Addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed Nov 30, 2023
1 parent 8fc3828 commit a05515f
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
pip install --upgrade "jax[cpu]==0.4.13"
- name: Train on debug dataset
run: WANDB_MODE=disabled python train.py --config experiments/dibya/debug_config.py --name debug
run: WANDB_MODE=disabled python train.py --config tests/debug_config.py --name debug
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ See the [Jax Github page](https://github.com/google/jax) for more details on ins

Test the installation by training on the debug dataset:
```
python train.py --config experiments/dibya/debug_config.py --name debug
python train.py --config tests/debug_config.py --debug
```

## Training
Expand Down Expand Up @@ -67,7 +67,7 @@ Steps to contribute:
1. Fork the repo and create your branch from `master`.
2. Use `pre-commit` to enable code checks and auto-formatting.
3. Test that a basic training starts with the debug dataset with: ```
python train.py --config experiments/dibya/debug_config.py --name debug
python train.py --config tests/debug_config.py
```
Expand Down
8 changes: 7 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ def update_config(config, **kwargs):


def wrap(f):
"""Simple wrapper to enable passing config strings to `get_config`"""
"""Simple wrapper to enable passing config strings to `get_config`
Usage:
python train.py --config=config.py:vit_s,multimodal
python train.py --config=config.py:transformer_size=vit_s
"""

def wrapped_f(config_string=None):
if config_string is None:
Expand Down
8 changes: 4 additions & 4 deletions experiments/dibya/debug_config.py → tests/debug_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_config():
del base_config["dataset_kwargs"]["oxe_kwargs"]
config = update_config(
base_config,
num_steps=20,
num_steps=2,
optimizer=dict(
learning_rate=dict(
warmup_steps=1,
Expand All @@ -18,16 +18,16 @@ def get_config():
batch_size=64,
shuffle_buffer_size=1000,
num_val_batches=1,
log_interval=10,
eval_interval=10,
log_interval=1,
eval_interval=2,
eval_datasets=None,
trajs_for_metrics=1,
trajs_for_viz=1,
dataset_kwargs={
"data_kwargs_list": [
{
"name": "bridge_dataset",
"data_dir": "./datasets/debug_dataset",
"data_dir": "./tests/debug_dataset",
"image_obs_keys": ["image_0"],
"state_obs_keys": ["state"],
},
Expand Down

0 comments on commit a05515f

Please sign in to comment.