/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerEvaluations;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.ListenerVariables;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.records.EvaluationRecord;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;

public abstract class BaseEvaluationListener
extends BaseListener {
    private Map<String, List<IEvaluation>> trainingEvaluations = new HashMap<String, List<IEvaluation>>();
    private Map<String, List<IEvaluation>> validationEvaluations = new HashMap<String, List<IEvaluation>>();

    public abstract ListenerEvaluations evaluations();

    @Override
    public final ListenerVariables requiredVariables(SameDiff sd) {
        return this.evaluations().requiredVariables().merge(this.otherRequiredVariables(sd));
    }

    public ListenerVariables otherRequiredVariables(SameDiff sd) {
        return ListenerVariables.empty();
    }

    @Override
    public final void epochStart(SameDiff sd, At at) {
        ArrayList evals;
        this.trainingEvaluations = new HashMap<String, List<IEvaluation>>();
        for (Map.Entry<String, List<IEvaluation>> entry : this.evaluations().trainEvaluations().entrySet()) {
            evals = new ArrayList();
            for (IEvaluation ie : entry.getValue()) {
                evals.add(ie.newInstance());
            }
            this.trainingEvaluations.put(entry.getKey(), evals);
        }
        this.validationEvaluations = new HashMap<String, List<IEvaluation>>();
        for (Map.Entry<String, List<IEvaluation>> entry : this.evaluations().validationEvaluations().entrySet()) {
            evals = new ArrayList();
            for (IEvaluation ie : entry.getValue()) {
                evals.add(ie.newInstance());
            }
            this.validationEvaluations.put(entry.getKey(), evals);
        }
        this.epochStartEvaluations(sd, at);
    }

    public void epochStartEvaluations(SameDiff sd, At at) {
    }

    @Override
    public final ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
        return this.epochEndEvaluations(sd, at, lossCurve, epochTimeMillis, new EvaluationRecord(this.trainingEvaluations));
    }

    public ListenerResponse epochEndEvaluations(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis, EvaluationRecord evaluations) {
        return ListenerResponse.CONTINUE;
    }

    @Override
    public final ListenerResponse validationDone(SameDiff sd, At at, long validationTimeMillis) {
        return this.validationDoneEvaluations(sd, at, validationTimeMillis, new EvaluationRecord(this.validationEvaluations));
    }

    public ListenerResponse validationDoneEvaluations(SameDiff sd, At at, long validationTimeMillis, EvaluationRecord evaluations) {
        return ListenerResponse.CONTINUE;
    }

    @Override
    public final void activationAvailable(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation) {
        if (at.operation() == Operation.TRAINING) {
            if (this.trainingEvaluations.containsKey(varName)) {
                INDArray labels = batch.getLabels(this.evaluations().trainEvaluationLabels().get(varName));
                INDArray mask = batch.getLabelsMaskArray(this.evaluations().trainEvaluationLabels().get(varName));
                for (IEvaluation e : this.trainingEvaluations.get(varName)) {
                    e.eval(labels, activation, mask);
                }
            }
        } else if (at.operation() == Operation.TRAINING_VALIDATION && this.validationEvaluations.containsKey(varName)) {
            INDArray labels = batch.getLabels(this.evaluations().validationEvaluationLabels().get(varName));
            INDArray mask = batch.getLabelsMaskArray(this.evaluations().validationEvaluationLabels().get(varName));
            for (IEvaluation e : this.validationEvaluations.get(varName)) {
                e.eval(labels, activation, mask);
            }
        }
        this.activationAvailableEvaluations(sd, at, batch, op, varName, activation);
    }

    public void activationAvailableEvaluations(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, String varName, INDArray activation) {
    }
}

