Brax¶
Installation¶
Installing Brax with CUDA acceleration can be a bit tricky. There are some notes here: https://github.com/google/jax#pip-installation-gpu-cuda.
I had the best luck with the following steps:
# Create a Conda environment or use your existing one
conda create --name sf_brax python=3.9
# Activate the environment
conda activate sf_brax
# cuda-nvcc seems to be necessary, and the order of conda repos matters
conda install cudatoolkit cuda-nvcc -c conda-forge -c nvidia
# Install Jax/Jaxlib from a custom repo
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Install Brax
pip install brax
Then follow general instructions to install Sample Factory if you need to.
Running Experiments¶
# to avoid OOM issues it is advised to disable vram preallocation (might not be necessary)
export XLA_PYTHON_CLIENT_PREALLOCATE=false
# train for 100M steps with default hyperparameters
python -m sf_examples.brax.train_brax --env=ant --experiment=ant_brax
# evaluate the agent
python -m sf_examples.brax.enjoy_brax --env=ant --experiment=ant_brax
# Brax software renderer is quite slow, so you can render a video offscreen instead of visualizing it in a window
# Video will be saved to the experiment directory
python -m sf_examples.brax.enjoy_brax --env=ant --experiment=ant_brax --save_video --video_name=ant
Results¶
Reports¶
The following reports were created after running a launcher script on a Slurm cluster with the following command:
python -m sample_factory.launcher.run --run=sf_examples.brax.experiments.brax_basic_envs --backend=slurm --slurm_workdir=./slurm_brax --experiment_suffix=slurm --slurm_gpus_per_job=1 --slurm_cpus_per_gpu=8 --slurm_sbatch_template=./sf_examples/brax/experiments/sbatch_timeout_brax.sh --pause_between=0 --slurm_print_only=False
- ant: https://api.wandb.ai/report/apetrenko/ji9jygss
- humanoid: https://api.wandb.ai/report/apetrenko/m520i16m
- halfcheetah: https://api.wandb.ai/report/apetrenko/7xlp3hh8
- walker2d: https://api.wandb.ai/report/apetrenko/pvb9d11c
Models¶
Environment | HuggingFace Hub Models | Evaluation Metrics |
---|---|---|
ant | https://huggingface.co/apetrenko/sample_factory_brax_ant | 12565.17 ± 3350.51 |
humanoid | https://huggingface.co/apetrenko/sample_factory_brax_humanoid | 33847.53 ± 6327.36 |
halfcheetah | https://huggingface.co/apetrenko/sample_factory_brax_halfcheetah | 22298.35 ± 1882.48 |
walker2d | https://huggingface.co/apetrenko/sample_factory_brax_walker2d | 5459.17 ± 2198.74 |
Example command line used to generate a HuggingFace Hub model:
python -m sf_examples.brax.enjoy_brax \
--env=humanoid --experiment=02_v083_brax_basic_benchmark_see_2322090_env_humanoid_u.rnn_False_n.epo_5 \
--train_dir=/home/alex/all/projects/sf2/train_dir/v083_brax_basic_benchmark/v083_brax_basic_benchmark_slurm \
--save_video --video_frames=500 --max_num_episodes=500 \
--enjoy_script=sf_examples.brax.enjoy_brax --train_script=sf_examples.brax.train_brax \
--push_to_hub --hf_repository=apetrenko/sample_factory_brax_humanoid --brax_render_res=320 --load_checkpoint_kind=best
Videos¶
Ant Environment¶
Humanoid Environment¶