/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.resources;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.runtime.common.resources.InitialMemoryAllocator;
import org.apache.tez.runtime.common.resources.InitialMemoryRequestContext;
import org.apache.tez.runtime.library.input.OrderedGroupedInputLegacy;
import org.apache.tez.runtime.library.input.OrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.UnorderedKVInput;
import org.apache.tez.runtime.library.output.OrderedPartitionedKVOutput;
import org.apache.tez.runtime.library.output.UnorderedPartitionedKVOutput;

@InterfaceAudience.Public
@InterfaceStability.Unstable
public class WeightedScalingMemoryDistributor
implements InitialMemoryAllocator {
    private static final Log LOG = LogFactory.getLog(WeightedScalingMemoryDistributor.class);
    static final double MAX_ADDITIONAL_RESERVATION_FRACTION_PER_IO = 0.1;
    static final double RESERVATION_FRACTION_PER_IO = 0.015;
    static final String[] DEFAULT_TASK_MEMORY_WEIGHTED_RATIOS = WeightedScalingMemoryDistributor.generateWeightStrings(1, 1, 12, 12, 1, 1);
    private Configuration conf;
    private EnumMap<RequestType, Integer> typeScaleMap = Maps.newEnumMap(RequestType.class);
    private int numRequests = 0;
    private int numRequestsScaled = 0;
    private long totalRequested = 0L;
    private List<Request> requests = Lists.newArrayList();

    public Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs, int numTotalOutputs, Iterable<InitialMemoryRequestContext> initialRequests) {
        this.populateTypeScaleMap();
        for (InitialMemoryRequestContext context : initialRequests) {
            this.initialProcessMemoryRequestContext(context);
        }
        if (this.numRequestsScaled == 0) {
            this.numRequestsScaled = this.numRequests;
            for (Request request : this.requests) {
                request.requestWeight = 1;
            }
        }
        double totalScaledRequest = 0.0;
        for (Request request : this.requests) {
            double requested = (double)request.requestSize * ((double)request.requestWeight / (double)this.numRequestsScaled);
            totalScaledRequest += requested;
        }
        double reserveFraction = this.computeReservedFraction(this.numRequests);
        Preconditions.checkState((reserveFraction >= 0.0 && reserveFraction <= 1.0 ? 1 : 0) != 0);
        availableForAllocation = (long)((double)availableForAllocation - reserveFraction * (double)availableForAllocation);
        long totalJvmMem = Runtime.getRuntime().maxMemory();
        double ratio = (double)this.totalRequested / (double)totalJvmMem;
        LOG.info((Object)("Scaling Requests. NumRequests: " + this.numRequests + ", numScaledRequests: " + this.numRequestsScaled + ", TotalRequested: " + this.totalRequested + ", TotalRequestedScaled: " + totalScaledRequest + ", TotalJVMHeap: " + totalJvmMem + ", TotalAvailable: " + availableForAllocation + ", TotalRequested/TotalJVMHeap:" + new DecimalFormat("0.00").format(ratio)));
        ArrayList allocations = Lists.newArrayListWithCapacity((int)this.numRequests);
        for (Request request : this.requests) {
            if (request.requestSize == 0L) {
                allocations.add(0L);
                if (!LOG.isDebugEnabled()) continue;
                LOG.debug((Object)("Scaling requested " + request.componentClassname + " of type " + (Object)((Object)request.requestType) + " 0 to allocated: 0"));
                continue;
            }
            double requestFactor = (double)request.requestWeight / (double)this.numRequestsScaled;
            double scaledRequest = requestFactor * (double)request.requestSize;
            long allocated = Math.min((long)(scaledRequest / totalScaledRequest * (double)availableForAllocation), request.requestSize);
            allocations.add(allocated);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug((Object)("Scaling requested " + request.componentClassname + " of type " + (Object)((Object)request.requestType) + " " + request.requestSize + "  to allocated: " + allocated));
        }
        return allocations;
    }

    private void initialProcessMemoryRequestContext(InitialMemoryRequestContext context) {
        ++this.numRequests;
        this.totalRequested += context.getRequestedSize();
        String className = context.getComponentClassName();
        RequestType requestType = this.getRequestTypeForClass(className);
        Integer typeScaleFactor = this.getScaleFactorForType(requestType);
        Request request = new Request(context.getComponentClassName(), context.getRequestedSize(), requestType, typeScaleFactor);
        this.requests.add(request);
        LOG.info((Object)("ScaleFactor: " + typeScaleFactor + ", for type: " + (Object)((Object)requestType)));
        this.numRequestsScaled += typeScaleFactor.intValue();
    }

    private Integer getScaleFactorForType(RequestType requestType) {
        Integer typeScaleFactor = this.typeScaleMap.get((Object)requestType);
        if (typeScaleFactor == null) {
            LOG.warn((Object)("Bad scale factor for requestType: " + (Object)((Object)requestType) + ", Using factor 0"));
            typeScaleFactor = 0;
        }
        return typeScaleFactor;
    }

    private RequestType getRequestTypeForClass(String className) {
        RequestType requestType;
        if (className.equals(OrderedPartitionedKVOutput.class.getName())) {
            requestType = RequestType.SORTED_OUTPUT;
        } else if (className.equals(OrderedGroupedKVInput.class.getName()) || className.equals(OrderedGroupedInputLegacy.class.getName())) {
            requestType = RequestType.SORTED_MERGED_INPUT;
        } else if (className.equals(UnorderedKVInput.class.getName())) {
            requestType = RequestType.UNSORTED_INPUT;
        } else if (className.equals(UnorderedPartitionedKVOutput.class.getName())) {
            requestType = RequestType.PARTITIONED_UNSORTED_OUTPUT;
        } else {
            requestType = RequestType.OTHER;
            LOG.info((Object)("Falling back to RequestType.OTHER for class: " + className));
        }
        return requestType;
    }

    private void populateTypeScaleMap() {
        String[] ratios = this.conf.getStrings("tez.task.scale.task.memory.ratios", DEFAULT_TASK_MEMORY_WEIGHTED_RATIOS);
        int numExpectedValues = RequestType.values().length;
        if (ratios == null) {
            LOG.info((Object)"No ratio specified. Falling back to Linear scaling");
            ratios = new String[numExpectedValues];
            int i = 0;
            for (RequestType requestType : RequestType.values()) {
                ratios[i] = requestType.name() + ":1";
                ++i;
            }
        } else if (ratios.length != RequestType.values().length) {
            throw new IllegalArgumentException("Number of entries in the configured ratios should be equal to the number of entries in RequestType: " + numExpectedValues);
        }
        HashSet<RequestType> seenTypes = new HashSet<RequestType>();
        for (String string : ratios) {
            String[] parts = string.split(":");
            Preconditions.checkState((parts.length == 2 ? 1 : 0) != 0);
            RequestType requestType = RequestType.valueOf(parts[0]);
            Integer ratioVal = Integer.parseInt(parts[1]);
            if (!seenTypes.add(requestType)) {
                throw new IllegalArgumentException("Cannot configure the same RequestType: " + (Object)((Object)requestType) + " multiple times");
            }
            Preconditions.checkState((ratioVal >= 0 ? 1 : 0) != 0, (Object)"Ratio must be >= 0");
            this.typeScaleMap.put(requestType, ratioVal);
        }
    }

    private double computeReservedFraction(int numTotalRequests) {
        double additionalReserveFraction;
        double initialReserveFraction;
        double reserveFraction;
        double reserveFractionPerIo = this.conf.getDouble("tez.task.scale.task.memory.additional-reservation.fraction.per-io", 0.015);
        double maxAdditionalReserveFraction = this.conf.getDouble("tez.task.scale.task.memory.additional-reservation.fraction.max", 0.1);
        Preconditions.checkArgument((maxAdditionalReserveFraction >= 0.0 && maxAdditionalReserveFraction <= 1.0 ? 1 : 0) != 0);
        Preconditions.checkArgument((reserveFractionPerIo <= maxAdditionalReserveFraction && reserveFractionPerIo >= 0.0 ? 1 : 0) != 0);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("ReservationFractionPerIO=" + reserveFractionPerIo + ", MaxPerIOReserveFraction=" + maxAdditionalReserveFraction));
        }
        Preconditions.checkState(((reserveFraction = (initialReserveFraction = this.conf.getDouble("tez.task.scale.task.memory.reserve-fraction", 0.3)) + (additionalReserveFraction = Math.min(maxAdditionalReserveFraction, (double)numTotalRequests * reserveFractionPerIo))) <= 1.0 ? 1 : 0) != 0);
        LOG.info((Object)("InitialReservationFraction=" + initialReserveFraction + ", AdditionalReservationFractionForIOs=" + additionalReserveFraction + ", finalReserveFractionUsed=" + reserveFraction));
        return reserveFraction;
    }

    public static String[] generateWeightStrings(int unsortedPartitioned, int broadcastIn, int sortedOut, int scatterGatherShuffleIn, int proc, int other) {
        String[] weights = new String[RequestType.values().length];
        weights[0] = RequestType.PARTITIONED_UNSORTED_OUTPUT.name() + ":" + unsortedPartitioned;
        weights[1] = RequestType.UNSORTED_OUTPUT.name() + ":" + 0;
        weights[2] = RequestType.UNSORTED_INPUT.name() + ":" + broadcastIn;
        weights[3] = RequestType.SORTED_OUTPUT.name() + ":" + sortedOut;
        weights[4] = RequestType.SORTED_MERGED_INPUT.name() + ":" + scatterGatherShuffleIn;
        weights[5] = RequestType.PROCESSOR.name() + ":" + proc;
        weights[6] = RequestType.OTHER.name() + ":" + other;
        return weights;
    }

    public void setConf(Configuration conf) {
        this.conf = conf;
    }

    public Configuration getConf() {
        return this.conf;
    }

    private static class Request {
        String componentClassname;
        long requestSize;
        private RequestType requestType;
        private int requestWeight;

        Request(String componentClassname, long requestSize, RequestType requestType, int requestWeight) {
            this.componentClassname = componentClassname;
            this.requestSize = requestSize;
            this.requestType = requestType;
            this.requestWeight = requestWeight;
        }
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    public static enum RequestType {
        PARTITIONED_UNSORTED_OUTPUT,
        UNSORTED_INPUT,
        UNSORTED_OUTPUT,
        SORTED_OUTPUT,
        SORTED_MERGED_INPUT,
        PROCESSOR,
        OTHER;

    }
}

