- Add complete HoReKa installation guide without conda dependency - Include SLURM job script with GPU configuration and account setup - Add helper scripts for job submission and environment testing - Integrate wandb logging with both online and offline modes - Support MuJoCo Playground environments for humanoid control - Update README with clear separation of added vs original content
94 lines
2.6 KiB
Python
94 lines
2.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to verify FastTD3 setup is working correctly.
|
|
This runs a minimal test to ensure all components are functioning.
|
|
"""
|
|
import os
|
|
import torch
|
|
import gymnasium as gym
|
|
import wandb
|
|
from fast_td3.hyperparams import get_args
|
|
|
|
def test_basic_imports():
|
|
"""Test that all required packages can be imported."""
|
|
print("Testing basic imports...")
|
|
try:
|
|
import torch
|
|
import gymnasium as gym
|
|
import wandb
|
|
import numpy as np
|
|
import tensordict
|
|
print("✓ All basic packages imported successfully")
|
|
return True
|
|
except ImportError as e:
|
|
print(f"✗ Import error: {e}")
|
|
return False
|
|
|
|
def test_gpu_availability():
|
|
"""Test GPU availability."""
|
|
print("Testing GPU availability...")
|
|
if torch.cuda.is_available():
|
|
print(f"✓ CUDA available, {torch.cuda.device_count()} GPU(s) found")
|
|
print(f" Current device: {torch.cuda.get_device_name(0)}")
|
|
return True
|
|
else:
|
|
print("⚠ CUDA not available, will run on CPU")
|
|
return False
|
|
|
|
def test_environment():
|
|
"""Test that we can create a simple environment."""
|
|
print("Testing environment creation...")
|
|
try:
|
|
env = gym.make("Pendulum-v1")
|
|
obs, info = env.reset()
|
|
print(f"✓ Environment created successfully")
|
|
print(f" Observation space: {env.observation_space}")
|
|
print(f" Action space: {env.action_space}")
|
|
env.close()
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ Environment creation failed: {e}")
|
|
return False
|
|
|
|
def test_wandb_setup():
|
|
"""Test wandb setup (without actual login)."""
|
|
print("Testing wandb setup...")
|
|
try:
|
|
# Just test that wandb can be initialized in offline mode
|
|
os.environ["WANDB_MODE"] = "offline"
|
|
wandb.init(project="test", mode="offline")
|
|
wandb.finish()
|
|
print("✓ wandb can be initialized")
|
|
return True
|
|
except Exception as e:
|
|
print(f"✗ wandb setup failed: {e}")
|
|
return False
|
|
|
|
def main():
|
|
print("FastTD3 Setup Test")
|
|
print("==================")
|
|
|
|
tests = [
|
|
test_basic_imports,
|
|
test_gpu_availability,
|
|
test_environment,
|
|
test_wandb_setup,
|
|
]
|
|
|
|
passed = 0
|
|
for test in tests:
|
|
if test():
|
|
passed += 1
|
|
print()
|
|
|
|
print(f"Results: {passed}/{len(tests)} tests passed")
|
|
|
|
if passed == len(tests):
|
|
print("🎉 All tests passed! Setup looks good.")
|
|
return True
|
|
else:
|
|
print("❌ Some tests failed. Check the output above.")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
main() |