Skip to content

Commit

Permalink
cleaning up a little
Browse files Browse the repository at this point in the history
  • Loading branch information
mginoya committed Aug 26, 2024
1 parent 1f5f123 commit 2d72763
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
6 changes: 3 additions & 3 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def reset(self, rng: jax.Array) -> State:
low, hi = -self._reset_noise_scale, self._reset_noise_scale

jcmd = self._sample_command(rng3)
#wcmd = self._sample_waypoint(rng3)
wcmd = self._sample_waypoint(rng3)

wcmd = jp.array([0.0, 10.0])
#wcmd = jp.array([0.0, 10.0])

q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
Expand Down Expand Up @@ -255,7 +255,7 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
key3, (1,), minval=z_range[0], maxval=z_range[1]
)

wcmd = jp.array([x[0], y[0], z[0]])
wcmd = jp.array([x[0], y[0]])

return wcmd

Expand Down
2 changes: 1 addition & 1 deletion experiments/AAnt-locomotion/one_physics_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
policy_params = (params[0], params[1])
inference_fn = make_policy(policy_params)

wcmd = jp.array([0.0, 1000.0])
wcmd = jp.array([0.0, 10.0])
key_envs, _ = jax.random.split(rng)
state = env.reset(rng=key_envs)

Expand Down
38 changes: 25 additions & 13 deletions experiments/AAnt-locomotion/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,22 @@
config={
"env_name": "AAnt",
"backend": "positional",
"seed": 0,
"seed": 1,
"len_training": 1_500_000,
"num_evals": 200,
"num_envs": 2048,
"batch_size": 2048,
"num_minibatches": 8,
"updates_per_batch": 8,
"episode_len": 1000,
"unroll_len": 10,
"reward_scaling":0.1,
"action_repeat": 1,
"discounting": 0.97,
"learning_rate": 3e-4,
"entropy_cost": 1e-3,
"reward_scaling": 0.1,
"normalize_obs": True,
},
)

Expand Down Expand Up @@ -126,20 +138,20 @@ def progress(num_steps, metrics):
train_fn = functools.partial(
ppo.train,
num_timesteps=wandb.config.len_training,
num_evals=400,
reward_scaling=0.1,
num_evals=wandb.config.num_evals,
reward_scaling=wandb.config.reward_scaling,
episode_length=wandb.config.episode_len,
normalize_observations=True,
action_repeat=1,
unroll_length=10,
num_minibatches=8,
num_updates_per_batch=8,
discounting=0.97,
learning_rate=3e-4,
entropy_cost=1e-3,
num_envs=2048,
normalize_observations=wandb.config.normalize_obs,
action_repeat=wandb.config.action_repeat,
unroll_length=wandb.config.unroll_len,
num_minibatches=wandb.config.num_minibatches,
num_updates_per_batch=wandb.config.updates_per_batch,
discounting=wandb.config.discounting,
learning_rate=wandb.config.learning_rate,
entropy_cost=wandb.config.entropy_cost,
num_envs=wandb.config.num_envs,
batch_size=wandb.config.batch_size,
seed=1,
seed=wandb.config.seed,
in_params=mParams,
)

Expand Down
4 changes: 2 additions & 2 deletions experiments/AAnt-locomotion/vis_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
jit_env_step = jax.jit(env.step)

rollout = []
rng = jax.random.PRNGKey(seed=13294)
rng = jax.random.PRNGKey(seed=13194)
state = jit_env_reset(rng=rng)

normalize = lambda x, y: x
Expand All @@ -79,7 +79,7 @@
#yaw_vel = 0.0 # rad/s
#jcmd = jp.array([x_vel, y_vel, yaw_vel])

wcmd = jp.array([0.0, 10.0])
wcmd = jp.array([-10.0, 10.0])

# generate policy rollout
for _ in range(episode_length):
Expand Down

0 comments on commit 2d72763

Please sign in to comment.