metastable-projections/metastable_projections/projections/identity_projection_layer.py

6 lines
201 B
Python
Raw Normal View History

from .base_projection_layer import BaseProjectionLayer
class IdentityProjectionLayer(BaseProjectionLayer):
def project_from_rollouts(self, dist, rollout_data, **kwargs):
return dist, dist