/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

public class MultiDataSet
implements org.nd4j.linalg.dataset.api.MultiDataSet {
    private static final ThreadLocal<INDArray> EMPTY_MASK_ARRAY_PLACEHOLDER = new ThreadLocal();
    private INDArray[] features;
    private INDArray[] labels;
    private INDArray[] featuresMaskArrays;
    private INDArray[] labelsMaskArrays;
    private List<Serializable> exampleMetaData;

    public MultiDataSet() {
    }

    public MultiDataSet(INDArray features, INDArray labels) {
        this(features, labels, null, null);
    }

    public MultiDataSet(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2;
        INDArray[] iNDArrayArray3;
        INDArray[] iNDArrayArray4;
        if (features != null) {
            INDArray[] iNDArrayArray5 = new INDArray[1];
            iNDArrayArray4 = iNDArrayArray5;
            iNDArrayArray5[0] = features;
        } else {
            iNDArrayArray4 = null;
        }
        if (labels != null) {
            INDArray[] iNDArrayArray6 = new INDArray[1];
            iNDArrayArray3 = iNDArrayArray6;
            iNDArrayArray6[0] = labels;
        } else {
            iNDArrayArray3 = null;
        }
        if (featuresMask != null) {
            INDArray[] iNDArrayArray7 = new INDArray[1];
            iNDArrayArray2 = iNDArrayArray7;
            iNDArrayArray7[0] = featuresMask;
        } else {
            iNDArrayArray2 = null;
        }
        if (labelsMask != null) {
            INDArray[] iNDArrayArray8 = new INDArray[1];
            iNDArrayArray = iNDArrayArray8;
            iNDArrayArray8[0] = labelsMask;
        } else {
            iNDArrayArray = null;
        }
        this(iNDArrayArray4, iNDArrayArray3, iNDArrayArray2, iNDArrayArray);
    }

    public MultiDataSet(INDArray[] features, INDArray[] labels) {
        this(features, labels, null, null);
    }

    public MultiDataSet(INDArray[] features, INDArray[] labels, INDArray[] featuresMaskArrays, INDArray[] labelsMaskArrays) {
        if (features != null && featuresMaskArrays != null && features.length != featuresMaskArrays.length) {
            throw new IllegalArgumentException("Invalid features / features mask arrays combination: features and features mask arrays must not be different lengths");
        }
        if (labels != null && labelsMaskArrays != null && labels.length != labelsMaskArrays.length) {
            throw new IllegalArgumentException("Invalid labels / labels mask arrays combination: labels and labels mask arrays must not be different lengths");
        }
        this.features = features;
        this.labels = labels;
        this.featuresMaskArrays = featuresMaskArrays;
        this.labelsMaskArrays = labelsMaskArrays;
        Nd4j.getExecutioner().commit();
    }

    @Override
    public List<Serializable> getExampleMetaData() {
        return this.exampleMetaData;
    }

    @Override
    public <T extends Serializable> List<T> getExampleMetaData(Class<T> metaDataType) {
        return this.exampleMetaData;
    }

    @Override
    public void setExampleMetaData(List<? extends Serializable> exampleMetaData) {
        this.exampleMetaData = exampleMetaData;
    }

    @Override
    public int numFeatureArrays() {
        return this.features != null ? this.features.length : 0;
    }

    @Override
    public int numLabelsArrays() {
        return this.labels != null ? this.labels.length : 0;
    }

    @Override
    public INDArray[] getFeatures() {
        return this.features;
    }

    @Override
    public INDArray getFeatures(int index) {
        return this.features[index];
    }

    @Override
    public void setFeatures(INDArray[] features) {
        this.features = features;
    }

    @Override
    public void setFeatures(int idx, INDArray features) {
        this.features[idx] = features;
    }

    @Override
    public INDArray[] getLabels() {
        return this.labels;
    }

    @Override
    public INDArray getLabels(int index) {
        return this.labels[index];
    }

    @Override
    public void setLabels(INDArray[] labels) {
        this.labels = labels;
    }

    @Override
    public void setLabels(int idx, INDArray labels) {
        this.labels[idx] = labels;
    }

    @Override
    public boolean hasMaskArrays() {
        if (this.featuresMaskArrays == null && this.labelsMaskArrays == null) {
            return false;
        }
        if (this.featuresMaskArrays != null) {
            for (INDArray i : this.featuresMaskArrays) {
                if (i == null) continue;
                return true;
            }
        }
        if (this.labelsMaskArrays != null) {
            for (INDArray i : this.labelsMaskArrays) {
                if (i == null) continue;
                return true;
            }
        }
        return false;
    }

    @Override
    public INDArray[] getFeaturesMaskArrays() {
        return this.featuresMaskArrays;
    }

    @Override
    public INDArray getFeaturesMaskArray(int index) {
        return this.featuresMaskArrays != null ? this.featuresMaskArrays[index] : null;
    }

    @Override
    public void setFeaturesMaskArrays(INDArray[] maskArrays) {
        this.featuresMaskArrays = maskArrays;
    }

    @Override
    public void setFeaturesMaskArray(int idx, INDArray maskArray) {
        this.featuresMaskArrays[idx] = maskArray;
    }

    @Override
    public INDArray[] getLabelsMaskArrays() {
        return this.labelsMaskArrays;
    }

    @Override
    public INDArray getLabelsMaskArray(int index) {
        return this.labelsMaskArrays != null ? this.labelsMaskArrays[index] : null;
    }

    @Override
    public void setLabelsMaskArray(INDArray[] labelsMaskArrays) {
        this.labelsMaskArrays = labelsMaskArrays;
    }

    @Override
    public void setLabelsMaskArray(int idx, INDArray labelsMaskArray) {
        this.labelsMaskArrays[idx] = labelsMaskArray;
    }

    @Override
    public void save(OutputStream to) throws IOException {
        int numFArr = this.features == null ? 0 : this.features.length;
        int numLArr = this.labels == null ? 0 : this.labels.length;
        int numFMArr = this.featuresMaskArrays == null ? 0 : this.featuresMaskArrays.length;
        int numLMArr = this.labelsMaskArrays == null ? 0 : this.labelsMaskArrays.length;
        try (DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(to));){
            dos.writeInt(numFArr);
            dos.writeInt(numLArr);
            dos.writeInt(numFMArr);
            dos.writeInt(numLMArr);
            this.saveINDArrays(this.features, dos, false);
            this.saveINDArrays(this.labels, dos, false);
            this.saveINDArrays(this.featuresMaskArrays, dos, true);
            this.saveINDArrays(this.labelsMaskArrays, dos, true);
            if (this.exampleMetaData != null && this.exampleMetaData.size() > 0) {
                dos.writeInt(1);
                ObjectOutputStream oos = new ObjectOutputStream(dos);
                oos.writeObject(this.exampleMetaData);
                oos.flush();
            }
        }
    }

    private void saveINDArrays(INDArray[] arrays, DataOutputStream dos, boolean isMask) throws IOException {
        if (arrays != null && arrays.length > 0) {
            for (INDArray fm : arrays) {
                if (isMask && fm == null) {
                    INDArray temp = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
                    if (temp == null) {
                        EMPTY_MASK_ARRAY_PLACEHOLDER.set(Nd4j.create(new float[]{-1.0f}));
                        temp = EMPTY_MASK_ARRAY_PLACEHOLDER.get();
                    }
                    fm = temp;
                }
                Nd4j.write(fm, dos);
            }
        }
    }

    @Override
    public void save(File to) throws IOException {
        this.save(new FileOutputStream(to));
    }

    @Override
    public void load(InputStream from) throws IOException {
        block22: {
            try (DataInputStream dis = new DataInputStream(new BufferedInputStream(from));){
                int i;
                int numFArr = dis.readInt();
                int numLArr = dis.readInt();
                int numFMArr = dis.readInt();
                int numLMArr = dis.readInt();
                this.features = this.loadINDArrays(numFArr, dis, false);
                this.labels = this.loadINDArrays(numLArr, dis, false);
                this.featuresMaskArrays = this.loadINDArrays(numFMArr, dis, true);
                this.labelsMaskArrays = this.loadINDArrays(numLMArr, dis, true);
                try {
                    i = dis.readInt();
                }
                catch (EOFException e) {
                    if (dis != null) {
                        if (var3_3 != null) {
                            try {
                                dis.close();
                            }
                            catch (Throwable throwable) {
                                var3_3.addSuppressed(throwable);
                            }
                        } else {
                            dis.close();
                        }
                    }
                    return;
                }
                if (i != 1) break block22;
                ObjectInputStream ois = new ObjectInputStream(dis);
                try {
                    this.exampleMetaData = (List)ois.readObject();
                }
                catch (ClassNotFoundException e) {
                    throw new RuntimeException("Error reading metadata from serialized MultiDataSet");
                }
            }
        }
    }

    private INDArray[] loadINDArrays(int numArrays, DataInputStream dis, boolean isMask) throws IOException {
        INDArray[] result = null;
        if (numArrays > 0) {
            result = new INDArray[numArrays];
            for (int i = 0; i < numArrays; ++i) {
                INDArray arr = Nd4j.read(dis);
                result[i] = isMask && arr.equals(EMPTY_MASK_ARRAY_PLACEHOLDER.get()) ? null : arr;
            }
        }
        return result;
    }

    @Override
    public void load(File from) throws IOException {
        this.load(new FileInputStream(from));
    }

    @Override
    public List<org.nd4j.linalg.dataset.api.MultiDataSet> asList() {
        long nExamples = this.features[0].size(0);
        ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet> list = new ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet>();
        int i = 0;
        while ((long)i < nExamples) {
            int j;
            INDArray[] thisFeatures = new INDArray[this.features.length];
            INDArray[] thisLabels = new INDArray[this.labels.length];
            INDArray[] thisFeaturesMaskArray = this.featuresMaskArrays != null ? new INDArray[this.featuresMaskArrays.length] : null;
            INDArray[] thisLabelsMaskArray = this.labelsMaskArrays != null ? new INDArray[this.labelsMaskArrays.length] : null;
            for (j = 0; j < this.features.length; ++j) {
                thisFeatures[j] = MultiDataSet.getSubsetForExample(this.features[j], i);
            }
            for (j = 0; j < this.labels.length; ++j) {
                thisLabels[j] = MultiDataSet.getSubsetForExample(this.labels[j], i);
            }
            if (thisFeaturesMaskArray != null) {
                for (j = 0; j < thisFeaturesMaskArray.length; ++j) {
                    if (this.featuresMaskArrays[j] == null) continue;
                    thisFeaturesMaskArray[j] = MultiDataSet.getSubsetForExample(this.featuresMaskArrays[j], i);
                }
            }
            if (thisLabelsMaskArray != null) {
                for (j = 0; j < thisLabelsMaskArray.length; ++j) {
                    if (this.labelsMaskArrays[j] == null) continue;
                    thisLabelsMaskArray[j] = MultiDataSet.getSubsetForExample(this.labelsMaskArrays[j], i);
                }
            }
            list.add(new MultiDataSet(thisFeatures, thisLabels, thisFeaturesMaskArray, thisLabelsMaskArray));
            ++i;
        }
        return list;
    }

    private static INDArray getSubsetForExample(INDArray array, int idx) {
        switch (array.rank()) {
            case 2: {
                return array.get(NDArrayIndex.point(idx), NDArrayIndex.all());
            }
            case 3: {
                return array.get(NDArrayIndex.interval(idx, idx, true), NDArrayIndex.all(), NDArrayIndex.all());
            }
            case 4: {
                return array.get(NDArrayIndex.interval(idx, idx, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
            }
        }
        throw new IllegalStateException("Cannot get subset for rank " + array.rank() + " array");
    }

    @Override
    public MultiDataSet copy() {
        MultiDataSet ret = new MultiDataSet(this.copy(this.getFeatures()), this.copy(this.getLabels()));
        if (this.labelsMaskArrays != null) {
            ret.setLabelsMaskArray(this.copy(this.labelsMaskArrays));
        }
        if (this.featuresMaskArrays != null) {
            ret.setFeaturesMaskArrays(this.copy(this.featuresMaskArrays));
        }
        return ret;
    }

    private INDArray[] copy(INDArray[] arrays) {
        INDArray[] result = new INDArray[arrays.length];
        for (int i = 0; i < arrays.length; ++i) {
            result[i] = arrays[i].dup();
        }
        return result;
    }

    public static MultiDataSet merge(Collection<? extends org.nd4j.linalg.dataset.api.MultiDataSet> toMerge) {
        if (toMerge.size() == 1) {
            org.nd4j.linalg.dataset.api.MultiDataSet mds = toMerge.iterator().next();
            if (mds instanceof MultiDataSet) {
                return (MultiDataSet)mds;
            }
            return new MultiDataSet(mds.getFeatures(), mds.getLabels(), mds.getFeaturesMaskArrays(), mds.getLabelsMaskArrays());
        }
        ArrayList<? extends org.nd4j.linalg.dataset.api.MultiDataSet> list = toMerge instanceof List ? (ArrayList<? extends org.nd4j.linalg.dataset.api.MultiDataSet>)toMerge : new ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet>(toMerge);
        int nonEmpty = 0;
        for (org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet : toMerge) {
            if (multiDataSet.isEmpty()) continue;
            ++nonEmpty;
        }
        int nInArrays = ((org.nd4j.linalg.dataset.api.MultiDataSet)list.get(0)).numFeatureArrays();
        int n = ((org.nd4j.linalg.dataset.api.MultiDataSet)list.get(0)).numLabelsArrays();
        INDArray[][] features = new INDArray[nonEmpty][0];
        INDArray[][] labels = new INDArray[nonEmpty][0];
        INDArray[][] featuresMasks = new INDArray[nonEmpty][0];
        INDArray[][] labelsMasks = new INDArray[nonEmpty][0];
        int i = 0;
        for (org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet : list) {
            if (multiDataSet.isEmpty()) continue;
            features[i] = multiDataSet.getFeatures();
            labels[i] = multiDataSet.getLabels();
            featuresMasks[i] = multiDataSet.getFeaturesMaskArrays();
            labelsMasks[i] = multiDataSet.getLabelsMaskArrays();
            if (features[i] == null || features[i].length != nInArrays) {
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of input arrays: toMerge[0] has " + nInArrays + " input arrays; toMerge[" + i + "] has " + (features[i] != null ? Integer.valueOf(features[i].length) : null) + " arrays");
            }
            if (labels[i] == null || labels[i].length != n) {
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of output arrays: toMerge[0] has " + n + " output arrays; toMerge[" + i + "] has " + (labels[i] != null ? Integer.valueOf(labels[i].length) : null) + " arrays");
            }
            ++i;
        }
        INDArray[] mergedFeatures = new INDArray[nInArrays];
        INDArray[] iNDArrayArray = new INDArray[n];
        INDArray[] mergedFeaturesMasks = new INDArray[nInArrays];
        INDArray[] mergedLabelsMasks = new INDArray[n];
        boolean needFeaturesMasks = false;
        for (i = 0; i < nInArrays; ++i) {
            Pair<INDArray, INDArray> pair = DataSetUtil.mergeFeatures(features, featuresMasks, i);
            mergedFeatures[i] = (INDArray)pair.getFirst();
            mergedFeaturesMasks[i] = (INDArray)pair.getSecond();
            if (mergedFeaturesMasks[i] == null) continue;
            needFeaturesMasks = true;
        }
        if (!needFeaturesMasks) {
            mergedFeaturesMasks = null;
        }
        boolean needLabelsMasks = false;
        for (i = 0; i < n; ++i) {
            Pair<INDArray, INDArray> pair = DataSetUtil.mergeLabels(labels, labelsMasks, i);
            iNDArrayArray[i] = (INDArray)pair.getFirst();
            mergedLabelsMasks[i] = (INDArray)pair.getSecond();
            if (mergedLabelsMasks[i] == null) continue;
            needLabelsMasks = true;
        }
        if (!needLabelsMasks) {
            mergedLabelsMasks = null;
        }
        return new MultiDataSet(mergedFeatures, iNDArrayArray, mergedFeaturesMasks, mergedLabelsMasks);
    }

    public String toString() {
        int i;
        int nfMask = 0;
        int nlMask = 0;
        if (this.featuresMaskArrays != null) {
            for (INDArray i2 : this.featuresMaskArrays) {
                if (i2 == null) continue;
                ++nfMask;
            }
        }
        if (this.labelsMaskArrays != null) {
            for (INDArray i2 : this.labelsMaskArrays) {
                if (i2 == null) continue;
                ++nlMask;
            }
        }
        StringBuilder sb = new StringBuilder();
        sb.append("MultiDataSet: ").append(this.numFeatureArrays()).append(" input arrays, ").append(this.numLabelsArrays()).append(" label arrays, ").append(nfMask).append(" input masks, ").append(nlMask).append(" label masks");
        for (i = 0; i < this.numFeatureArrays(); ++i) {
            sb.append("\n=== INPUT ").append(i).append(" ===\n").append(this.getFeatures(i).toString().replaceAll(";", "\n"));
            if (this.getFeaturesMaskArray(i) == null) continue;
            sb.append("\n--- INPUT MASK ---\n").append(this.getFeaturesMaskArray(i).toString().replaceAll(";", "\n"));
        }
        for (i = 0; i < this.numLabelsArrays(); ++i) {
            sb.append("\n=== LABEL ").append(i).append(" ===\n").append(this.getLabels(i).toString().replaceAll(";", "\n"));
            if (this.getLabelsMaskArray(i) == null) continue;
            sb.append("\n--- LABEL MASK ---\n").append(this.getLabelsMaskArray(i).toString().replaceAll(";", "\n"));
        }
        return sb.toString();
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MultiDataSet)) {
            return false;
        }
        MultiDataSet m = (MultiDataSet)o;
        if (!this.bothNullOrEqual(this.features, m.features)) {
            return false;
        }
        if (!this.bothNullOrEqual(this.labels, m.labels)) {
            return false;
        }
        if (!this.bothNullOrEqual(this.featuresMaskArrays, m.featuresMaskArrays)) {
            return false;
        }
        return this.bothNullOrEqual(this.labelsMaskArrays, m.labelsMaskArrays);
    }

    private boolean bothNullOrEqual(INDArray[] first, INDArray[] second) {
        if (first == null && second == null) {
            return true;
        }
        if (first == null || second == null) {
            return false;
        }
        if (first.length != second.length) {
            return false;
        }
        for (int i = 0; i < first.length; ++i) {
            if (Objects.equals(first[i], second[i])) continue;
            return false;
        }
        return true;
    }

    public int hashCode() {
        int result = 0;
        if (this.features != null) {
            for (INDArray f : this.features) {
                result = result * 31 + f.hashCode();
            }
        }
        if (this.labels != null) {
            for (INDArray l : this.labels) {
                result = result * 31 + l.hashCode();
            }
        }
        if (this.featuresMaskArrays != null) {
            for (INDArray fm : this.featuresMaskArrays) {
                result = result * 31 + fm.hashCode();
            }
        }
        if (this.labelsMaskArrays != null) {
            for (INDArray lm : this.labelsMaskArrays) {
                result = result * 31 + lm.hashCode();
            }
        }
        return result;
    }

    @Override
    public long getMemoryFootprint() {
        long reqMem = 0L;
        for (INDArray f : this.features) {
            reqMem += f == null ? 0L : f.lengthLong() * (long)Nd4j.sizeOfDataType();
        }
        if (this.featuresMaskArrays != null) {
            for (INDArray f : this.featuresMaskArrays) {
                reqMem += f == null ? 0L : f.lengthLong() * (long)Nd4j.sizeOfDataType();
            }
        }
        if (this.labelsMaskArrays != null) {
            for (INDArray f : this.labelsMaskArrays) {
                reqMem += f == null ? 0L : f.lengthLong() * (long)Nd4j.sizeOfDataType();
            }
        }
        if (this.labels != null) {
            for (INDArray f : this.labels) {
                reqMem += f == null ? 0L : f.lengthLong() * (long)Nd4j.sizeOfDataType();
            }
        }
        return reqMem;
    }

    @Override
    public void migrate() {
        if (Nd4j.getMemoryManager().getCurrentWorkspace() != null) {
            int e;
            if (this.features != null) {
                for (e = 0; e < this.features.length; ++e) {
                    this.features[e] = this.features[e].migrate();
                }
            }
            if (this.labels != null) {
                for (e = 0; e < this.labels.length; ++e) {
                    this.labels[e] = this.labels[e].migrate();
                }
            }
            if (this.featuresMaskArrays != null) {
                for (e = 0; e < this.featuresMaskArrays.length; ++e) {
                    this.featuresMaskArrays[e] = this.featuresMaskArrays[e].migrate();
                }
            }
            if (this.labelsMaskArrays != null) {
                for (e = 0; e < this.labelsMaskArrays.length; ++e) {
                    this.labelsMaskArrays[e] = this.labelsMaskArrays[e].migrate();
                }
            }
        }
    }

    @Override
    public void detach() {
        int e;
        if (this.features != null) {
            for (e = 0; e < this.features.length; ++e) {
                this.features[e] = this.features[e].detach();
            }
        }
        if (this.labels != null) {
            for (e = 0; e < this.labels.length; ++e) {
                this.labels[e] = this.labels[e].detach();
            }
        }
        if (this.featuresMaskArrays != null) {
            for (e = 0; e < this.featuresMaskArrays.length; ++e) {
                this.featuresMaskArrays[e] = this.featuresMaskArrays[e].detach();
            }
        }
        if (this.labelsMaskArrays != null) {
            for (e = 0; e < this.labelsMaskArrays.length; ++e) {
                this.labelsMaskArrays[e] = this.labelsMaskArrays[e].detach();
            }
        }
    }

    @Override
    public boolean isEmpty() {
        return MultiDataSet.nullOrEmpty(this.features) && MultiDataSet.nullOrEmpty(this.labels) && MultiDataSet.nullOrEmpty(this.featuresMaskArrays) && MultiDataSet.nullOrEmpty(this.labelsMaskArrays);
    }

    @Override
    public void shuffle() {
        List<org.nd4j.linalg.dataset.api.MultiDataSet> split = this.asList();
        Collections.shuffle(split);
        MultiDataSet mds = MultiDataSet.merge(split);
        this.features = mds.features;
        this.labels = mds.labels;
        this.featuresMaskArrays = mds.featuresMaskArrays;
        this.labelsMaskArrays = mds.labelsMaskArrays;
        this.exampleMetaData = mds.exampleMetaData;
    }

    private static boolean nullOrEmpty(INDArray[] arr) {
        if (arr == null) {
            return true;
        }
        for (INDArray i : arr) {
            if (i == null) continue;
            return false;
        }
        return true;
    }
}

