allow history observation
This commit is contained in:
parent
2ddf63b8f5
commit
f13eb203e1
@ -162,7 +162,10 @@ To use DDIM fine-tuning, set `denoising_steps=100` in pre-training and set `mode
|
|||||||
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_length`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_length` is a 1-D array for indexing across num_total_steps.
|
Pre-training script is at [`agent/pretrain/train_diffusion_agent.py`](agent/pretrain/train_diffusion_agent.py). The pre-training dataset [loader](agent/dataset/sequence.py) assumes a npz file containing numpy arrays `states`, `actions`, `images` (if using pixel) and `traj_length`, where `states` and `actions` have the shape of num_total_steps x obs_dim/act_dim, `images` num_total_steps x C (concatenated if multiple images) x H x W, and `traj_length` is a 1-D array for indexing across num_total_steps.
|
||||||
<!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). -->
|
<!-- One pre-processing example can be found at [`script/process_robomimic_dataset.py`](script/process_robomimic_dataset.py). -->
|
||||||
|
|
||||||
**Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131).
|
<!-- **Note:** The current implementation does not support loading history observations (only using observation at the current timestep). If needed, you can modify [here](agent/dataset/sequence.py#L130-L131). -->
|
||||||
|
|
||||||
|
#### Observation history
|
||||||
|
In our experiments we did not use any observation from previous timesteps (state or pixel), but it is implemented. You can set `cond_steps=<num_state_obs_step>` (and `img_cond_steps=<num_img_obs_step>`, no larger than `cond_steps`) in pre-training, and set the same when fine-tuning the newly pre-trained policy.
|
||||||
|
|
||||||
### Fine-tuning environment
|
### Fine-tuning environment
|
||||||
We follow the Gym format for interacting with the environments. The vectorized environments are initialized at [make_async](env/gym_utils/__init__.py#L10) (called in the parent fine-tuning agent class [here](agent/finetune/train_agent.py#L38-L39)). The current implementation is not the cleanest as we tried to make it compatible with Gym, Robomimic, Furniture-Bench, and D3IL environments, but it should be easy to modify and allow using other environments. We use [multi_step](env/gym_utils/wrapper/multi_step.py) wrapper for history observations (not used currently) and multi-environment-step action execution. We also use environment-specific wrappers such as [robomimic_lowdim](env/gym_utils/wrapper/robomimic_lowdim.py) and [furniture](env/gym_utils/wrapper/furniture.py) for observation/action normalization, etc. You can implement a new environment wrapper if needed.
|
We follow the Gym format for interacting with the environments. The vectorized environments are initialized at [make_async](env/gym_utils/__init__.py#L10) (called in the parent fine-tuning agent class [here](agent/finetune/train_agent.py#L38-L39)). The current implementation is not the cleanest as we tried to make it compatible with Gym, Robomimic, Furniture-Bench, and D3IL environments, but it should be easy to modify and allow using other environments. We use [multi_step](env/gym_utils/wrapper/multi_step.py) wrapper for history observations (not used currently) and multi-environment-step action execution. We also use environment-specific wrappers such as [robomimic_lowdim](env/gym_utils/wrapper/robomimic_lowdim.py) and [furniture](env/gym_utils/wrapper/furniture.py) for observation/action normalization, etc. You can implement a new environment wrapper if needed.
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
|
Pre-training data loader. Modified from https://github.com/jannerm/diffuser/blob/main/diffuser/datasets/sequence.py
|
||||||
|
|
||||||
TODO: implement history observation
|
|
||||||
|
|
||||||
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
|
No normalization is applied here --- we always normalize the data when pre-processing it with a different script, and the normalization info is also used in RL fine-tuning.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: lamp
|
furniture: lamp
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: lamp
|
furniture: lamp
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ env:
|
|||||||
furniture: lamp
|
furniture: lamp
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: lamp
|
furniture: lamp
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: lamp
|
furniture: lamp
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ env:
|
|||||||
furniture: lamp
|
furniture: lamp
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: one_leg
|
furniture: one_leg
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: one_leg
|
furniture: one_leg
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ env:
|
|||||||
furniture: one_leg
|
furniture: one_leg
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: one_leg
|
furniture: one_leg
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: one_leg
|
furniture: one_leg
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ env:
|
|||||||
furniture: one_leg
|
furniture: one_leg
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: round_table
|
furniture: round_table
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: round_table
|
furniture: round_table
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ env:
|
|||||||
furniture: round_table
|
furniture: round_table
|
||||||
randomness: low
|
randomness: low
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: round_table
|
furniture: round_table
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ env:
|
|||||||
furniture: round_table
|
furniture: round_table
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ env:
|
|||||||
furniture: round_table
|
furniture: round_table
|
||||||
randomness: med
|
randomness: med
|
||||||
normalization_path: ${normalization_path}
|
normalization_path: ${normalization_path}
|
||||||
|
obs_steps: ${cond_steps}
|
||||||
act_steps: ${act_steps}
|
act_steps: ${act_steps}
|
||||||
sparse_reward: True
|
sparse_reward: True
|
||||||
|
|
||||||
|
6
env/gym_utils/__init__.py
vendored
6
env/gym_utils/__init__.py
vendored
@ -24,6 +24,7 @@ def make_async(
|
|||||||
normalization_path=None,
|
normalization_path=None,
|
||||||
furniture="one_leg",
|
furniture="one_leg",
|
||||||
randomness="low",
|
randomness="low",
|
||||||
|
obs_steps=1,
|
||||||
act_steps=8,
|
act_steps=8,
|
||||||
sparse_reward=False,
|
sparse_reward=False,
|
||||||
# below for robomimic only
|
# below for robomimic only
|
||||||
@ -93,19 +94,16 @@ def make_async(
|
|||||||
stiffness=1_000,
|
stiffness=1_000,
|
||||||
damping=200,
|
damping=200,
|
||||||
)
|
)
|
||||||
|
|
||||||
env = FurnitureRLSimEnvMultiStepWrapper(
|
env = FurnitureRLSimEnvMultiStepWrapper(
|
||||||
env,
|
env,
|
||||||
n_obs_steps=1,
|
n_obs_steps=obs_steps,
|
||||||
n_action_steps=act_steps,
|
n_action_steps=act_steps,
|
||||||
reward_agg_method="sum",
|
|
||||||
prev_action=False,
|
prev_action=False,
|
||||||
reset_within_step=False,
|
reset_within_step=False,
|
||||||
pass_full_observations=False,
|
pass_full_observations=False,
|
||||||
normalization_path=normalization_path,
|
normalization_path=normalization_path,
|
||||||
sparse_reward=sparse_reward,
|
sparse_reward=sparse_reward,
|
||||||
)
|
)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
# avoid import error due incompatible gym versions
|
# avoid import error due incompatible gym versions
|
||||||
|
48
env/gym_utils/wrapper/furniture.py
vendored
48
env/gym_utils/wrapper/furniture.py
vendored
@ -5,8 +5,10 @@ Environment wrapper for Furniture-Bench environments.
|
|||||||
|
|
||||||
import gym
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from furniture_bench.envs.furniture_rl_sim_env import FurnitureRLSimEnv
|
|
||||||
import torch
|
import torch
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
from furniture_bench.envs.furniture_rl_sim_env import FurnitureRLSimEnv
|
||||||
from furniture_bench.controllers.control_utils import proprioceptive_quat_to_6d_rotation
|
from furniture_bench.controllers.control_utils import proprioceptive_quat_to_6d_rotation
|
||||||
from ..furniture_normalizer import LinearNormalizer
|
from ..furniture_normalizer import LinearNormalizer
|
||||||
from .multi_step import repeated_space
|
from .multi_step import repeated_space
|
||||||
@ -16,6 +18,32 @@ import logging
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def stack_last_n_obs_dict(all_obs, n_steps):
|
||||||
|
"""Apply padding"""
|
||||||
|
assert len(all_obs) > 0
|
||||||
|
all_obs = list(all_obs)
|
||||||
|
result = {
|
||||||
|
key: torch.zeros(
|
||||||
|
list(all_obs[-1][key].shape)[0:1]
|
||||||
|
+ [n_steps]
|
||||||
|
+ list(all_obs[-1][key].shape)[1:],
|
||||||
|
dtype=all_obs[-1][key].dtype,
|
||||||
|
).to(
|
||||||
|
all_obs[-1][key].device
|
||||||
|
) # add step dimension
|
||||||
|
for key in all_obs[-1]
|
||||||
|
}
|
||||||
|
start_idx = -min(n_steps, len(all_obs))
|
||||||
|
for key in all_obs[-1]:
|
||||||
|
result[key][:, start_idx:] = torch.concatenate(
|
||||||
|
[obs[key][:, None] for obs in all_obs[start_idx:]], dim=1
|
||||||
|
) # add step dimension
|
||||||
|
if n_steps > len(all_obs):
|
||||||
|
# pad
|
||||||
|
result[key][:start_idx] = result[key][start_idx]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
|
class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
|
||||||
env: FurnitureRLSimEnv
|
env: FurnitureRLSimEnv
|
||||||
|
|
||||||
@ -26,7 +54,6 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
|
|||||||
n_action_steps=1,
|
n_action_steps=1,
|
||||||
max_episode_steps=None,
|
max_episode_steps=None,
|
||||||
sparse_reward=False,
|
sparse_reward=False,
|
||||||
reward_agg_method="sum", # never use other types
|
|
||||||
reset_within_step=False,
|
reset_within_step=False,
|
||||||
pass_full_observations=False,
|
pass_full_observations=False,
|
||||||
normalization_path=None,
|
normalization_path=None,
|
||||||
@ -35,8 +62,6 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
|
|||||||
assert (
|
assert (
|
||||||
not reset_within_step
|
not reset_within_step
|
||||||
), "reset_within_step must be False for furniture envs"
|
), "reset_within_step must be False for furniture envs"
|
||||||
assert n_obs_steps == 1, "n_obs_steps must be 1"
|
|
||||||
assert reward_agg_method == "sum", "reward_agg_method must be sum"
|
|
||||||
assert (
|
assert (
|
||||||
not pass_full_observations
|
not pass_full_observations
|
||||||
), "pass_full_observations is not implemented yet"
|
), "pass_full_observations is not implemented yet"
|
||||||
@ -68,6 +93,8 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
|
|||||||
):
|
):
|
||||||
"""Resets the environment."""
|
"""Resets the environment."""
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
|
self.obs = deque([obs], maxlen=max(self.n_obs_steps + 1, self.n_action_steps))
|
||||||
|
obs = stack_last_n_obs_dict(self.obs, self.n_obs_steps)
|
||||||
nobs = self.process_obs(obs)
|
nobs = self.process_obs(obs)
|
||||||
self.best_reward = torch.zeros(self.env.num_envs).to(self.device)
|
self.best_reward = torch.zeros(self.env.num_envs).to(self.device)
|
||||||
self.done = list()
|
self.done = list()
|
||||||
@ -118,6 +145,7 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
|
|||||||
for i in range(self.n_action_steps):
|
for i in range(self.n_action_steps):
|
||||||
# The dimensions of the action_chunk are (num_envs, chunk_size, action_dim)
|
# The dimensions of the action_chunk are (num_envs, chunk_size, action_dim)
|
||||||
obs, reward, done, info = self.env.step(action_chunk[:, i, :])
|
obs, reward, done, info = self.env.step(action_chunk[:, i, :])
|
||||||
|
self.obs.append(obs)
|
||||||
|
|
||||||
# track raw reward
|
# track raw reward
|
||||||
sparse_reward += reward.squeeze()
|
sparse_reward += reward.squeeze()
|
||||||
@ -130,21 +158,17 @@ class FurnitureRLSimEnvMultiStepWrapper(gym.Wrapper):
|
|||||||
|
|
||||||
dones = dones | done.squeeze()
|
dones = dones | done.squeeze()
|
||||||
|
|
||||||
|
obs = stack_last_n_obs_dict(self.obs, self.n_obs_steps)
|
||||||
return obs, sparse_reward, dense_reward, dones, info
|
return obs, sparse_reward, dense_reward, dones, info
|
||||||
|
|
||||||
def process_obs(self, obs: torch.Tensor) -> np.ndarray:
|
def process_obs(self, obs: torch.Tensor) -> np.ndarray:
|
||||||
robot_state = obs["robot_state"]
|
|
||||||
|
|
||||||
# Convert the robot state to have 6D pose
|
# Convert the robot state to have 6D pose
|
||||||
|
robot_state = obs["robot_state"]
|
||||||
robot_state = proprioceptive_quat_to_6d_rotation(robot_state)
|
robot_state = proprioceptive_quat_to_6d_rotation(robot_state)
|
||||||
|
|
||||||
parts_poses = obs["parts_poses"]
|
parts_poses = obs["parts_poses"]
|
||||||
|
|
||||||
obs = torch.cat([robot_state, parts_poses], dim=-1)
|
obs = torch.cat([robot_state, parts_poses], dim=-1)
|
||||||
|
|
||||||
nobs = self.normalizer(obs, "observations", forward=True)
|
nobs = self.normalizer(obs, "observations", forward=True)
|
||||||
nobs = torch.clamp(nobs, -5, 5)
|
nobs = torch.clamp(nobs, -5, 5)
|
||||||
|
return nobs.cpu().numpy() # (n_envs, n_obs_steps, obs_dim)
|
||||||
# Insert a dummy dimension for the n_obs_steps (n_envs, obs_dim) -> (n_envs, n_obs_steps, obs_dim)
|
|
||||||
nobs = nobs.unsqueeze(1).cpu().numpy()
|
|
||||||
|
|
||||||
return nobs
|
|
||||||
|
@ -246,7 +246,7 @@ def make_dataset(
|
|||||||
|
|
||||||
# Save to np file
|
# Save to np file
|
||||||
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
||||||
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl")
|
save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
|
||||||
with open(save_train_path, "wb") as f:
|
with open(save_train_path, "wb") as f:
|
||||||
pickle.dump(out_train, f)
|
pickle.dump(out_train, f)
|
||||||
with open(save_val_path, "wb") as f:
|
with open(save_val_path, "wb") as f:
|
||||||
|
@ -143,7 +143,7 @@ def make_dataset(env_name, save_dir, save_name_prefix, val_split, logger):
|
|||||||
|
|
||||||
# Save to np file
|
# Save to np file
|
||||||
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
||||||
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl")
|
save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
|
||||||
with open(save_train_path, "wb") as f:
|
with open(save_train_path, "wb") as f:
|
||||||
pickle.dump(out_train, f)
|
pickle.dump(out_train, f)
|
||||||
with open(save_val_path, "wb") as f:
|
with open(save_val_path, "wb") as f:
|
||||||
|
@ -193,7 +193,7 @@ def make_dataset(load_path, save_dir, save_name_prefix, env_type, val_split):
|
|||||||
|
|
||||||
# Save to np file
|
# Save to np file
|
||||||
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
||||||
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl")
|
save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
|
||||||
with open(save_train_path, "wb") as f:
|
with open(save_train_path, "wb") as f:
|
||||||
pickle.dump(out_train, f)
|
pickle.dump(out_train, f)
|
||||||
with open(save_val_path, "wb") as f:
|
with open(save_val_path, "wb") as f:
|
||||||
|
@ -304,7 +304,7 @@ def make_dataset(
|
|||||||
|
|
||||||
# Save to np file
|
# Save to np file
|
||||||
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
save_train_path = os.path.join(save_dir, save_name_prefix + "train.npz")
|
||||||
save_val_path = os.path.join(save_dir, save_name_prefix + "val.pkl")
|
save_val_path = os.path.join(save_dir, save_name_prefix + "val.npz")
|
||||||
with open(save_train_path, "wb") as f:
|
with open(save_train_path, "wb") as f:
|
||||||
pickle.dump(out_train, f)
|
pickle.dump(out_train, f)
|
||||||
with open(save_val_path, "wb") as f:
|
with open(save_val_path, "wb") as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user