/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators.util;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntComparator;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.core.io.IOReadableWritable;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
import org.apache.flink.runtime.operators.shipping.OutputEmitter;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.apache.flink.runtime.plugable.SerializationDelegate;
import org.apache.flink.runtime.testutils.recordutils.RecordComparator;
import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory;
import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory;
import org.apache.flink.types.DeserializationException;
import org.apache.flink.types.DoubleValue;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.NullKeyFieldException;
import org.apache.flink.types.Record;
import org.apache.flink.types.StringValue;
import org.apache.flink.types.Value;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

class OutputEmitterTest {
    OutputEmitterTest() {
    }

    @Test
    void testPartitionHash() {
        this.verifyPartitionHashSelectedChannels(50000, 100, RecordType.INTEGER);
        this.verifyPartitionHashSelectedChannels(10000, 100, RecordType.STRING);
        TestIntComparator testIntComp = new TestIntComparator();
        ChannelSelector selector = this.createChannelSelector(ShipStrategyType.PARTITION_HASH, testIntComp, 100);
        SerializationDelegate serializationDelegate = new SerializationDelegate((TypeSerializer)new IntSerializer());
        this.assertPartitionHashSelectedChannels(selector, (SerializationDelegate<Integer>)serializationDelegate, Integer.MIN_VALUE, 100);
        this.assertPartitionHashSelectedChannels(selector, (SerializationDelegate<Integer>)serializationDelegate, -1, 100);
        this.assertPartitionHashSelectedChannels(selector, (SerializationDelegate<Integer>)serializationDelegate, 0, 100);
        this.assertPartitionHashSelectedChannels(selector, (SerializationDelegate<Integer>)serializationDelegate, 1, 100);
        this.assertPartitionHashSelectedChannels(selector, (SerializationDelegate<Integer>)serializationDelegate, Integer.MAX_VALUE, 100);
    }

    @Test
    void testForward() {
        int numberOfChannels = 100;
        int numRecords = 50050;
        this.verifyForwardSelectedChannels(numRecords, 100, RecordType.INTEGER);
        numRecords = 10050;
        this.verifyForwardSelectedChannels(numRecords, 100, RecordType.STRING);
    }

    @Test
    void testForcedRebalance() {
        int numberOfChannels = 100;
        int toTaskIndex = 85;
        int fromTaskIndex = toTaskIndex + 100;
        int extraRecords = 33;
        int numRecords = 50000 + extraRecords;
        SerializationDelegate delegate = new SerializationDelegate(new RecordSerializerFactory().getSerializer());
        OutputEmitter selector = new OutputEmitter(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex);
        selector.setup(100);
        int[] hits = this.getSelectedChannelsHitCount((ChannelSelector<SerializationDelegate<Record>>)selector, (SerializationDelegate<Record>)delegate, RecordType.INTEGER, numRecords, 100);
        int totalHitCount = 0;
        for (int i = 0; i < hits.length; ++i) {
            if (toTaskIndex <= i || i < toTaskIndex + extraRecords - 100) {
                Assertions.assertThat((int)hits[i]).isEqualTo(numRecords / 100 + 1);
            } else {
                Assertions.assertThat((int)hits[i]).isEqualTo(numRecords / 100);
            }
            totalHitCount += hits[i];
        }
        Assertions.assertThat((int)totalHitCount).isEqualTo(numRecords);
        toTaskIndex = 20;
        fromTaskIndex = toTaskIndex + 200;
        extraRecords = 22;
        numRecords = 10000 + extraRecords;
        OutputEmitter selector2 = new OutputEmitter(ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex);
        selector2.setup(100);
        hits = this.getSelectedChannelsHitCount((ChannelSelector<SerializationDelegate<Record>>)selector2, (SerializationDelegate<Record>)delegate, RecordType.STRING, numRecords, 100);
        totalHitCount = 0;
        for (int i = 0; i < hits.length; ++i) {
            if (toTaskIndex <= i && i < toTaskIndex + extraRecords) {
                Assertions.assertThat((int)hits[i]).isEqualTo(numRecords / 100 + 1);
            } else {
                Assertions.assertThat((int)hits[i]).isEqualTo(numRecords / 100);
            }
            totalHitCount += hits[i];
        }
        Assertions.assertThat((int)totalHitCount).isEqualTo(numRecords);
    }

    @Test
    void testBroadcast() {
        this.verifyBroadcastSelectedChannels(100, 50000, RecordType.INTEGER);
        this.verifyBroadcastSelectedChannels(100, 50000, RecordType.STRING);
    }

