/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.common.resources;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.dag.api.TezEntityDescriptor;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.runtime.api.MemoryUpdateCallback;
import org.apache.tez.runtime.api.TezInputContext;
import org.apache.tez.runtime.api.TezOutputContext;
import org.apache.tez.runtime.api.TezProcessorContext;
import org.apache.tez.runtime.api.TezTaskContext;

@InterfaceAudience.Private
public class MemoryDistributor {
    private static final Log LOG = LogFactory.getLog(MemoryDistributor.class);
    private final int numTotalInputs;
    private final int numTotalOutputs;
    private AtomicInteger numInputsSeen = new AtomicInteger(0);
    private AtomicInteger numOutputsSeen = new AtomicInteger(0);
    private long totalJvmMemory;
    private volatile long totalAssignableMemory;
    private final boolean isEnabled;
    private final boolean reserveFractionConfigured;
    private float reserveFraction;
    private final Set<TezTaskContext> dupSet = Collections.newSetFromMap(new ConcurrentHashMap());
    private final List<RequestorInfo> requestList;
    @VisibleForTesting
    static final float RESERVE_FRACTION_NO_PROCESSOR = 0.3f;
    @VisibleForTesting
    static final float RESERVE_FRACTION_WITH_PROCESSOR = 0.05f;

    public MemoryDistributor(int numTotalInputs, int numTotalOutputs, Configuration conf) {
        this.isEnabled = conf.getBoolean("tez.task.scale.memory.enabled", true);
        if (conf.get("tez.task.scale.memory.reserve-fraction") != null) {
            this.reserveFractionConfigured = true;
            this.reserveFraction = conf.getFloat("tez.task.scale.memory.reserve-fraction", 0.3f);
            Preconditions.checkArgument((this.reserveFraction >= 0.0f && this.reserveFraction <= 1.0f ? 1 : 0) != 0);
        } else {
            this.reserveFractionConfigured = false;
            this.reserveFraction = 0.3f;
        }
        this.numTotalInputs = numTotalInputs;
        this.numTotalOutputs = numTotalOutputs;
        this.totalJvmMemory = Runtime.getRuntime().maxMemory();
        this.computeAssignableMemory();
        this.requestList = Collections.synchronizedList(new LinkedList());
        LOG.info((Object)("InitialMemoryDistributor (isEnabled=" + this.isEnabled + ") invoked with: numInputs=" + numTotalInputs + ", numOutputs=" + numTotalOutputs + ". Configuration: reserveFractionSpecified= " + this.reserveFractionConfigured + ", reserveFraction=" + this.reserveFraction + ", JVM.maxFree=" + this.totalJvmMemory + ", assignableMemory=" + this.totalAssignableMemory));
    }

    public void requestMemory(long requestSize, MemoryUpdateCallback callback, TezTaskContext taskContext, TezEntityDescriptor descriptor) {
        this.registerRequest(requestSize, callback, taskContext, descriptor);
    }

