package org.deeplearning4j.nn.layers.mkldnn;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
import org.nd4j.linalg.activations.impl.ActivationSoftSign;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.activations.impl.ActivationThresholdedReLU;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.class */
public class MKLDNNLSTMHelper implements LSTMHelper {
    @Override // org.deeplearning4j.nn.layers.recurrent.LSTMHelper
    public boolean checkSupported(IActivation iActivation, IActivation iActivation2, boolean z) {
        return (iActivation instanceof ActivationSigmoid) && (iActivation2 instanceof ActivationTanH) && BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.LSTMHelper
    public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, int i, FwdPassReturn fwdPassReturn, boolean z2, String str, String str2, String str3, Map<String, INDArray> map, INDArray iNDArray5, boolean z3, LayerWorkspaceMgr layerWorkspaceMgr) {
        return null;
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.LSTMHelper
    public FwdPassReturn activate(Layer layer, NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, INDArray iNDArray5, INDArray iNDArray6, boolean z2, boolean z3, String str, INDArray iNDArray7, boolean z4, LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray reshape = iNDArray4.reshape(new long[]{iNDArray4.length()});
        INDArray firstIndex = iNDArray7 != null ? BooleanIndexing.firstIndex(iNDArray7, Conditions.equals(0), new int[]{1}) : null;
        ArrayList arrayList = new ArrayList();
        arrayList.add(iNDArray);
        arrayList.add(iNDArray3);
        arrayList.add(iNDArray2);
        if (z4) {
            throw new IllegalStateException("Not yet implemented");
        }
        arrayList.add(reshape);
        if (firstIndex != null) {
            arrayList.add(firstIndex);
        }
        if (iNDArray5 != null) {
            arrayList.add(iNDArray5);
        }
        if (iNDArray6 != null) {
            arrayList.add(iNDArray6);
        }
        IActivation activationFn = ((LSTM) neuralNetConfiguration.getLayer()).getActivationFn();
        DynamicCustomOp.DynamicCustomOpsBuilder addInputs = DynamicCustomOp.builder("lstmLayer").addInputs((INDArray[]) arrayList.toArray(new INDArray[0]));
        boolean[] zArr = new boolean[8];
        zArr[0] = true;
        zArr[1] = firstIndex != null;
        zArr[2] = iNDArray5 != null;
        zArr[3] = iNDArray6 != null;
        zArr[4] = z4;
        zArr[5] = true;
        zArr[6] = true;
        zArr[7] = true;
        DynamicCustomOp build = addInputs.addBooleanArguments(zArr).addIntegerArguments(new int[]{2, 0, activationToArg(iActivation), activationToArg(activationFn), activationToArg(activationFn)}).build();
        for (LongShapeDescriptor longShapeDescriptor : build.calculateOutputShape()) {
            build.addOutputArgument(new INDArray[]{layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, longShapeDescriptor.dataType(), longShapeDescriptor.getShape(), longShapeDescriptor.getOrder())});
        }
        FwdPassReturn fwdPassReturn = new FwdPassReturn();
        fwdPassReturn.fwdPassOutput = build.getOutputArgument(0);
        fwdPassReturn.lastAct = build.getOutputArgument(1);
        fwdPassReturn.lastMemCell = build.getOutputArgument(2);
        return fwdPassReturn;
    }

    @Override // org.deeplearning4j.nn.layers.LayerHelper
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }

    private int activationToArg(IActivation iActivation) {
        if (iActivation instanceof ActivationTanH) {
            return 0;
        }
        if (iActivation instanceof ActivationReLU) {
            return 1;
        }
        if (iActivation instanceof ActivationSigmoid) {
            return 2;
        }
        if (iActivation instanceof ActivationIdentity) {
            return 3;
        }
        if (iActivation instanceof ActivationLReLU) {
            return 4;
        }
        if (iActivation instanceof ActivationThresholdedReLU) {
            return 5;
        }
        if (iActivation instanceof ActivationHardSigmoid) {
            return 7;
        }
        if (iActivation instanceof ActivationELU) {
            return 8;
        }
        if (iActivation instanceof ActivationSoftSign) {
            return 9;
        }
        if (iActivation instanceof ActivationSoftPlus) {
            return 10;
        }
        throw new IllegalStateException("Unknown or not supported activation function: " + iActivation);
    }
}
