/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators.testutils;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.DefaultOpenContext;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
import org.apache.flink.api.java.typeutils.runtime.RuntimeSerializerFactory;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.runtime.memory.MemoryManagerBuilder;
import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.ResettableDriver;
import org.apache.flink.runtime.operators.TaskContext;
import org.apache.flink.runtime.operators.sort.ExternalSorter;
import org.apache.flink.runtime.operators.sort.Sorter;
import org.apache.flink.runtime.operators.testutils.DummyInvokable;
import org.apache.flink.runtime.operators.util.TaskConfig;
import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.extension.ExtendWith;

@ExtendWith(value={ParameterizedTestExtension.class})
public abstract class UnaryOperatorTestBase<S extends Function, IN, OUT>
implements TaskContext<S, OUT> {
    protected static final long DEFAULT_PER_SORT_MEM = 0x1000000L;
    protected static final int PAGE_SIZE = 32768;
    private final TaskManagerRuntimeInfo taskManageInfo;
    private final IOManager ioManager;
    private final MemoryManager memManager;
    private MutableObjectIterator<IN> input;
    private TypeSerializer<IN> inputSerializer;
    private List<TypeComparator<IN>> comparators;
    private Sorter<IN> sorter;
    private final AbstractInvokable owner;
    private final TaskConfig taskConfig;
    protected final long perSortMem;
    protected final double perSortFractionMem;
    private Collector<OUT> output;
    protected int numFileHandles;
    private S stub;
    private Driver<S, OUT> driver;
    private volatile boolean running;
    private ExecutionConfig executionConfig;

    protected UnaryOperatorTestBase(ExecutionConfig executionConfig, long memory, int maxNumSorters) {
        this(executionConfig, memory, maxNumSorters, 0x1000000L);
    }

    protected UnaryOperatorTestBase(ExecutionConfig executionConfig, long memory, int maxNumSorters, long perSortMemory) {
        if (memory < 0L || maxNumSorters < 0 || perSortMemory < 0L) {
            throw new IllegalArgumentException();
        }
        long totalMem = Math.max(memory, 0L) + (long)Math.max(maxNumSorters, 0) * perSortMemory;
        this.perSortMem = perSortMemory;
        this.perSortFractionMem = (double)perSortMemory / (double)totalMem;
        this.ioManager = new IOManagerAsync();
        this.memManager = totalMem > 0L ? MemoryManagerBuilder.newBuilder().setMemorySize(totalMem).build() : null;
        this.owner = new DummyInvokable();
        Configuration config = new Configuration();
        this.taskConfig = new TaskConfig(config);
        this.executionConfig = executionConfig;
        this.comparators = new ArrayList<TypeComparator<IN>>(2);
        this.taskManageInfo = new TestingTaskManagerRuntimeInfo();
    }

    @Parameters
    private static Collection<Object[]> getConfigurations() {
        ExecutionConfig withReuse = new ExecutionConfig();
        withReuse.enableObjectReuse();
        ExecutionConfig withoutReuse = new ExecutionConfig();
        withoutReuse.disableObjectReuse();
        Object[] a = new Object[]{withoutReuse};
        Object[] b = new Object[]{withReuse};
        return Arrays.asList(a, b);
    }

    protected void setInput(MutableObjectIterator<IN> input, TypeSerializer<IN> serializer) {
        this.input = input;
        this.inputSerializer = serializer;
        this.sorter = null;
    }

    protected void addInputSorted(MutableObjectIterator<IN> input, TypeSerializer<IN> serializer, TypeComparator<IN> comp) throws Exception {
        this.input = null;
        this.inputSerializer = serializer;
        this.sorter = ExternalSorter.newBuilder((MemoryManager)this.memManager, (AbstractInvokable)this.owner, (TypeSerializer)this.getInputSerializer(0).getSerializer(), comp).maxNumFileHandles(32).sortBuffers(1).enableSpilling(this.ioManager, (double)0.8f).memoryFraction(this.perSortFractionMem).objectReuse(false).largeRecords(true).build(input);
    }

    protected void addDriverComparator(TypeComparator<IN> comparator) {
        this.comparators.add(comparator);
    }

    protected void setOutput(Collector<OUT> output) {
        this.output = output;
    }

    protected void setOutput(List<OUT> output, TypeSerializer<OUT> outSerializer) {
        this.output = new ListOutputCollector<OUT>(output, outSerializer);
    }

    protected int getNumFileHandlesForSort() {
        return this.numFileHandles;
    }

    protected void setNumFileHandlesForSort(int numFileHandles) {
        this.numFileHandles = numFileHandles;
    }

    protected void testDriver(Driver driver, Class stubClass) throws Exception {
        this.testDriverInternal(driver, stubClass);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void testDriverInternal(Driver driver, Class stubClass) throws Exception {
        this.driver = driver;
        driver.setup((TaskContext)this);
        this.stub = (Function)stubClass.newInstance();
        this.running = true;
        boolean stubOpen = false;
        try {
            try {
                driver.prepare();
            }
            catch (Throwable t) {
                throw new Exception("The data preparation caused an error: " + t.getMessage(), t);
            }
            try {
                FunctionUtils.openFunction(this.stub, (OpenContext)DefaultOpenContext.INSTANCE);
                stubOpen = true;
            }
            catch (Throwable t) {
                throw new Exception("The user defined 'open()' method caused an exception: " + t.getMessage(), t);
            }
            driver.run();
            if (this.running) {
                FunctionUtils.closeFunction(this.stub);
                stubOpen = false;
            }
            this.output.close();
        }
        catch (Exception ex) {
            if (stubOpen) {
                try {
                    FunctionUtils.closeFunction(this.stub);
                }
                catch (Throwable throwable) {
                    // empty catch block
                }
            }
            if (this.driver instanceof ResettableDriver) {
                ResettableDriver resDriver = (ResettableDriver)this.driver;
                try {
                    resDriver.teardown();
                }
                catch (Throwable t) {
                    throw new Exception("Error while shutting down an iterative operator: " + t.getMessage(), t);
                }
            }
            if (this.running) {
                throw ex;
            }
        }
        finally {
            driver.cleanup();
        }
    }

    protected void testResettableDriver(ResettableDriver driver, Class stubClass, int iterations) throws Exception {
        driver.setup((TaskContext)this);
        for (int i = 0; i < iterations; ++i) {
            if (i == 0) {
                driver.initialize();
            } else {
                driver.reset();
            }
            this.testDriver((Driver)driver, stubClass);
        }
        driver.teardown();
    }

    protected void cancel() throws Exception {
        this.running = false;
        this.driver.cancel();
    }

    public TaskConfig getTaskConfig() {
        return this.taskConfig;
    }

    public ExecutionConfig getExecutionConfig() {
        return this.executionConfig;
    }

    public ClassLoader getUserCodeClassLoader() {
        return this.getClass().getClassLoader();
    }

    public IOManager getIOManager() {
        return this.ioManager;
    }

    public MemoryManager getMemoryManager() {
        return this.memManager;
    }

    public TaskManagerRuntimeInfo getTaskManagerInfo() {
        return this.taskManageInfo;
    }

    public <X> MutableObjectIterator<X> getInput(int index) {
        MutableObjectIterator in = this.input;
        if (in == null) {
            try {
                in = this.sorter.getIterator();
            }
            catch (InterruptedException e) {
                throw new RuntimeException("Interrupted");
            }
            catch (IOException e) {
                throw new RuntimeException("IOException");
            }
            this.input = in;
        }
        MutableObjectIterator<IN> input = this.input;
        return input;
    }

    public <X> TypeSerializerFactory<X> getInputSerializer(int index) {
        if (index != 0) {
            throw new IllegalArgumentException();
        }
        TypeSerializer<IN> ser = this.inputSerializer;
        return new RuntimeSerializerFactory(ser, ser.createInstance().getClass());
    }

    public <X> TypeComparator<X> getDriverComparator(int index) {
        TypeComparator<IN> comparator = this.comparators.get(index);
        return comparator;
    }

    public S getStub() {
        return this.stub;
    }

    public Collector<OUT> getOutputCollector() {
        return this.output;
    }

    public AbstractInvokable getContainingTask() {
        return this.owner;
    }

    public String formatLogString(String message) {
        return "Driver Tester: " + message;
    }

    public OperatorMetricGroup getMetricGroup() {
        return UnregisteredMetricGroups.createUnregisteredOperatorMetricGroup();
    }

    @AfterEach
    void shutdownAll() throws Exception {
        if (this.sorter != null) {
            this.sorter.close();
        }
        this.ioManager.close();
        MemoryManager memMan = this.getMemoryManager();
        if (memMan != null) {
            ((AbstractBooleanAssert)Assertions.assertThat((boolean)memMan.verifyEmpty()).withFailMessage("Memory Manager managed memory was not completely freed.", new Object[0])).isTrue();
            memMan.shutdown();
        }
    }

    private static final class ListOutputCollector<OUT>
    implements Collector<OUT> {
        private final List<OUT> output;
        private final TypeSerializer<OUT> serializer;

        public ListOutputCollector(List<OUT> outputList, TypeSerializer<OUT> serializer) {
            this.output = outputList;
            this.serializer = serializer;
        }

        public void collect(OUT record) {
            this.output.add(this.serializer.copy(record));
        }

        public void close() {
        }
    }
}

