Orion
path: orion.primitives.tadgan.TadGAN
orion.primitives.tadgan.TadGAN
description: this is a reconstruction model, namely Generative Adversarial Networks (GAN), containing multiple neural networks and cycle consistency loss. the proposed model is described in the related paper.
see json.
argument
type
description
parameters
X
numpy.ndarray
n-dimensional array containing the input sequences for the model
y
n-dimensional array containing the target sequences we want to reconstruct. Typically y is a signal from a selected set of channels from X.
hyperparameters
epochs
int
number of epochs to train the model. An epoch is an iteration over the entire X data provided
input_shape
tuple
tuple denoting the shape of an input sample
target_shape
tuple denoting the shape of an reconstructed sample
optimizer
str
string (name of optimizer) or optimizer instance. Default is keras.optimizers.Adam
keras.optimizers.Adam
learning_rate
float
float denoting the learning rate of the optimizer. Default is 0.005
latent_dim
integer denoting dimension of latent space. Default is 20
batch_size
number of samples per gradient update. Default is 64
iterations_critic
number of critic training steps per generator/encoder training steps. Default is 5
layers_encoder
list
list containing layers of encoder
layers_generator
list containing layers of generator
layers_critic_x
list containing layers of critic_x
critic_x
layers_critic_z
list containing layers of critic_z
critic_z
output
n-dimensional array containing the reconstructions for each input sequence
critic
n-dimensional array containing the critic score for each input sequence
In [1]: import numpy as np In [2]: from mlstars import load_primitive In [3]: X = np.array([1] * 100).reshape(1, -1, 1) In [4]: y = X[:,:, [0]] # signal to reconstruct from X (channel 0) In [5]: primitive = load_primitive('orion.primitives.tadgan.TadGAN', ...: arguments={"X": X, "y":X, "epochs": 5, "batch_size": 1, ...: "iterations_critic": 1}) ...: In [6]: primitive.fit() Epoch: 1/5, Losses: {'cx_loss': 8.5161, 'cz_loss': 0.13, 'eg_loss': 10.1249} Epoch: 2/5, Losses: {'cx_loss': 8.0038, 'cz_loss': 0.3508, 'eg_loss': 9.4262} Epoch: 3/5, Losses: {'cx_loss': 7.6097, 'cz_loss': -1.6323, 'eg_loss': 7.6745} Epoch: 4/5, Losses: {'cx_loss': 7.1942, 'cz_loss': 1.9003, 'eg_loss': 5.1486} Epoch: 5/5, Losses: {'cx_loss': 6.2477, 'cz_loss': 7.8678, 'eg_loss': 5.2253} In [7]: y, critic = primitive.produce(X=X, y=y) 1/1 [==============================] - ETA: 0s 1/1 [==============================] - 0s 470ms/step 1/1 [==============================] - ETA: 0s 1/1 [==============================] - 1s 559ms/step 1/1 [==============================] - ETA: 0s 1/1 [==============================] - 0s 61ms/step In [8]: print("average reconstructed value: {:.2f}, critic score {:.2f}".format( ...: y.mean(), critic[0][0])) ...: average reconstructed value: 0.42, critic score 0.12