    public void makeInitialAllocations() {
        Preconditions.checkState((this.numInputsSeen.get() == this.numTotalInputs ? 1 : 0) != 0, (Object)"All inputs are expected to ask for memory");
        Preconditions.checkState((this.numOutputsSeen.get() == this.numTotalOutputs ? 1 : 0) != 0, (Object)"All outputs are expected to ask for memory");
        Iterable requestContexts = Iterables.transform(this.requestList, (Function)new Function<RequestorInfo, RequestContext>(){

            public RequestContext apply(RequestorInfo requestInfo) {
                return requestInfo.getRequestContext();
            }
        });
        Iterable<Long> allocations = null;
        if (!this.isEnabled) {
            allocations = Iterables.transform(this.requestList, (Function)new Function<RequestorInfo, Long>(){

                public Long apply(RequestorInfo requestInfo) {
                    return requestInfo.getRequestContext().getRequestedSize();
                }
            });
        } else {
            ScalingAllocator allocator = new ScalingAllocator();
            allocations = allocator.assignMemory(this.totalAssignableMemory, this.numTotalInputs, this.numTotalOutputs, Iterables.unmodifiableIterable((Iterable)requestContexts));
            this.validateAllocations(allocations, this.requestList.size());
        }
        Iterator allocatedIter = allocations.iterator();
        for (RequestorInfo rInfo : this.requestList) {
            long allocated = (Long)allocatedIter.next();
            LOG.info((Object)("Informing: " + (Object)((Object)rInfo.getRequestContext().getComponentType()) + ", " + rInfo.getRequestContext().getComponentVertexName() + ", " + rInfo.getRequestContext().getComponentClassName() + ": requested=" + rInfo.getRequestContext().getRequestedSize() + ", allocated=" + allocated));
            rInfo.getCallback().memoryAssigned(allocated);
        }
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    void setJvmMemory(long size) {
        this.totalJvmMemory = size;
        this.computeAssignableMemory();
    }

    private void computeAssignableMemory() {
        this.totalAssignableMemory = this.totalJvmMemory - (long)(this.reserveFraction * (float)this.totalJvmMemory);
    }

    private long registerRequest(long requestSize, MemoryUpdateCallback callback, TezTaskContext entityContext, TezEntityDescriptor descriptor) {
        Preconditions.checkArgument((requestSize >= 0L ? 1 : 0) != 0);
        Preconditions.checkNotNull((Object)callback);
        Preconditions.checkNotNull((Object)entityContext);
        Preconditions.checkNotNull((Object)descriptor);
        if (!this.dupSet.add(entityContext)) {
            throw new TezUncheckedException("A single entity can only make one call to request resources for now");
        }
        RequestorInfo requestInfo = new RequestorInfo(entityContext, requestSize, callback, descriptor);
        switch (requestInfo.getRequestContext().getComponentType()) {
            case INPUT: {
                this.numInputsSeen.incrementAndGet();
                Preconditions.checkState((this.numInputsSeen.get() <= this.numTotalInputs ? 1 : 0) != 0, (Object)("Num Requesting Inputs higher than total # of inputs: " + this.numInputsSeen + ", " + this.numTotalInputs));
                break;
            }
            case OUTPUT: {
                this.numOutputsSeen.incrementAndGet();
                Preconditions.checkState((this.numOutputsSeen.get() <= this.numTotalOutputs ? 1 : 0) != 0, (Object)("Num Requesting Inputs higher than total # of outputs: " + this.numOutputsSeen + ", " + this.numTotalOutputs));
            }
            case PROCESSOR: {
                break;
            }
        }
        this.requestList.add(requestInfo);
        if (!this.reserveFractionConfigured && requestInfo.getRequestContext().getComponentType() == RequestContext.ComponentType.PROCESSOR) {
            this.reserveFraction = 0.05f;
            this.computeAssignableMemory();
            LOG.info((Object)("Processor request for initial memory. Updating assignableMemory to : " + this.totalAssignableMemory));
        }
        return -1L;
    }

    private void validateAllocations(Iterable<Long> allocations, int numRequestors) {
        Preconditions.checkNotNull(allocations);
        long totalAllocated = 0L;
        int numAllocations = 0;
        for (Long l : allocations) {
            totalAllocated += l.longValue();
            ++numAllocations;
        }
        Preconditions.checkState((numAllocations == numRequestors ? 1 : 0) != 0, (Object)("Number of allocations must match number of requestors. Allocated=" + numAllocations + ", Requests: " + numRequestors));
        Preconditions.checkState((totalAllocated <= this.totalAssignableMemory ? 1 : 0) != 0, (Object)("Total allocation should be <= availableMem. TotalAllocated: " + totalAllocated + ", totalAssignable: " + this.totalAssignableMemory));
    }

    private static class ScalingAllocator
    implements InitialMemoryAllocator {
        private ScalingAllocator() {
        }

        @Override
        public Iterable<Long> assignMemory(long availableForAllocation, int numTotalInputs, int numTotalOutputs, Iterable<RequestContext> requests) {
            int numRequests = 0;
            long totalRequested = 0L;
            for (RequestContext context : requests) {
                totalRequested += context.getRequestedSize();
                ++numRequests;
            }
            long totalJvmMem = Runtime.getRuntime().maxMemory();
            double ratio = (double)totalRequested / (double)totalJvmMem;
            LOG.info((Object)("Scaling Requests. TotalRequested: " + totalRequested + ", TotalJVMMem: " + totalJvmMem + ", TotalAvailable: " + availableForAllocation + ", TotalRequested/TotalHeap:" + new DecimalFormat("0.00").format(ratio)));
            if (totalRequested < availableForAllocation || totalRequested == 0L) {
                return Lists.newArrayList((Iterable)Iterables.transform(requests, (Function)new Function<RequestContext, Long>(){

                    public Long apply(RequestContext requestContext) {
                        return requestContext.getRequestedSize();
                    }
                }));
            }
            ArrayList allocations = Lists.newArrayListWithCapacity((int)numRequests);
            for (RequestContext request : requests) {
                long requestedSize = request.getRequestedSize();
                if (requestedSize == 0L) {
                    allocations.add(0L);
                    if (!LOG.isDebugEnabled()) continue;
                    LOG.debug((Object)"Scaling requested: 0 to allocated: 0");
                    continue;
                }
                long allocated = (long)((double)requestedSize / (double)totalRequested * (double)availableForAllocation);
                allocations.add(allocated);
                if (!LOG.isDebugEnabled()) continue;
                LOG.debug((Object)("Scaling requested: " + requestedSize + " to allocated: " + allocated));
            }
            return allocations;
        }
    }

    @InterfaceAudience.Private
    private static class RequestorInfo {
        private final MemoryUpdateCallback callback;
        private final RequestContext requestContext;

        RequestorInfo(TezTaskContext taskContext, long requestSize, MemoryUpdateCallback callback, TezEntityDescriptor descriptor) {
            String componentVertexName;
            RequestContext.ComponentType type;
            if (taskContext instanceof TezInputContext) {
                type = RequestContext.ComponentType.INPUT;
                componentVertexName = ((TezInputContext)taskContext).getSourceVertexName();
            } else if (taskContext instanceof TezOutputContext) {
                type = RequestContext.ComponentType.OUTPUT;
                componentVertexName = ((TezOutputContext)taskContext).getDestinationVertexName();
            } else if (taskContext instanceof TezProcessorContext) {
                type = RequestContext.ComponentType.PROCESSOR;
                componentVertexName = ((TezProcessorContext)taskContext).getTaskVertexName();
            } else {
                throw new IllegalArgumentException("Unknown type of entityContext: " + taskContext.getClass().getName());
            }
            this.requestContext = new RequestContext(requestSize, descriptor.getClassName(), type, componentVertexName);
            this.callback = callback;
            LOG.info((Object)("Received request: " + requestSize + ", type: " + (Object)((Object)type) + ", componentVertexName: " + componentVertexName));
        }

        public MemoryUpdateCallback getCallback() {
            return this.callback;
        }

        public RequestContext getRequestContext() {
            return this.requestContext;
        }
    }

    private static class RequestContext {
        private long requestedSize;
        private String componentClassName;
        private ComponentType componentType;
        private String componentVertexName;

        public RequestContext(long requestedSize, String componentClassName, ComponentType componentType, String componentVertexName) {
            this.requestedSize = requestedSize;
            this.componentClassName = componentClassName;
            this.componentType = componentType;
            this.componentVertexName = componentVertexName;
        }

        public long getRequestedSize() {
            return this.requestedSize;
        }

        public String getComponentClassName() {
            return this.componentClassName;
        }

        public ComponentType getComponentType() {
            return this.componentType;
        }

        public String getComponentVertexName() {
            return this.componentVertexName;
        }

        private static enum ComponentType {
            INPUT,
            OUTPUT,
            PROCESSOR;

        }
    }

    private static interface InitialMemoryAllocator {
        public Iterable<Long> assignMemory(long var1, int var3, int var4, Iterable<RequestContext> var5);
    }
}

