Skip to content

Commit

Permalink
sanity check with A0
Browse files Browse the repository at this point in the history
  • Loading branch information
mginoya committed Sep 25, 2023
1 parent 8ad6c58 commit c3a0ef9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
8 changes: 5 additions & 3 deletions alfredo/agents/A0/alfredo_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def step(self, state: State, action: jp.ndarray) -> State:
# print(f"a_vel -> {a_velocity}")

reward_vel = math.safe_norm(a_velocity)
forward_reward = self._forward_reward_weight * reward_vel
forward_reward = self._forward_reward_weight * a_velocity[0]
# print(f"a_vel -> {a_velocity}")
# print(f"target_pos -> {pipeline_state.q[-2:]}")
dist_diff = jp.array(
Expand All @@ -318,7 +318,7 @@ def step(self, state: State, action: jp.ndarray) -> State:
else:
healthy_reward = self._healthy_reward * is_healthy

reward = healthy_reward - ctrl_cost # + forward_reward #+ reward_to_target
reward = healthy_reward - ctrl_cost + forward_reward # + reward_to_target

done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0

Expand Down Expand Up @@ -373,7 +373,9 @@ def _get_obs(self, pipeline_state: base.State, action: jp.ndarray) -> jp.ndarray
# com_ang = xd_i.ang
# com_velocity = jp.hstack([com_vel, com_ang])

qfrc_actuator = actuator.to_tau(self.sys, action, pipeline_state.q, pipeline_state.qd)
qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd
)

# external_contact_forces are excluded
# return jp.concatenate([
Expand Down
37 changes: 21 additions & 16 deletions experiments/Alfredo-ex1/seq_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import jax
import matplotlib.pyplot as plt
import optax
import wandb
from brax import envs
# from brax.envs.wrappers import training
from brax.io import html, json, model
Expand All @@ -23,17 +24,16 @@
from alfredo.agents.A0 import Alfredo
from alfredo.train import ppo


import wandb
# Initialize a new run
wandb.init(project="alfredo",
config = {
wandb.init(
project="alfredo",
config={
"env_name": "A0",
"backend": "positional",
"seed": 0,
"len_training": 1_000_000,
"len_training": 1_500_000,
"batch_size": 1024,
}
},
)

# ==============================
Expand All @@ -45,14 +45,19 @@

def progress(num_steps, metrics):
print(num_steps)
wandb.log({"step": num_steps,
"Total Reward": metrics['eval/episode_reward'],
"Target Reward": metrics['eval/episode_reward_to_target'],
"Vel Reward": metrics['eval/episode_reward_velocity'],
"Alive Reward": metrics['eval/episode_reward_alive'],
"Ctrl Reward": metrics['eval/episode_reward_ctrl'],
"a_vel_x": metrics['eval/episode_agent_x_velocity'],
"a_vel_y": metrics['eval/episode_agent_y_velocity']})
wandb.log(
{
"step": num_steps,
"Total Reward": metrics["eval/episode_reward"],
"Target Reward": metrics["eval/episode_reward_to_target"],
"Vel Reward": metrics["eval/episode_reward_velocity"],
"Alive Reward": metrics["eval/episode_reward_alive"],
"Ctrl Reward": metrics["eval/episode_reward_ctrl"],
"a_vel_x": metrics["eval/episode_agent_x_velocity"],
"a_vel_y": metrics["eval/episode_agent_y_velocity"],
}
)


# ==============================
# General Variable Defs
Expand Down Expand Up @@ -124,13 +129,13 @@ def progress(num_steps, metrics):
train_fn = functools.partial(
ppo.train,
num_timesteps=wandb.config.len_training,
num_evals=1000,
num_evals=200,
reward_scaling=0.1,
episode_length=1000,
normalize_observations=True,
action_repeat=1,
unroll_length=10,
num_minibatches=32,
num_minibatches=8,
num_updates_per_batch=8,
discounting=0.97,
learning_rate=3e-4,
Expand Down

0 comments on commit c3a0ef9

Please sign in to comment.