/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;

public class DataSetUtil {
    public static INDArray tailor2d(@NonNull DataSet dataSet, boolean areFeatures) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        return DataSetUtil.tailor2d(areFeatures ? dataSet.getFeatures() : dataSet.getLabels(), areFeatures ? dataSet.getFeaturesMaskArray() : dataSet.getLabelsMaskArray());
    }

    public static INDArray tailor2d(@NonNull INDArray data, INDArray mask) {
        if (data == null) {
            throw new NullPointerException("data");
        }
        switch (data.rank()) {
            case 1: 
            case 2: {
                return data;
            }
            case 3: {
                return DataSetUtil.tailor3d2d(data, mask);
            }
            case 4: {
                return DataSetUtil.tailor4d2d(data);
            }
        }
        throw new RuntimeException("Unsupported data rank");
    }

    public static INDArray tailor3d2d(DataSet dataset, boolean areFeatures) {
        INDArray data = areFeatures ? dataset.getFeatures() : dataset.getLabels();
        INDArray mask = areFeatures ? dataset.getFeaturesMaskArray() : dataset.getLabelsMaskArray();
        return DataSetUtil.tailor3d2d(data, mask);
    }

    public static INDArray tailor3d2d(@NonNull INDArray data, INDArray mask) {
        if (data == null) {
            throw new NullPointerException("data");
        }
        int instances = data.size(0);
        int features = data.size(1);
        int timesteps = data.size(2);
        boolean hasMasks = mask != null;
        INDArray in2d = Nd4j.create(features, timesteps * instances);
        int tads = data.tensorssAlongDimension(2, 0);
        for (int i = 0; i < tads; ++i) {
            INDArray thisTAD = data.tensorAlongDimension(i, 2, 0);
            if (hasMasks) {
                thisTAD.muli(mask);
            }
            in2d.putRow(i, Nd4j.toFlattened('c', thisTAD));
        }
        in2d = in2d.transpose();
        if (hasMasks) {
            INDArray columnMask = Nd4j.toFlattened('c', mask).transpose();
            int actualSamples = columnMask.sumNumber().intValue();
            if (actualSamples == 0) {
                return null;
            }
            INDArray in2dMask = Nd4j.create(actualSamples, features);
            int i = 0;
            for (int j = 0; j < instances; ++j) {
                for (int k = 0; k < timesteps; ++k) {
                    if (columnMask.getInt(j * timesteps + k, 0) == 0) continue;
                    in2dMask.putRow(i, in2d.getRow(j * timesteps + k));
                    ++i;
                }
            }
            return in2dMask;
        }
        return in2d;
    }

    public static INDArray tailor4d2d(DataSet dataset, boolean areFeatures) {
        return DataSetUtil.tailor4d2d(areFeatures ? dataset.getFeatures() : dataset.getLabels());
    }

    public static INDArray tailor4d2d(@NonNull INDArray data) {
        if (data == null) {
            throw new NullPointerException("data");
        }
        int instances = data.size(0);
        int channels = data.size(1);
        int height = data.size(2);
        int width = data.size(3);
        INDArray in2d = Nd4j.create(channels, height * width * instances);
        int tads = data.tensorssAlongDimension(3, 2, 0);
        for (int i = 0; i < tads; ++i) {
            INDArray thisTAD = data.tensorAlongDimension(i, 3, 2, 0);
            in2d.putRow(i, Nd4j.toFlattened(thisTAD));
        }
        return in2d.transposei();
    }
}

