Better config for Feature Extractor

This commit is contained in:
Dominik Moritz Roth 2024-05-27 17:07:00 +02:00
parent 17ff693ee5
commit 8bf85d1a65
2 changed files with 39 additions and 33 deletions

View File

@ -1,6 +1,8 @@
name: EXAMPLE name: EXAMPLE
feature_extractor: feature_extractor:
input_size: 1953 # Input size for the Feature Extractor (length of snippets). (=0.1s)
transforms:
- type: 'identity' # Pass the last n samples of the input data directly. - type: 'identity' # Pass the last n samples of the input data directly.
length: 8 # Number of last samples to pass directly. Use full input size if set to null. length: 8 # Number of last samples to pass directly. Use full input size if set to null.
- type: 'fourier' # Apply Fourier transform to the input data. - type: 'fourier' # Apply Fourier transform to the input data.
@ -35,7 +37,6 @@ feature_extractor:
latent_projector: latent_projector:
type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier' type: 'fc' # Type of latent projector: 'fc', 'rnn', 'fourier'
input_size: 1953 # Input size for the Latent Projector (length of snippets). (=0.1s)
latent_size: 4 # Size of the latent representation before message passing. latent_size: 4 # Size of the latent representation before message passing.
layer_shapes: [32, 8] # List of layer sizes for the latent projector if type is 'fc' or 'fourier'. layer_shapes: [32, 8] # List of layer sizes for the latent projector if type is 'fc' or 'fourier'.
activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers if type is 'fc' or 'fourier'. activations: ['ReLU', 'ReLU'] # Activation functions for the latent projector layers if type is 'fc' or 'fourier'.
@ -146,6 +147,11 @@ middle_out:
name: FC name: FC
import: $ import: $
feature_extractor:
input size: 10
transforms:
- type: 'identity'
latent_projector: latent_projector:
type: fc type: fc
input_size: 1953 input_size: 1953

View File

@ -24,7 +24,7 @@ class FeatureExtractor(nn.Module):
transforms = [] transforms = []
for item in config: for item in config:
transform_type = item['type'] transform_type = item['type']
length = item.get('length', self.input_size) length = item.get('length', None)
if length in [None, -1]: if length in [None, -1]:
length = self.input_size length = self.input_size