/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.state.api;

import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.io.OutputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.connector.sink.lib.OutputFormatSink;
import org.apache.flink.api.connector.sink2.Sink;
import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.metadata.CheckpointMetadata;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.state.api.OperatorIdentifier;
import org.apache.flink.state.api.StateBootstrapTransformation;
import org.apache.flink.state.api.output.FileCopyFunction;
import org.apache.flink.state.api.output.MergeOperatorStates;
import org.apache.flink.state.api.output.SavepointOutputFormat;
import org.apache.flink.state.api.output.StatePathExtractor;
import org.apache.flink.state.api.output.operators.GroupReduceOperator;
import org.apache.flink.state.api.runtime.SavepointLoader;
import org.apache.flink.state.api.runtime.StateBootstrapTransformationWithID;
import org.apache.flink.state.api.runtime.metadata.SavepointMetadataV2;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;

@PublicEvolving
public class SavepointWriter {
    @Nullable
    private final StreamExecutionEnvironment executionEnvironment;
    private final Map<OperatorIdentifier, OperatorIdentifier> uidTransformationMap = new HashMap<OperatorIdentifier, OperatorIdentifier>();
    protected final SavepointMetadataV2 metadata;
    @Nullable
    protected final StateBackend stateBackend;
    private final Configuration configuration;

    public static SavepointWriter fromExistingSavepoint(StreamExecutionEnvironment executionEnvironment, String path) throws IOException {
        return new SavepointWriter(SavepointWriter.readSavepointMetadata(path), null, executionEnvironment);
    }

    public static SavepointWriter fromExistingSavepoint(StreamExecutionEnvironment executionEnvironment, String path, StateBackend stateBackend) throws IOException {
        return new SavepointWriter(SavepointWriter.readSavepointMetadata(path), stateBackend, executionEnvironment);
    }

    private static SavepointMetadataV2 readSavepointMetadata(String path) throws IOException {
        CheckpointMetadata metadata = SavepointLoader.loadSavepointMetadata(path);
        int maxParallelism = metadata.getOperatorStates().stream().map(OperatorState::getMaxParallelism).max(Comparator.naturalOrder()).orElseThrow(() -> new RuntimeException("Savepoint must contain at least one operator state."));
        return new SavepointMetadataV2(metadata.getCheckpointId(), maxParallelism, metadata.getMasterStates(), metadata.getOperatorStates());
    }

    public static SavepointWriter newSavepoint(StreamExecutionEnvironment executionEnvironment, int maxParallelism) {
        return new SavepointWriter(SavepointWriter.createSavepointMetadata(0L, maxParallelism), null, executionEnvironment);
    }

    public static SavepointWriter newSavepoint(StreamExecutionEnvironment executionEnvironment, long checkpointId, int maxParallelism) {
        return new SavepointWriter(SavepointWriter.createSavepointMetadata(checkpointId, maxParallelism), null, executionEnvironment);
    }

    public static SavepointWriter newSavepoint(StreamExecutionEnvironment executionEnvironment, StateBackend stateBackend, int maxParallelism) {
        return new SavepointWriter(SavepointWriter.createSavepointMetadata(0L, maxParallelism), stateBackend, executionEnvironment);
    }

    public static SavepointWriter newSavepoint(StreamExecutionEnvironment executionEnvironment, StateBackend stateBackend, long checkpointId, int maxParallelism) {
        return new SavepointWriter(SavepointWriter.createSavepointMetadata(checkpointId, maxParallelism), stateBackend, executionEnvironment);
    }

    private static SavepointMetadataV2 createSavepointMetadata(long checkpointId, int maxParallelism) {
        Preconditions.checkArgument((maxParallelism > 0 && maxParallelism <= 32768 ? 1 : 0) != 0, (Object)("Maximum parallelism must be between 1 and 32768. Found: " + maxParallelism));
        return new SavepointMetadataV2(checkpointId, maxParallelism, Collections.emptyList(), Collections.emptyList());
    }

    private SavepointWriter(SavepointMetadataV2 metadata, @Nullable StateBackend stateBackend, @Nullable StreamExecutionEnvironment executionEnvironment) {
        Preconditions.checkNotNull((Object)metadata, (String)"The savepoint metadata must not be null");
        this.metadata = metadata;
        this.stateBackend = stateBackend;
        this.configuration = new Configuration();
        this.executionEnvironment = executionEnvironment;
    }

    public SavepointWriter removeOperator(OperatorIdentifier identifier) {
        this.metadata.removeOperator(identifier);
        return this;
    }

    public <T> SavepointWriter withOperator(OperatorIdentifier identifier, StateBootstrapTransformation<T> transformation) {
        this.metadata.addOperator(identifier, transformation);
        return this;
    }

    public <T> SavepointWriter withConfiguration(ConfigOption<T> option, T value) {
        this.configuration.set(option, value);
        return this;
    }

