/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.spark;

import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.beam.runners.core.UnboundedReadFromBoundedSource;
import org.apache.beam.runners.core.construction.PTransformMatchers;
import org.apache.beam.runners.core.construction.ReplacementOutputs;
import org.apache.beam.runners.spark.SparkPipelineResult;
import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.TestSparkPipelineOptions;
import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator;
import org.apache.beam.runners.spark.metrics.SparkMetricsContainer;
import org.apache.beam.runners.spark.util.GlobalWatermarkHolder;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.BoundedReadFromUnboundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsValidator;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.runners.PipelineRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.util.ValueWithRecordId;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TaggedPValue;
import org.apache.beam.spark.repackaged.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.spark.repackaged.com.google.common.base.Preconditions;
import org.apache.commons.io.FileUtils;
import org.hamcrest.Matcher;
import org.hamcrest.MatcherAssert;
import org.hamcrest.Matchers;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class TestSparkRunner
extends PipelineRunner<SparkPipelineResult> {
    private static final Logger LOG = LoggerFactory.getLogger(TestSparkRunner.class);
    private SparkRunner delegate;
    private boolean isForceStreaming;

    private TestSparkRunner(TestSparkPipelineOptions options) {
        this.delegate = SparkRunner.fromOptions(options);
        this.isForceStreaming = options.isForceStreaming();
    }

    public static TestSparkRunner fromOptions(PipelineOptions options) {
        TestSparkPipelineOptions sparkOptions = (TestSparkPipelineOptions)PipelineOptionsValidator.validate(TestSparkPipelineOptions.class, (PipelineOptions)options);
        return new TestSparkRunner(sparkOptions);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public SparkPipelineResult run(Pipeline pipeline) {
        TestSparkPipelineOptions testSparkPipelineOptions = (TestSparkPipelineOptions)pipeline.getOptions().as(TestSparkPipelineOptions.class);
        if (this.isForceStreaming) {
            this.adaptBoundedReads(pipeline);
        }
        SparkPipelineResult result = null;
        int expectedNumberOfAssertions = PAssert.countAsserts((Pipeline)pipeline);
        AggregatorsAccumulator.clear();
        SparkMetricsContainer.clear();
        GlobalWatermarkHolder.clear();
        LOG.info("About to run test pipeline " + testSparkPipelineOptions.getJobName());
        if (this.isForceStreaming) {
            try {
                result = this.delegate.run(pipeline);
                Long timeout = testSparkPipelineOptions.getTestTimeoutSeconds();
                result.waitUntilFinish(Duration.standardSeconds((long)Preconditions.checkNotNull(timeout)));
                int successAssertions = result.getAggregatorValue("PAssertSuccess", Integer.class);
                MatcherAssert.assertThat((String)String.format("Expected %d successful assertions, but found %d.", expectedNumberOfAssertions, successAssertions), (Object)successAssertions, (Matcher)Matchers.is((Object)expectedNumberOfAssertions));
                int failedAssertions = result.getAggregatorValue("PAssertFailure", Integer.class);
                MatcherAssert.assertThat((String)String.format("Found %d failed assertions.", failedAssertions), (Object)failedAssertions, (Matcher)Matchers.is((Object)0));
                LOG.info(String.format("Successfully asserted pipeline %s with %d successful assertions.", testSparkPipelineOptions.getJobName(), successAssertions));
            }
            finally {
                try {
                    FileUtils.deleteDirectory((File)new File(testSparkPipelineOptions.getCheckpointDir()));
                }
                catch (IOException e) {
                    throw new RuntimeException("Failed to clear checkpoint tmp dir.", e);
                }
            }
        }
        result = this.delegate.run(pipeline);
        result.waitUntilFinish();
        MatcherAssert.assertThat((Object)result, (Matcher)testSparkPipelineOptions.getOnCreateMatcher());
        MatcherAssert.assertThat((Object)result, (Matcher)testSparkPipelineOptions.getOnSuccessMatcher());
        return result;
    }

    @VisibleForTesting
    void adaptBoundedReads(Pipeline pipeline) {
        pipeline.replace(PTransformMatchers.classEqualTo(BoundedReadFromUnboundedSource.class), new AdaptedBoundedAsUnbounded.Factory());
    }

    private static class AdaptedBoundedAsUnbounded<T>
    extends PTransform<PBegin, PCollection<T>> {
        private final BoundedReadFromUnboundedSource<T> source;

        AdaptedBoundedAsUnbounded(BoundedReadFromUnboundedSource<T> source) {
            this.source = source;
        }

        public PCollection<T> expand(PBegin input) {
            UnboundedReadFromBoundedSource replacingTransform = new UnboundedReadFromBoundedSource(this.source.getAdaptedSource());
            return (PCollection)((PCollection)input.apply((PTransform)replacingTransform)).apply("StripIds", (PTransform)ParDo.of((DoFn)new ValueWithRecordId.StripIdsDoFn()));
        }

        static class Factory<T>
        implements PTransformOverrideFactory<PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>> {
            Factory() {
            }

            public PTransform<PBegin, PCollection<T>> getReplacementTransform(BoundedReadFromUnboundedSource<T> transform) {
                return new AdaptedBoundedAsUnbounded<T>(transform);
            }

            public PBegin getInput(List<TaggedPValue> inputs, Pipeline p) {
                return p.begin();
            }

            public Map<PValue, PTransformOverrideFactory.ReplacementOutput> mapOutputs(List<TaggedPValue> outputs, PCollection<T> newOutput) {
                return ReplacementOutputs.singleton(outputs, newOutput);
            }
        }
    }
}

