- Fix missing MUON optimizer by replacing with optax.adam - Fix Hydra configuration parameter paths (env.name instead of env_name) - Fix BraxGymnaxWrapper method signatures to accept params argument - Fix training loop division by zero with proper total_time_steps - Fix incorrect algorithm name in wandb (reppo instead of sac) - Fix JAX key batching error in BraxGymnaxWrapper reset method - Add comprehensive HoReKa SLURM integration with wandb logging - Update README with detailed bug documentation and fixes
55 lines
1.4 KiB
Bash
Executable File
55 lines
1.4 KiB
Bash
Executable File
#!/bin/bash
|
|
#SBATCH --job-name=reppo_brax
|
|
#SBATCH --account=hk-project-p0022232
|
|
#SBATCH --partition=accelerated
|
|
#SBATCH --gres=gpu:1
|
|
#SBATCH --nodes=1
|
|
#SBATCH --ntasks-per-node=1
|
|
#SBATCH --cpus-per-task=8
|
|
#SBATCH --time=04:00:00
|
|
#SBATCH --mem=24G
|
|
#SBATCH --output=logs/reppo_brax_%j.out
|
|
#SBATCH --error=logs/reppo_brax_%j.err
|
|
|
|
# Load required modules
|
|
module load devel/cuda/12.4
|
|
|
|
# Set environment variables
|
|
export WANDB_MODE=online
|
|
export WANDB_PROJECT=reppo_brax
|
|
|
|
# Change to project directory
|
|
cd /hkfs/home/project/hk-project-robolear/ys1087/Projects/reppo
|
|
|
|
# Activate virtual environment
|
|
source .venv/bin/activate
|
|
|
|
# Note: Ensure WANDB_API_KEY and WANDB_ENTITY are set before running
|
|
|
|
# Run REPPO with Brax environment
|
|
echo "Starting REPPO training with Brax..."
|
|
echo "Job ID: $SLURM_JOB_ID"
|
|
echo "Node: $SLURM_NODELIST"
|
|
echo "GPU: $CUDA_VISIBLE_DEVICES"
|
|
|
|
# Default environment: ant (can be overridden)
|
|
ENV_NAME=${ENV_NAME:-ant}
|
|
EXPERIMENT_TYPE=${EXPERIMENT_TYPE:-mjx_dmc_small_data}
|
|
|
|
echo "Environment: $ENV_NAME"
|
|
echo "Experiment type: $EXPERIMENT_TYPE"
|
|
|
|
# Run the experiment
|
|
python reppo_alg/jaxrl/reppo.py \
|
|
env=brax \
|
|
env.name=$ENV_NAME \
|
|
hyperparameters.num_envs=1024 \
|
|
hyperparameters.num_steps=128 \
|
|
hyperparameters.num_mini_batches=128 \
|
|
hyperparameters.num_epochs=4 \
|
|
hyperparameters.total_time_steps=50000000 \
|
|
wandb.mode=online \
|
|
wandb.entity=${WANDB_ENTITY} \
|
|
wandb.project=$WANDB_PROJECT
|
|
|
|
echo "Training completed!" |