    @Test
    void testMultiKeys() {
        int numberOfChannels = 100;
        int numRecords = 5000;
        RecordComparator multiComp = new RecordComparatorFactory(new int[]{0, 1, 3}, new Class[]{IntValue.class, StringValue.class, DoubleValue.class}).createComparator();
        ChannelSelector selector = this.createChannelSelector(ShipStrategyType.PARTITION_HASH, multiComp, 100);
        SerializationDelegate delegate = new SerializationDelegate(new RecordSerializerFactory().getSerializer());
        int[] hits = new int[100];
        for (int i = 0; i < 5000; ++i) {
            int channel;
            Record record = new Record(4);
            record.setField(0, (Value)new IntValue(i));
            record.setField(1, (Value)new StringValue((CharSequence)("AB" + i + "CD" + i)));
            record.setField(3, (Value)new DoubleValue((double)i * 3.141));
            delegate.setInstance((Object)record);
            int n = channel = selector.selectChannel((IOReadableWritable)delegate);
            hits[n] = hits[n] + 1;
        }
        int totalHitCount = 0;
        for (int hit : hits) {
            Assertions.assertThat((int)hit).isGreaterThan(0);
            totalHitCount += hit;
        }
        Assertions.assertThat((int)totalHitCount).isEqualTo(5000);
    }

