/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.local.transforms;

import com.codepoetics.protonpack.StreamUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.datavec.api.transform.DataAction;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.filter.Filter;
import org.datavec.api.transform.join.Join;
import org.datavec.api.transform.ops.IAggregableReduceOp;
import org.datavec.api.transform.rank.CalculateSortedRank;
import org.datavec.api.transform.reduce.IAssociativeReducer;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.transform.sequence.ConvertToSequence;
import org.datavec.api.transform.sequence.SequenceSplit;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.comparator.WritableComparator;
import org.datavec.arrow.ArrowConverter;
import org.datavec.local.transforms.SequenceEmptyRecordFunction;
import org.datavec.local.transforms.functions.EmptyRecordFunction;
import org.datavec.local.transforms.join.ExecuteJoinFromCoGroupFlatMapFunction;
import org.datavec.local.transforms.join.ExtractKeysFunction;
import org.datavec.local.transforms.misc.ColumnAsKeyPairFunction;
import org.datavec.local.transforms.rank.UnzipForCalculateSortedRankFunction;
import org.datavec.local.transforms.reduce.MapToPairForReducerFunction;
import org.datavec.local.transforms.sequence.ConvertToSequenceLengthOne;
import org.datavec.local.transforms.sequence.LocalGroupToSequenceFunction;
import org.datavec.local.transforms.sequence.LocalMapToPairByMultipleColumnsFunction;
import org.datavec.local.transforms.sequence.LocalSequenceFilterFunction;
import org.datavec.local.transforms.sequence.LocalSequenceTransformFunction;
import org.datavec.local.transforms.transform.LocalTransformFunction;
import org.datavec.local.transforms.transform.SequenceSplitFunction;
import org.datavec.local.transforms.transform.filter.LocalFilterFunction;
import org.nd4j.common.function.Function;
import org.nd4j.common.function.FunctionalUtils;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LocalTransformExecutor {
    private static final Logger log = LoggerFactory.getLogger(LocalTransformExecutor.class);
    public static final String LOG_ERROR_PROPERTY = "org.datavec.spark.transform.logerrors";
    private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);

    public static List<List<Writable>> execute(List<List<Writable>> inputWritables, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        List<List<Writable>> filteredSequence = inputWritables.parallelStream().filter(input -> input.size() == transformProcess.getInitialSchema().numColumns()).collect(Collectors.toList());
        if (filteredSequence.size() != inputWritables.size()) {
            log.warn("Filtered out " + (inputWritables.size() - filteredSequence.size()) + " values");
        }
        return (List)LocalTransformExecutor.execute(filteredSequence, null, transformProcess).getFirst();
    }

    public static List<List<List<Writable>>> executeToSequence(List<List<Writable>> inputWritables, TransformProcess transformProcess) {
        if (!(transformProcess.getFinalSchema() instanceof SequenceSchema)) {
            throw new IllegalStateException("Cannot return non-sequence data with this method");
        }
        return (List)LocalTransformExecutor.execute(inputWritables, null, transformProcess).getSecond();
    }

    public static List<List<Writable>> executeSequenceToSeparate(List<List<List<Writable>>> inputSequence, TransformProcess transformProcess) {
        if (transformProcess.getFinalSchema() instanceof SequenceSchema) {
            throw new IllegalStateException("Cannot return sequence data with this method");
        }
        return (List)LocalTransformExecutor.execute(null, inputSequence, transformProcess).getFirst();
    }

    public static List<List<List<Writable>>> executeSequenceToSequence(List<List<List<Writable>>> inputSequence, TransformProcess transformProcess) {
        if (!(transformProcess.getFinalSchema() instanceof SequenceSchema)) {
            ArrayList<List<List<Writable>>> ret = new ArrayList<List<List<Writable>>>(inputSequence.size());
            for (List<List<Writable>> timeStep : inputSequence) {
                ret.add((List<List<Writable>>)LocalTransformExecutor.execute(timeStep, null, transformProcess).getFirst());
            }
            return ret;
        }
        return (List)LocalTransformExecutor.execute(null, inputSequence, transformProcess).getSecond();
    }

    public static List<List<String>> convertWritableInputToString(List<List<Writable>> stringInput, Schema schema) {
        ArrayList<List<String>> ret = new ArrayList<List<String>>();
        ArrayList timeStepAdd = new ArrayList();
        for (int j = 0; j < stringInput.size(); ++j) {
            List<Writable> record = stringInput.get(j);
            ArrayList<String> recordAdd = new ArrayList<String>();
            for (int k = 0; k < record.size(); ++k) {
                recordAdd.add(record.get(k).toString());
            }
            timeStepAdd.add(recordAdd);
        }
        return ret;
    }

    public static List<List<Writable>> convertStringInput(List<List<String>> stringInput, Schema schema) {
        ArrayList<List<Writable>> ret = new ArrayList<List<Writable>>();
        ArrayList timeStepAdd = new ArrayList();
        for (int j = 0; j < stringInput.size(); ++j) {
            List<String> record = stringInput.get(j);
            ArrayList<Object> recordAdd = new ArrayList<Object>();
            block9: for (int k = 0; k < record.size(); ++k) {
                switch (schema.getType(k)) {
                    case Double: {
                        recordAdd.add(new DoubleWritable(Double.parseDouble(record.get(k))));
                        continue block9;
                    }
                    case Float: {
                        recordAdd.add(new FloatWritable(Float.parseFloat(record.get(k))));
                        continue block9;
                    }
                    case Integer: {
                        recordAdd.add(new IntWritable(Integer.parseInt(record.get(k))));
                        continue block9;
                    }
                    case Long: {
                        recordAdd.add(new LongWritable(Long.parseLong(record.get(k))));
                        continue block9;
                    }
                    case String: {
                        recordAdd.add(new Text(record.get(k)));
                        continue block9;
                    }
                    case Time: {
                        recordAdd.add(new LongWritable(Long.parseLong(record.get(k))));
                    }
                }
            }
            timeStepAdd.add(recordAdd);
        }
        return ret;
    }

    public static List<List<List<String>>> convertWritableInputToStringTimeSeries(List<List<List<Writable>>> stringInput, Schema schema) {
        ArrayList<List<List<String>>> ret = new ArrayList<List<List<String>>>();
        for (int i = 0; i < stringInput.size(); ++i) {
            List<List<Writable>> currInput = stringInput.get(i);
            ArrayList timeStepAdd = new ArrayList();
            for (int j = 0; j < currInput.size(); ++j) {
                List<Writable> record = currInput.get(j);
                ArrayList<String> recordAdd = new ArrayList<String>();
                for (int k = 0; k < record.size(); ++k) {
                    recordAdd.add(record.get(k).toString());
                }
                timeStepAdd.add(recordAdd);
            }
            ret.add(timeStepAdd);
        }
        return ret;
    }

    public static List<List<List<Writable>>> convertStringInputTimeSeries(List<List<List<String>>> stringInput, Schema schema) {
        ArrayList<List<List<Writable>>> ret = new ArrayList<List<List<Writable>>>();
        for (int i = 0; i < stringInput.size(); ++i) {
            List<List<String>> currInput = stringInput.get(i);
            ArrayList timeStepAdd = new ArrayList();
            for (int j = 0; j < currInput.size(); ++j) {
                List<String> record = currInput.get(j);
                ArrayList<Object> recordAdd = new ArrayList<Object>();
                block10: for (int k = 0; k < record.size(); ++k) {
                    switch (schema.getType(k)) {
                        case Double: {
                            recordAdd.add(new DoubleWritable(Double.parseDouble(record.get(k))));
                            continue block10;
                        }
                        case Float: {
                            recordAdd.add(new FloatWritable(Float.parseFloat(record.get(k))));
                            continue block10;
                        }
                        case Integer: {
                            recordAdd.add(new IntWritable(Integer.parseInt(record.get(k))));
                            continue block10;
                        }
                        case Long: {
                            recordAdd.add(new LongWritable(Long.parseLong(record.get(k))));
                            continue block10;
                        }
                        case String: {
                            recordAdd.add(new Text(record.get(k)));
                            continue block10;
                        }
                        case Time: {
                            recordAdd.add(new LongWritable(Long.parseLong(record.get(k))));
                        }
                    }
                }
                timeStepAdd.add(recordAdd);
            }
            ret.add(timeStepAdd);
        }
        return ret;
    }

    public static boolean isTryCatch() {
        return Boolean.getBoolean(LOG_ERROR_PROPERTY);
    }

    private static Pair<List<List<Writable>>, List<List<List<Writable>>>> execute(List<List<Writable>> inputWritables, List<List<List<Writable>>> inputSequence, TransformProcess sequence) {
        List<List<Object>> currentWritables = inputWritables;
        List currentSequence = inputSequence;
        List dataActions = sequence.getActionList();
        if (inputWritables != null) {
            List<Writable> first = inputWritables.get(0);
            if (first.size() != sequence.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input data number of columns (" + first.size() + ") does not match the number of columns for the transform process (" + sequence.getInitialSchema().numColumns() + ")");
            }
        } else {
            List<List<Writable>> firstSeq = inputSequence.get(0);
            if (firstSeq.size() > 0 && firstSeq.get(0).size() != sequence.getInitialSchema().numColumns()) {
                throw new IllegalStateException("Input sequence data number of columns (" + ((List)firstSeq.get(0)).size() + ") does not match the number of columns for the transform process (" + sequence.getInitialSchema().numColumns() + ")");
            }
        }
        for (DataAction d : dataActions) {
            if (d.getTransform() != null) {
                Object function;
                Transform t = d.getTransform();
                if (currentWritables != null) {
                    function = new LocalTransformFunction(t);
                    if (LocalTransformExecutor.isTryCatch()) {
                        currentWritables = currentWritables.stream().map(arg_0 -> LocalTransformExecutor.lambda$execute$1((Function)function, arg_0)).filter(input -> new EmptyRecordFunction().apply((List<Writable>)input)).collect(Collectors.toList());
                        continue;
                    }
                    currentWritables = currentWritables.stream().map(arg_0 -> LocalTransformExecutor.lambda$execute$3((Function)function, arg_0)).collect(Collectors.toList());
                    continue;
                }
                function = new LocalSequenceTransformFunction(t);
                if (LocalTransformExecutor.isTryCatch()) {
                    currentSequence = currentSequence.stream().map(arg_0 -> LocalTransformExecutor.lambda$execute$4((Function)function, arg_0)).filter(input -> new SequenceEmptyRecordFunction().apply((List<List<Writable>>)input)).collect(Collectors.toList());
                    continue;
                }
                currentSequence = currentSequence.stream().map(arg_0 -> LocalTransformExecutor.lambda$execute$6((Function)function, arg_0)).collect(Collectors.toList());
                continue;
            }
            if (d.getFilter() != null) {
                Filter f = d.getFilter();
                if (currentWritables != null) {
                    LocalFilterFunction localFilterFunction = new LocalFilterFunction(f);
                    currentWritables = currentWritables.stream().filter(input -> localFilterFunction.apply((List<Writable>)input)).collect(Collectors.toList());
                    continue;
                }
                LocalSequenceFilterFunction localSequenceFilterFunction = new LocalSequenceFilterFunction(f);
                currentSequence = currentSequence.stream().filter(input -> localSequenceFilterFunction.apply((List<List<Writable>>)input)).collect(Collectors.toList());
                continue;
            }
            if (d.getConvertToSequence() != null) {
                ConvertToSequence cts = d.getConvertToSequence();
                if (cts.isSingleStepSequencesMode()) {
                    ConvertToSequenceLengthOne convertToSequenceLengthOne = new ConvertToSequenceLengthOne();
                    currentSequence = currentWritables.stream().map(input -> convertToSequenceLengthOne.apply((List<Writable>)input)).collect(Collectors.toList());
                    currentWritables = null;
                    continue;
                }
                Schema schema = cts.getInputSchema();
                int[] colIdxs = schema.getIndexOfColumns(cts.getKeyColumns());
                LocalMapToPairByMultipleColumnsFunction localMapToPairByMultipleColumnsFunction = new LocalMapToPairByMultipleColumnsFunction(colIdxs);
                List withKey = currentWritables.stream().map(inputSequence2 -> localMapToPairByMultipleColumnsFunction.apply((List<Writable>)inputSequence2)).collect(Collectors.toList());
                Map collect = FunctionalUtils.groupByKey(withKey);
                LocalGroupToSequenceFunction localGroupToSequenceFunction = new LocalGroupToSequenceFunction(cts.getComparator());
                currentSequence = collect.entrySet().stream().map(input -> (List)input.getValue()).map(input -> localGroupToSequenceFunction.apply((List<List<Writable>>)input)).collect(Collectors.toList());
                currentWritables = null;
                continue;
            }
            if (d.getConvertFromSequence() != null) {
                if (currentSequence == null) {
                    throw new IllegalStateException("Cannot execute ConvertFromSequence operation: current sequence is null");
                }
                currentWritables = currentSequence.stream().flatMap(input -> input.stream()).collect(Collectors.toList());
                currentSequence = null;
                continue;
            }
            if (d.getSequenceSplit() != null) {
                SequenceSplit sequenceSplit = d.getSequenceSplit();
                if (currentSequence == null) {
                    throw new IllegalStateException("Error during execution of SequenceSplit: currentSequence is null");
                }
                SequenceSplitFunction sequenceSplitFunction = new SequenceSplitFunction(sequenceSplit);
                currentSequence = currentSequence.stream().flatMap(input -> sequenceSplitFunction.call(input).stream()).collect(Collectors.toList());
                continue;
            }
            if (d.getReducer() != null) {
                IAssociativeReducer reducer = d.getReducer();
                if (currentWritables == null) {
                    throw new IllegalStateException("Error during execution of reduction: current writables are null. Trying to execute a reduce operation on a sequence?");
                }
                MapToPairForReducerFunction mapToPairForReducerFunction = new MapToPairForReducerFunction(reducer);
                List pair = currentWritables.stream().map(input -> mapToPairForReducerFunction.apply((List<Writable>)input)).collect(Collectors.toList());
                HashMap resultPerKey = new HashMap();
                Map groupedByKey = FunctionalUtils.groupByKey(pair);
                List aggregated = StreamUtils.aggregate(groupedByKey.entrySet().stream(), (BiPredicate)new BiPredicate<Map.Entry<String, List<List<Writable>>>, Map.Entry<String, List<List<Writable>>>>(){

                    @Override
                    public boolean test(Map.Entry<String, List<List<Writable>>> stringListEntry, Map.Entry<String, List<List<Writable>>> stringListEntry2) {
                        return stringListEntry.getKey().equals(stringListEntry2.getKey());
                    }
                }).collect(Collectors.toList());
                aggregated.stream().forEach(input -> {
                    for (Map.Entry entry : input) {
                        if (resultPerKey.containsKey(entry.getKey())) continue;
                        IAggregableReduceOp reducer2 = reducer.aggregableReducer();
                        resultPerKey.put(entry.getKey(), reducer2);
                        for (List value : (List)entry.getValue()) {
                            reducer2.accept((Object)value);
                        }
                    }
                });
                currentWritables = resultPerKey.entrySet().stream().map(input -> (List)((IAggregableReduceOp)input.getValue()).get()).collect(Collectors.toList());
                continue;
            }
            if (d.getCalculateSortedRank() != null) {
                CalculateSortedRank csr = d.getCalculateSortedRank();
                if (currentWritables == null) {
                    throw new IllegalStateException("Error during execution of CalculateSortedRank: current writables are null. Trying to execute a CalculateSortedRank operation on a sequence? (not currently supported)");
                }
                WritableComparator comparator = csr.getComparator();
                String sortColumn = csr.getSortOnColumn();
                int sortColumnIdx = csr.getInputSchema().getIndexOfColumn(sortColumn);
                boolean ascending = csr.isAscending();
                List pairRDD = currentWritables.stream().map(input -> new ColumnAsKeyPairFunction(sortColumnIdx).apply((List<Writable>)input)).collect(Collectors.toList());
                pairRDD = pairRDD.stream().sorted(new Comparator<Pair<Writable, List<Writable>>>((Comparator)comparator, ascending){
                    final /* synthetic */ Comparator val$comparator;
                    final /* synthetic */ boolean val$ascending;
                    {
                        this.val$comparator = comparator;
                        this.val$ascending = bl;
                    }

                    @Override
                    public int compare(Pair<Writable, List<Writable>> writableListPair, Pair<Writable, List<Writable>> t1) {
                        int result = this.val$comparator.compare(writableListPair.getFirst(), t1.getFirst());
                        if (this.val$ascending) {
                            return result;
                        }
                        return -result;
                    }
                }).collect(Collectors.toList());
                List zipped = StreamUtils.zipWithIndex(pairRDD.stream()).collect(Collectors.toList());
                currentWritables = zipped.stream().map(input -> new UnzipForCalculateSortedRankFunction().apply((Pair<Pair<Writable, List<Writable>>, Long>)Pair.of((Object)input.getValue(), (Object)input.getIndex()))).collect(Collectors.toList());
                continue;
            }
            throw new RuntimeException("Unknown/not implemented action: " + d);
        }
        if (currentSequence != null) {
            boolean allSameLength = true;
            Integer length = null;
            for (List<List<Writable>> record : currentSequence) {
                if (length == null) {
                    length = record.size();
                    continue;
                }
                if (record.size() == length.intValue()) continue;
                allSameLength = false;
            }
            if (allSameLength) {
                List writablesConvert;
                List arrowColumns = ArrowConverter.toArrowColumnsTimeSeries((BufferAllocator)bufferAllocator, (Schema)sequence.getFinalSchema(), currentSequence);
                int timeSeriesLength = currentSequence.get(0).size() * currentSequence.get(0).get(0).size();
                currentSequence = writablesConvert = ArrowConverter.toArrowWritablesTimeSeries((List)arrowColumns, (Schema)sequence.getFinalSchema(), (int)timeSeriesLength);
            }
            return Pair.of(null, currentSequence);
        }
        return new Pair((Object)ArrowConverter.toArrowWritables((List)ArrowConverter.toArrowColumns((BufferAllocator)bufferAllocator, (Schema)sequence.getFinalSchema(), currentWritables), (Schema)sequence.getFinalSchema()), null);
    }

    public static List<List<Writable>> executeJoin(Join join, List<List<Writable>> left, List<List<Writable>> right) {
        String[] leftColumnNames = join.getJoinColumnsLeft();
        int[] leftColumnIndexes = new int[leftColumnNames.length];
        for (int i = 0; i < leftColumnNames.length; ++i) {
            leftColumnIndexes[i] = join.getLeftSchema().getIndexOfColumn(leftColumnNames[i]);
        }
        ExtractKeysFunction extractKeysFunction1 = new ExtractKeysFunction(leftColumnIndexes);
        List leftJV = left.stream().filter(input -> input.size() != leftColumnNames.length).map(input -> extractKeysFunction1.apply((List<Writable>)input)).collect(Collectors.toList());
        String[] rightColumnNames = join.getJoinColumnsRight();
        int[] rightColumnIndexes = new int[rightColumnNames.length];
        for (int i = 0; i < rightColumnNames.length; ++i) {
            rightColumnIndexes[i] = join.getRightSchema().getIndexOfColumn(rightColumnNames[i]);
        }
        ExtractKeysFunction extractKeysFunction = new ExtractKeysFunction(rightColumnIndexes);
        List rightJV = right.stream().filter(input -> input.size() != rightColumnNames.length).map(input -> extractKeysFunction.apply((List<Writable>)input)).collect(Collectors.toList());
        Map cogroupedJV = FunctionalUtils.cogroup(leftJV, rightJV);
        ExecuteJoinFromCoGroupFlatMapFunction executeJoinFromCoGroupFlatMapFunction = new ExecuteJoinFromCoGroupFlatMapFunction(join);
        List ret = cogroupedJV.entrySet().stream().flatMap(input -> executeJoinFromCoGroupFlatMapFunction.call(Pair.of(input.getKey(), input.getValue())).stream()).collect(Collectors.toList());
        Schema retSchema = join.getOutputSchema();
        return ArrowConverter.toArrowWritables((List)ArrowConverter.toArrowColumns((BufferAllocator)bufferAllocator, (Schema)retSchema, ret), (Schema)retSchema);
    }

    private static /* synthetic */ List lambda$execute$6(Function function, List input) {
        return (List)function.apply((Object)input);
    }

    private static /* synthetic */ List lambda$execute$4(Function function, List input) {
        return (List)function.apply((Object)input);
    }

    private static /* synthetic */ List lambda$execute$3(Function function, List input) {
        return (List)function.apply((Object)input);
    }

    private static /* synthetic */ List lambda$execute$1(Function function, List input) {
        return (List)function.apply((Object)input);
    }
}