    public SavepointWriter changeOperatorIdentifier(OperatorIdentifier from, OperatorIdentifier to) {
        this.uidTransformationMap.put(from, to);
        return this;
    }

    public final void write(String path) {
        Path savepointPath = new Path(path);
        List<StateBootstrapTransformationWithID<?>> newOperatorTransformations = this.metadata.getNewOperators();
        Optional<DataStream<OperatorState>> newOperatorStates = this.writeOperatorStates(newOperatorTransformations, this.configuration, savepointPath);
        if (this.executionEnvironment == null && newOperatorStates.isEmpty()) {
            throw new IllegalStateException("Savepoint must contain at least one operator if no execution environment was provided.");
        }
        List<OperatorState> existingOperators = this.metadata.getExistingOperators();
        if (newOperatorStates.isEmpty() && existingOperators.isEmpty()) {
            throw new IllegalStateException("Savepoint must contain at least one operator to be created.");
        }
        SavepointWriter.getFinalOperatorStates(this.executionEnvironment != null ? this.executionEnvironment : newOperatorStates.get().getExecutionEnvironment(), existingOperators, (DataStream<OperatorState>)((DataStream)newOperatorStates.orElse(null)), path).transform("reduce(OperatorState)", TypeInformation.of(CheckpointMetadata.class), new GroupReduceOperator<OperatorState, CheckpointMetadata>(new MergeOperatorStates(this.metadata.getCheckpointId(), this.metadata.getMasterStates()))).forceNonParallel().map((MapFunction)new CheckpointMetadataCheckpointMetadataMapFunction(this.uidTransformationMap)).setParallelism(1).sinkTo((Sink)new OutputFormatSink((OutputFormat)new SavepointOutputFormat(savepointPath))).setParallelism(1).name(path);
    }

    private static DataStream<OperatorState> getFinalOperatorStates(StreamExecutionEnvironment executionEnvironment, List<OperatorState> existingOperators, @Nullable DataStream<OperatorState> newOperatorStates, String path) {
        if (existingOperators.isEmpty()) {
            return newOperatorStates;
        }
        SingleOutputStreamOperator existingOperatorStates = executionEnvironment.fromData(existingOperators).name("existingOperatorStates");
        existingOperatorStates.flatMap((FlatMapFunction)new StatePathExtractor()).setParallelism(1).sinkTo((Sink)new OutputFormatSink((OutputFormat)new FileCopyFunction(path)));
        return newOperatorStates != null ? newOperatorStates.union(new DataStream[]{existingOperatorStates}) : existingOperatorStates;
    }

    private Optional<DataStream<OperatorState>> writeOperatorStates(List<StateBootstrapTransformationWithID<?>> newOperatorStates, Configuration config, Path savepointWritePath) {
        return newOperatorStates.stream().map(newOperatorState -> newOperatorState.getBootstrapTransformation().writeOperatorState(newOperatorState.getOperatorIdentifier(), this.stateBackend, config, this.metadata.getMaxParallelism(), savepointWritePath)).reduce((rec$, xva$0) -> ((DataStream)rec$).union(new DataStream[]{xva$0}));
    }

    private static class CheckpointMetadataCheckpointMetadataMapFunction
    extends RichMapFunction<CheckpointMetadata, CheckpointMetadata> {
        private static final long serialVersionUID = 1L;
        private final Map<OperatorIdentifier, OperatorIdentifier> uidTransformationMap;

        public CheckpointMetadataCheckpointMetadataMapFunction(Map<OperatorIdentifier, OperatorIdentifier> uidTransformationMap) {
            this.uidTransformationMap = new HashMap<OperatorIdentifier, OperatorIdentifier>(uidTransformationMap);
        }

        public CheckpointMetadata map(CheckpointMetadata value) throws Exception {
            List mapped = value.getOperatorStates().stream().map((? super T operatorState) -> {
                OperatorIdentifier operatorIdentifier = operatorState.getOperatorUid().isPresent() ? OperatorIdentifier.forUid((String)operatorState.getOperatorUid().get()) : OperatorIdentifier.forUidHash(operatorState.getOperatorID().toHexString());
                OperatorIdentifier transformedIdentifier = this.uidTransformationMap.remove(operatorIdentifier);
                if (transformedIdentifier != null) {
                    return operatorState.copyWithNewIDs((String)transformedIdentifier.getUid().orElse(null), transformedIdentifier.getOperatorId());
                }
                return operatorState;
            }).collect(Collectors.toList());
            return new CheckpointMetadata(value.getCheckpointId(), mapped, value.getMasterStates());
        }

        public void close() throws Exception {
            if (!this.uidTransformationMap.isEmpty()) {
                throw new FlinkRuntimeException("Some identifier changes were never applied!" + this.uidTransformationMap.entrySet().stream().map(Object::toString).collect(Collectors.joining("\n\t", "\n\t", "")));
            }
        }
    }
}

