-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy path09_multitask_procgen.py
More file actions
101 lines (91 loc) · 3.1 KB
/
09_multitask_procgen.py
File metadata and controls
101 lines (91 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from argparse import ArgumentParser
import wandb
from amago.envs.builtin.procgen_envs import (
TwoShotMTProcgen,
ProcgenAMAGO,
ALL_PROCGEN_GAMES,
)
from amago.nets.cnn import IMPALAishCNN
from amago import cli_utils
def add_cli(parser):
parser.add_argument("--max_seq_len", type=int, default=256)
parser.add_argument(
"--distribution",
type=str,
default="easy",
choices=["easy", "easy-rescaled", "memory-hard"],
)
parser.add_argument("--train_seeds", type=int, default=10_000)
return parser
PROCGEN_SETTINGS = {
"easy": {
"games": ["climber", "coinrun", "jumper", "ninja", "leaper"],
"reward_scales": {},
"distribution_mode": "easy",
},
"easy-rescaled": {
"games": ["climber", "coinrun", "jumper", "ninja", "leaper"],
"reward_scales": {"coinrun": 100.0, "climber": 0.1},
"distribution_mode": "easy",
},
"memory-hard": {
"games": ALL_PROCGEN_GAMES,
"reward_scales": {},
"distribution_mode": "memory-hard",
},
}
if __name__ == "__main__":
parser = ArgumentParser()
cli_utils.add_common_cli(parser)
add_cli(parser)
args = parser.parse_args()
config = {}
traj_encoder_type = cli_utils.switch_traj_encoder(
config,
arch=args.traj_encoder,
memory_size=args.memory_size,
layers=args.memory_layers,
)
tstep_encoder_type = cli_utils.switch_tstep_encoder(
config,
arch="cnn",
cnn_type=IMPALAishCNN,
channels_first=False,
drqv2_aug=True,
)
agent_type = cli_utils.switch_agent(config, args.agent_type)
cli_utils.use_config(config, args.configs)
procgen_kwargs = PROCGEN_SETTINGS[args.distribution]
horizon = 2000 if "easy" in args.distribution else 5000
make_train_env = lambda: ProcgenAMAGO(
TwoShotMTProcgen(**procgen_kwargs, seed_range=(0, args.train_seeds)),
)
make_test_env = lambda: ProcgenAMAGO(
TwoShotMTProcgen(
**procgen_kwargs, seed_range=(args.train_seeds + 1, 10_000_000)
),
)
group_name = f"{args.run_name}_{args.distribution}_procgen_l_{args.max_seq_len}"
for trial in range(args.trials):
run_name = group_name + f"_trial_{trial}"
experiment = cli_utils.create_experiment_from_cli(
args,
make_train_env=make_train_env,
make_val_env=make_test_env,
max_seq_len=args.max_seq_len,
traj_save_len=args.max_seq_len * 4,
run_name=run_name,
tstep_encoder_type=tstep_encoder_type,
traj_encoder_type=traj_encoder_type,
agent_type=agent_type,
group_name=group_name,
val_timesteps_per_epoch=5 * horizon + 1,
)
experiment = cli_utils.switch_async_mode(experiment, args.mode)
experiment.start()
if args.ckpt is not None:
experiment.load_checkpoint(args.ckpt)
experiment.learn()
experiment.evaluate_test(make_test_env, timesteps=horizon * 20, render=False)
experiment.delete_buffer_from_disk()
wandb.finish()