/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor.serializer;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.linalg.dataset.api.preprocessor.MinMaxStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.MultiNormalizerHybrid;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.StandardizeStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerSerializerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType;
import org.nd4j.linalg.dataset.api.preprocessor.stats.DistributionStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.MinMaxStats;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.factory.Nd4j;

public class MultiHybridSerializerStrategy
implements NormalizerSerializerStrategy<MultiNormalizerHybrid> {
    @Override
    public void write(@NonNull MultiNormalizerHybrid normalizer, @NonNull OutputStream stream) throws IOException {
        if (normalizer == null) {
            throw new NullPointerException("normalizer is marked non-null but is null");
        }
        if (stream == null) {
            throw new NullPointerException("stream is marked non-null but is null");
        }
        try (DataOutputStream dos = new DataOutputStream(stream);){
            MultiHybridSerializerStrategy.writeStatsMap(normalizer.getInputStats(), dos);
            MultiHybridSerializerStrategy.writeStatsMap(normalizer.getOutputStats(), dos);
            MultiHybridSerializerStrategy.writeStrategy(normalizer.getGlobalInputStrategy(), dos);
            MultiHybridSerializerStrategy.writeStrategy(normalizer.getGlobalOutputStrategy(), dos);
            MultiHybridSerializerStrategy.writeStrategyMap(normalizer.getPerInputStrategies(), dos);
            MultiHybridSerializerStrategy.writeStrategyMap(normalizer.getPerOutputStrategies(), dos);
        }
    }

    @Override
    public MultiNormalizerHybrid restore(@NonNull InputStream stream) throws IOException {
        if (stream == null) {
            throw new NullPointerException("stream is marked non-null but is null");
        }
        DataInputStream dis = new DataInputStream(stream);
        MultiNormalizerHybrid result = new MultiNormalizerHybrid();
        result.setInputStats(MultiHybridSerializerStrategy.readStatsMap(dis));
        result.setOutputStats(MultiHybridSerializerStrategy.readStatsMap(dis));
        result.setGlobalInputStrategy(MultiHybridSerializerStrategy.readStrategy(dis));
        result.setGlobalOutputStrategy(MultiHybridSerializerStrategy.readStrategy(dis));
        result.setPerInputStrategies(MultiHybridSerializerStrategy.readStrategyMap(dis));
        result.setPerOutputStrategies(MultiHybridSerializerStrategy.readStrategyMap(dis));
        return result;
    }

    @Override
    public NormalizerType getSupportedType() {
        return NormalizerType.MULTI_HYBRID;
    }

    private static void writeStatsMap(Map<Integer, NormalizerStats> statsMap, DataOutputStream dos) throws IOException {
        Set<Integer> indices = statsMap.keySet();
        dos.writeInt(indices.size());
        for (int index : indices) {
            dos.writeInt(index);
            MultiHybridSerializerStrategy.writeNormalizerStats(statsMap.get(index), dos);
        }
    }

    private static Map<Integer, NormalizerStats> readStatsMap(DataInputStream dis) throws IOException {
        HashMap<Integer, NormalizerStats> result = new HashMap<Integer, NormalizerStats>();
        int numEntries = dis.readInt();
        for (int i = 0; i < numEntries; ++i) {
            int index = dis.readInt();
            result.put(index, MultiHybridSerializerStrategy.readNormalizerStats(dis));
        }
        return result;
    }

    private static void writeNormalizerStats(NormalizerStats normalizerStats, DataOutputStream dos) throws IOException {
        if (normalizerStats instanceof DistributionStats) {
            MultiHybridSerializerStrategy.writeDistributionStats((DistributionStats)normalizerStats, dos);
        } else if (normalizerStats instanceof MinMaxStats) {
            MultiHybridSerializerStrategy.writeMinMaxStats((MinMaxStats)normalizerStats, dos);
        } else {
            throw new RuntimeException("Unsupported stats class " + normalizerStats.getClass());
        }
    }

    private static NormalizerStats readNormalizerStats(DataInputStream dis) throws IOException {
        Strategy strategy = Strategy.values()[dis.readInt()];
        switch (strategy) {
            case STANDARDIZE: {
                return MultiHybridSerializerStrategy.readDistributionStats(dis);
            }
            case MIN_MAX: {
                return MultiHybridSerializerStrategy.readMinMaxStats(dis);
            }
        }
        throw new RuntimeException("Unsupported strategy " + strategy.name());
    }

    private static void writeDistributionStats(DistributionStats normalizerStats, DataOutputStream dos) throws IOException {
        dos.writeInt(Strategy.STANDARDIZE.ordinal());
        Nd4j.write(normalizerStats.getMean(), dos);
        Nd4j.write(normalizerStats.getStd(), dos);
    }

    private static NormalizerStats readDistributionStats(DataInputStream dis) throws IOException {
        return new DistributionStats(Nd4j.read(dis), Nd4j.read(dis));
    }

    private static void writeMinMaxStats(MinMaxStats normalizerStats, DataOutputStream dos) throws IOException {
        dos.writeInt(Strategy.MIN_MAX.ordinal());
        Nd4j.write(normalizerStats.getLower(), dos);
        Nd4j.write(normalizerStats.getUpper(), dos);
    }

    private static NormalizerStats readMinMaxStats(DataInputStream dis) throws IOException {
        return new MinMaxStats(Nd4j.read(dis), Nd4j.read(dis));
    }

    private static void writeStrategyMap(Map<Integer, NormalizerStrategy> strategyMap, DataOutputStream dos) throws IOException {
        Set<Integer> indices = strategyMap.keySet();
        dos.writeInt(indices.size());
        for (int index : indices) {
            dos.writeInt(index);
            MultiHybridSerializerStrategy.writeStrategy(strategyMap.get(index), dos);
        }
    }

    private static Map<Integer, NormalizerStrategy> readStrategyMap(DataInputStream dis) throws IOException {
        HashMap<Integer, NormalizerStrategy> result = new HashMap<Integer, NormalizerStrategy>();
        int numIndices = dis.readInt();
        for (int i = 0; i < numIndices; ++i) {
            result.put(dis.readInt(), MultiHybridSerializerStrategy.readStrategy(dis));
        }
        return result;
    }

    private static void writeStrategy(NormalizerStrategy strategy, DataOutputStream dos) throws IOException {
        if (strategy == null) {
            MultiHybridSerializerStrategy.writeNoStrategy(dos);
        } else if (strategy instanceof StandardizeStrategy) {
            MultiHybridSerializerStrategy.writeStandardizeStrategy(dos);
        } else if (strategy instanceof MinMaxStrategy) {
            MultiHybridSerializerStrategy.writeMinMaxStrategy((MinMaxStrategy)strategy, dos);
        } else {
            throw new RuntimeException("Unsupported strategy class " + strategy.getClass());
        }
    }

    private static NormalizerStrategy readStrategy(DataInputStream dis) throws IOException {
        Strategy strategy = Strategy.values()[dis.readInt()];
        switch (strategy) {
            case NULL: {
                return null;
            }
            case STANDARDIZE: {
                return MultiHybridSerializerStrategy.readStandardizeStrategy();
            }
            case MIN_MAX: {
                return MultiHybridSerializerStrategy.readMinMaxStrategy(dis);
            }
        }
        throw new RuntimeException("Unsupported strategy " + strategy.name());
    }

    private static void writeNoStrategy(DataOutputStream dos) throws IOException {
        dos.writeInt(Strategy.NULL.ordinal());
    }

    private static void writeStandardizeStrategy(DataOutputStream dos) throws IOException {
        dos.writeInt(Strategy.STANDARDIZE.ordinal());
    }

    private static NormalizerStrategy readStandardizeStrategy() {
        return new StandardizeStrategy();
    }

    private static void writeMinMaxStrategy(MinMaxStrategy strategy, DataOutputStream dos) throws IOException {
        dos.writeInt(Strategy.MIN_MAX.ordinal());
        dos.writeDouble(strategy.getMinRange());
        dos.writeDouble(strategy.getMaxRange());
    }

    private static NormalizerStrategy readMinMaxStrategy(DataInputStream dis) throws IOException {
        return new MinMaxStrategy(dis.readDouble(), dis.readDouble());
    }

    private static enum Strategy {
        NULL,
        STANDARDIZE,
        MIN_MAX;

    }
}