    @Test
    void testMissingKey() {
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)this.verifyWrongPartitionHashKey(1, 0)).withFailMessage("Expected a KeyFieldOutOfBoundsException.", new Object[0])).isTrue();
    }

    @Test
    void testNullKey() {
        ((AbstractBooleanAssert)Assertions.assertThat((boolean)this.verifyWrongPartitionHashKey(0, 1)).withFailMessage("Expected a NullKeyFieldException.", new Object[0])).isTrue();
    }

    @Test
    void testWrongKeyClass() throws Exception {
        RecordComparator doubleComp = new RecordComparatorFactory(new int[]{0}, new Class[]{DoubleValue.class}).createComparator();
        ChannelSelector selector = this.createChannelSelector(ShipStrategyType.PARTITION_HASH, doubleComp, 100);
        SerializationDelegate delegate = new SerializationDelegate(new RecordSerializerFactory().getSerializer());
        PipedInputStream pipedInput = new PipedInputStream(0x100000);
        DataInputViewStreamWrapper in = new DataInputViewStreamWrapper((InputStream)pipedInput);
        DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper((OutputStream)new PipedOutputStream(pipedInput));
        Record record = new Record(1);
        record.setField(0, (Value)new IntValue());
        record.write((DataOutputView)out);
        record = new Record();
        record.read((DataInputView)in);
        delegate.setInstance((Object)record);
        Assertions.assertThatThrownBy(() -> selector.selectChannel((IOReadableWritable)delegate)).isInstanceOf(DeserializationException.class);
    }

    private void verifyPartitionHashSelectedChannels(int numRecords, int numberOfChannels, Enum recordType) {
        int[] hits = this.getSelectedChannelsHitCount(ShipStrategyType.PARTITION_HASH, numRecords, numberOfChannels, recordType);
        int totalHitCount = 0;
        for (int hit : hits) {
            Assertions.assertThat((int)hit).isGreaterThan(0);
            totalHitCount += hit;
        }
        Assertions.assertThat((int)totalHitCount).isEqualTo(numRecords);
    }

    private void verifyForwardSelectedChannels(int numRecords, int numberOfChannels, Enum recordType) {
        int[] hits = this.getSelectedChannelsHitCount(ShipStrategyType.FORWARD, numRecords, numberOfChannels, recordType);
        Assertions.assertThat((int)hits[0]).isEqualTo(numRecords);
        for (int i = 1; i < hits.length; ++i) {
            Assertions.assertThat((int)hits[i]).isZero();
        }
    }

    private void verifyBroadcastSelectedChannels(int numRecords, int numberOfChannels, Enum recordType) {
        Assertions.assertThatThrownBy(() -> this.getSelectedChannelsHitCount(ShipStrategyType.BROADCAST, numRecords, numberOfChannels, recordType)).isInstanceOf(UnsupportedOperationException.class);
    }

    private boolean verifyWrongPartitionHashKey(int position, int fieldNum) {
        RecordComparator comparator = new RecordComparatorFactory(new int[]{position}, new Class[]{IntValue.class}).createComparator();
        ChannelSelector selector = this.createChannelSelector(ShipStrategyType.PARTITION_HASH, comparator, 100);
        SerializationDelegate delegate = new SerializationDelegate(new RecordSerializerFactory().getSerializer());
        Record record = new Record(2);
        record.setField(fieldNum, (Value)new IntValue(1));
        delegate.setInstance((Object)record);
        try {
            selector.selectChannel((IOReadableWritable)delegate);
        }
        catch (NullKeyFieldException re) {
            Assertions.assertThat((int)re.getFieldNumber()).isEqualTo(position);
            return true;
        }
        return false;
    }

    private int[] getSelectedChannelsHitCount(ShipStrategyType shipStrategyType, int numRecords, int numberOfChannels, Enum recordType) {
        RecordComparator comparator = new RecordComparatorFactory(new int[]{0}, new Class[]{recordType == RecordType.INTEGER ? IntValue.class : StringValue.class}).createComparator();
        ChannelSelector selector = this.createChannelSelector(shipStrategyType, comparator, numberOfChannels);
        SerializationDelegate delegate = new SerializationDelegate(new RecordSerializerFactory().getSerializer());
        return this.getSelectedChannelsHitCount((ChannelSelector<SerializationDelegate<Record>>)selector, (SerializationDelegate<Record>)delegate, recordType, numRecords, numberOfChannels);
    }

    private ChannelSelector createChannelSelector(ShipStrategyType shipStrategyType, TypeComparator comparator, int numberOfChannels) {
        OutputEmitter selector = new OutputEmitter(shipStrategyType, comparator);
        selector.setup(numberOfChannels);
        Assertions.assertThat((shipStrategyType == ShipStrategyType.BROADCAST ? 1 : 0) != 0).isEqualTo(selector.isBroadcast());
        return selector;
    }

    private int[] getSelectedChannelsHitCount(ChannelSelector<SerializationDelegate<Record>> selector, SerializationDelegate<Record> delegate, Enum recordType, int numRecords, int numberOfChannels) {
        int[] hits = new int[numberOfChannels];
        for (int i = 0; i < numRecords; ++i) {
            int channel;
            Object value = recordType == RecordType.INTEGER ? new IntValue(i) : new StringValue((CharSequence)(i + ""));
            Record record = new Record((Value)value);
            delegate.setInstance((Object)record);
            int n = channel = selector.selectChannel(delegate);
            hits[n] = hits[n] + 1;
        }
        return hits;
    }

    private void assertPartitionHashSelectedChannels(ChannelSelector selector, SerializationDelegate<Integer> serializationDelegate, int record, int numberOfChannels) {
        serializationDelegate.setInstance((Object)record);
        int selectedChannel = selector.selectChannel(serializationDelegate);
        Assertions.assertThat((int)selectedChannel).isGreaterThanOrEqualTo(0).isLessThanOrEqualTo(numberOfChannels - 1);
    }

    private static enum RecordType {
        STRING,
        INTEGER;

    }

    private static class TestIntComparator
    extends TypeComparator<Integer> {
        private TypeComparator[] comparators = new TypeComparator[]{new IntComparator(true)};

        private TestIntComparator() {
        }

        public int hash(Integer record) {
            return record;
        }

        public void setReference(Integer toCompare) {
            throw new UnsupportedOperationException();
        }

        public boolean equalToReference(Integer candidate) {
            throw new UnsupportedOperationException();
        }

        public int compareToReference(TypeComparator<Integer> referencedComparator) {
            throw new UnsupportedOperationException();
        }

        public int compare(Integer first, Integer second) {
            throw new UnsupportedOperationException();
        }

        public int compareSerialized(DataInputView firstSource, DataInputView secondSource) {
            throw new UnsupportedOperationException();
        }

        public boolean supportsNormalizedKey() {
            throw new UnsupportedOperationException();
        }

        public boolean supportsSerializationWithKeyNormalization() {
            throw new UnsupportedOperationException();
        }

        public int getNormalizeKeyLen() {
            throw new UnsupportedOperationException();
        }

        public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
            throw new UnsupportedOperationException();
        }

        public void putNormalizedKey(Integer record, MemorySegment target, int offset, int numBytes) {
            throw new UnsupportedOperationException();
        }

        public void writeWithKeyNormalization(Integer record, DataOutputView target) throws IOException {
            throw new UnsupportedOperationException();
        }

        public Integer readWithKeyDenormalization(Integer reuse, DataInputView source) throws IOException {
            throw new UnsupportedOperationException();
        }

        public boolean invertNormalizedKey() {
            throw new UnsupportedOperationException();
        }

        public TypeComparator<Integer> duplicate() {
            throw new UnsupportedOperationException();
        }

        public int extractKeys(Object record, Object[] target, int index) {
            target[index] = record;
            return 1;
        }

        public TypeComparator[] getFlatComparators() {
            return this.comparators;
        }
    }
}

