package org.nd4j.linalg.dataset;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.commons.math3.util.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/nd4j/linalg/dataset/MultiDataSet.class */
public class MultiDataSet implements org.nd4j.linalg.dataset.api.MultiDataSet {
    private INDArray[] features;
    private INDArray[] labels;
    private INDArray[] featuresMaskArrays;
    private INDArray[] labelsMaskArrays;

    public MultiDataSet(INDArray iNDArray, INDArray iNDArray2) {
        this(new INDArray[]{iNDArray}, new INDArray[]{iNDArray2});
    }

    public MultiDataSet(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2) {
        this(iNDArrayArr, iNDArrayArr2, null, null);
    }

    public MultiDataSet(INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, INDArray[] iNDArrayArr3, INDArray[] iNDArrayArr4) {
        if (iNDArrayArr != null && iNDArrayArr3 != null && iNDArrayArr.length != iNDArrayArr3.length) {
            throw new IllegalArgumentException("Invalid features / features mask arrays combination: features and features mask arrays must not be different lengths");
        }
        if (iNDArrayArr2 != null && iNDArrayArr4 != null && iNDArrayArr2.length != iNDArrayArr4.length) {
            throw new IllegalArgumentException("Invalid labels / labels mask arrays combination: labels and labels mask arrays must not be different lengths");
        }
        this.features = iNDArrayArr;
        this.labels = iNDArrayArr2;
        this.featuresMaskArrays = iNDArrayArr3;
        this.labelsMaskArrays = iNDArrayArr4;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public int numFeatureArrays() {
        if (this.features != null) {
            return this.features.length;
        }
        return 0;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public int numLabelsArrays() {
        if (this.labels != null) {
            return this.labels.length;
        }
        return 0;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getFeatures() {
        return this.features;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getFeatures(int i) {
        return this.features[i];
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeatures(INDArray[] iNDArrayArr) {
        this.features = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeatures(int i, INDArray iNDArray) {
        this.features[i] = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getLabels() {
        return this.labels;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getLabels(int i) {
        return this.labels[i];
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabels(INDArray[] iNDArrayArr) {
        this.labels = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabels(int i, INDArray iNDArray) {
        this.labels[i] = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public boolean hasMaskArrays() {
        if (this.featuresMaskArrays == null && this.labelsMaskArrays == null) {
            return false;
        }
        if (this.featuresMaskArrays != null) {
            for (INDArray iNDArray : this.featuresMaskArrays) {
                if (iNDArray != null) {
                    return true;
                }
            }
        }
        if (this.labelsMaskArrays == null) {
            return false;
        }
        for (INDArray iNDArray2 : this.labelsMaskArrays) {
            if (iNDArray2 != null) {
                return true;
            }
        }
        return false;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getFeaturesMaskArrays() {
        return this.featuresMaskArrays;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getFeaturesMaskArray(int i) {
        if (this.featuresMaskArrays != null) {
            return this.featuresMaskArrays[i];
        }
        return null;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeaturesMaskArrays(INDArray[] iNDArrayArr) {
        this.featuresMaskArrays = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setFeaturesMaskArray(int i, INDArray iNDArray) {
        this.featuresMaskArrays[i] = iNDArray;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray[] getLabelsMaskArrays() {
        return this.labelsMaskArrays;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public INDArray getLabelsMaskArray(int i) {
        if (this.labelsMaskArrays != null) {
            return this.labelsMaskArrays[i];
        }
        return null;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabelsMaskArray(INDArray[] iNDArrayArr) {
        this.labelsMaskArrays = iNDArrayArr;
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSet
    public void setLabelsMaskArray(int i, INDArray iNDArray) {
        this.labelsMaskArrays[i] = iNDArray;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v83, types: [java.util.List] */
    public static MultiDataSet merge(Collection<? extends org.nd4j.linalg.dataset.api.MultiDataSet> collection) {
        if (collection.size() == 1) {
            org.nd4j.linalg.dataset.api.MultiDataSet next = collection.iterator().next();
            return next instanceof MultiDataSet ? (MultiDataSet) next : new MultiDataSet(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArrays(), next.getLabelsMaskArrays());
        }
        ArrayList<org.nd4j.linalg.dataset.api.MultiDataSet> arrayList = collection instanceof List ? (List) collection : new ArrayList(collection);
        int numFeatureArrays = ((org.nd4j.linalg.dataset.api.MultiDataSet) arrayList.get(0)).numFeatureArrays();
        int numLabelsArrays = ((org.nd4j.linalg.dataset.api.MultiDataSet) arrayList.get(0)).numLabelsArrays();
        INDArray[][] iNDArrayArr = new INDArray[arrayList.size()][0];
        INDArray[][] iNDArrayArr2 = new INDArray[arrayList.size()][0];
        INDArray[][] iNDArrayArr3 = new INDArray[arrayList.size()][0];
        INDArray[][] iNDArrayArr4 = new INDArray[arrayList.size()][0];
        int i = 0;
        for (org.nd4j.linalg.dataset.api.MultiDataSet multiDataSet : arrayList) {
            iNDArrayArr[i] = multiDataSet.getFeatures();
            iNDArrayArr2[i] = multiDataSet.getLabels();
            iNDArrayArr3[i] = multiDataSet.getFeaturesMaskArrays();
            iNDArrayArr4[i] = multiDataSet.getLabelsMaskArrays();
            if (iNDArrayArr[i] == null || iNDArrayArr[i].length != numFeatureArrays) {
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of input arrays: toMerge[0] has " + numFeatureArrays + " input arrays; toMerge[" + i + "] has " + (iNDArrayArr[i] != null ? Integer.valueOf(iNDArrayArr[i].length) : null) + " arrays");
            }
            if (iNDArrayArr2[i] == null || iNDArrayArr2[i].length != numLabelsArrays) {
                throw new IllegalStateException("Cannot merge MultiDataSets with different number of output arrays: toMerge[0] has " + numLabelsArrays + " output arrays; toMerge[" + i + "] has " + (iNDArrayArr2[i] != null ? Integer.valueOf(iNDArrayArr2[i].length) : null) + " arrays");
            }
            i++;
        }
        INDArray[] iNDArrayArr5 = new INDArray[numFeatureArrays];
        INDArray[] iNDArrayArr6 = new INDArray[numLabelsArrays];
        INDArray[] iNDArrayArr7 = new INDArray[numFeatureArrays];
        INDArray[] iNDArrayArr8 = new INDArray[numLabelsArrays];
        boolean z = false;
        for (int i2 = 0; i2 < numFeatureArrays; i2++) {
            Pair<INDArray, INDArray> merge = merge(iNDArrayArr, iNDArrayArr3, i2);
            iNDArrayArr5[i2] = (INDArray) merge.getFirst();
            iNDArrayArr7[i2] = (INDArray) merge.getSecond();
            if (iNDArrayArr7[i2] != null) {
                z = true;
            }
        }
        if (!z) {
            iNDArrayArr7 = null;
        }
        boolean z2 = false;
        for (int i3 = 0; i3 < numLabelsArrays; i3++) {
            Pair<INDArray, INDArray> merge2 = merge(iNDArrayArr2, iNDArrayArr4, i3);
            iNDArrayArr6[i3] = (INDArray) merge2.getFirst();
            iNDArrayArr8[i3] = (INDArray) merge2.getSecond();
            if (iNDArrayArr8[i3] != null) {
                z2 = true;
            }
        }
        if (!z2) {
            iNDArrayArr8 = null;
        }
        return new MultiDataSet(iNDArrayArr5, iNDArrayArr6, iNDArrayArr7, iNDArrayArr8);
    }

    private static Pair<INDArray, INDArray> merge(INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        int rank = iNDArrayArr[i][0].rank();
        if (rank == 2) {
            return new Pair<>(merge2d(iNDArrayArr, i), (Object) null);
        }
        if (rank == 3) {
            return mergeTimeSeries(iNDArrayArr, iNDArrayArr2, i);
        }
        if (rank == 4) {
            return new Pair<>(merge4d(iNDArrayArr, i), (Object) null);
        }
        throw new UnsupportedOperationException("Cannot merge arrays with rank 5 or more (input/output number: " + i + ")");
    }

    private static INDArray merge2d(INDArray[][] iNDArrayArr, int i) {
        int i2 = 0;
        int columns = iNDArrayArr[0][i].columns();
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            i2 += iNDArrayArr[i3][i].rows();
            if (iNDArrayArr[i3][i].columns() != columns) {
                throw new IllegalStateException("Cannot merge 2d arrays with different numbers of columns (firstNCols=" + columns + ", ithNCols=" + iNDArrayArr[i3][i].columns() + ")");
            }
        }
        INDArray create = Nd4j.create(i2, columns);
        int i4 = 0;
        for (int i5 = 0; i5 < iNDArrayArr.length; i5++) {
            int rows = iNDArrayArr[i5][i].rows();
            create.put(new INDArrayIndex[]{NDArrayIndex.interval(i4, i4 + rows), NDArrayIndex.all()}, iNDArrayArr[i5][i]);
            i4 += rows;
        }
        return create;
    }

    private static Pair<INDArray, INDArray> mergeTimeSeries(INDArray[][] iNDArrayArr, INDArray[][] iNDArrayArr2, int i) {
        int size = iNDArrayArr[0][i].size(2);
        int size2 = iNDArrayArr[0][i].size(1);
        int i2 = size;
        boolean z = false;
        boolean z2 = false;
        int i3 = 0;
        for (int i4 = 0; i4 < iNDArrayArr.length; i4++) {
            i3 += iNDArrayArr[i4][i].size(0);
            int size3 = iNDArrayArr[i4][i].size(2);
            i2 = Math.max(i2, size3);
            if (size3 != size) {
                z2 = true;
            }
            if (iNDArrayArr2 != null && iNDArrayArr2[i4] != null && iNDArrayArr2[i4][i] != null) {
                z = true;
            }
            if (iNDArrayArr[i4][i].size(1) != size2) {
                throw new IllegalStateException("Cannot merge time series with different size for dimension 1 (first shape: " + Arrays.toString(iNDArrayArr[0][i].shape()) + ", " + i4 + "th shape: " + Arrays.toString(iNDArrayArr[i4][i].shape()));
            }
        }
        boolean z3 = z || z2;
        INDArray create = Nd4j.create(i3, size2, i2);
        INDArray ones = z3 ? Nd4j.ones(i3, i2) : null;
        int i5 = 0;
        if (!z2 && !z3) {
            for (int i6 = 0; i6 < iNDArrayArr.length; i6++) {
                int size4 = iNDArrayArr[i6][i].size(0);
                create.put(new INDArrayIndex[]{NDArrayIndex.interval(i5, i5 + size4), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArrayArr[i6][i]);
                i5 += size4;
            }
            return new Pair<>(create, (Object) null);
        }
        for (int i7 = 0; i7 < iNDArrayArr.length; i7++) {
            INDArray iNDArray = iNDArrayArr[i7][i];
            int size5 = iNDArray.size(0);
            int size6 = iNDArray.size(2);
            create.put(new INDArrayIndex[]{NDArrayIndex.interval(i5, i5 + size5), NDArrayIndex.all(), NDArrayIndex.interval(0, size6)}, iNDArray);
            if (iNDArrayArr2 != null && iNDArrayArr2[i7] != null && iNDArrayArr2[i7][i] != null) {
                INDArray iNDArray2 = iNDArrayArr2[i7][i];
                int size7 = iNDArray2.size(1);
                ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i5, i5 + size5), NDArrayIndex.interval(0, size7)}, iNDArray2);
                if (size7 < i2) {
                    ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i5, i5 + size5), NDArrayIndex.interval(size7, i2)}, Nd4j.zeros(size5, i2 - size7));
                }
            } else if (size6 < i2) {
                ones.put(new INDArrayIndex[]{NDArrayIndex.interval(i5, i5 + size5), NDArrayIndex.interval(size6, i2)}, Nd4j.zeros(size5, i2 - size6));
            }
            i5 += size5;
        }
        return new Pair<>(create, ones);
    }

    private static INDArray merge4d(INDArray[][] iNDArrayArr, int i) {
        int i2 = 0;
        int[] shape = iNDArrayArr[0][i].shape();
        for (int i3 = 0; i3 < iNDArrayArr.length; i3++) {
            i2 += iNDArrayArr[i3][i].size(0);
            int[] shape2 = iNDArrayArr[i3][i].shape();
            if (shape2.length != 4) {
                throw new IllegalStateException("Cannot merge 4d arrays with non 4d arrays");
            }
            for (int i4 = 1; i4 < 4; i4++) {
                if (shape2[i4] != shape[i4]) {
                    throw new IllegalStateException("Cannot merge 4d arrays with different shape (other than # examples):  data[0][" + i + "].shape = " + Arrays.toString(shape) + ", data[" + i3 + "][" + i + "].shape = " + Arrays.toString(shape2));
                }
            }
        }
        INDArray create = Nd4j.create(i2, shape[1], shape[2], shape[3]);
        int i5 = 0;
        for (int i6 = 0; i6 < iNDArrayArr.length; i6++) {
            int size = iNDArrayArr[i6][i].size(0);
            create.put(new INDArrayIndex[]{NDArrayIndex.interval(i5, i5 + size), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}, iNDArrayArr[i6][i]);
            i5 += size;
        }
        return create;
    }
}
