/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.autoencoder;

import java.io.Serializable;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.jblas.DoubleMatrix;

public class DeepAutoEncoder
implements Serializable {
    private static final long serialVersionUID = -3571832097247806784L;
    private BaseMultiLayerNetwork encoder;
    private BaseMultiLayerNetwork decoder;
    private Object[] trainingParams;

    public DeepAutoEncoder(BaseMultiLayerNetwork encoder, Object[] trainingParams) {
        this.encoder = encoder;
        this.trainingParams = trainingParams;
    }

    public void train(DoubleMatrix input, DoubleMatrix labels, double lr) {
        this.encoder.trainNetwork(input, labels, this.trainingParams);
        this.decoder = new BaseMultiLayerNetwork.Builder().withClazz(this.encoder.getClass()).buildEmpty();
        this.decoder.asDecoder(this.encoder);
        DoubleMatrix encoderInput = this.encoder.predict(input);
        DoubleMatrix encoderLabels = input;
        this.decoder.trainNetwork(encoderInput, encoderLabels, this.trainingParams);
    }

    public DoubleMatrix encode(DoubleMatrix input) {
        return this.encoder.predict(input);
    }

    public DoubleMatrix decode(DoubleMatrix input) {
        return this.decoder.predict(input);
    }
}

