/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.common.resources;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.collect.Iterables;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.common.Preconditions;
import org.apache.tez.common.ReflectionUtils;
import org.apache.tez.dag.api.EntityDescriptor;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.dag.api.TezUncheckedException;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.MemoryUpdateCallback;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.api.ProcessorContext;
import org.apache.tez.runtime.api.TaskContext;
import org.apache.tez.runtime.common.resources.InitialMemoryAllocator;
import org.apache.tez.runtime.common.resources.InitialMemoryRequestContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@InterfaceAudience.Private
public class MemoryDistributor {
    private static final Logger LOG = LoggerFactory.getLogger(MemoryDistributor.class);
    private final int numTotalInputs;
    private final int numTotalOutputs;
    private final Configuration conf;
    private AtomicInteger numInputsSeen = new AtomicInteger(0);
    private AtomicInteger numOutputsSeen = new AtomicInteger(0);
    private long totalJvmMemory;
    private final boolean isEnabled;
    private final boolean isInputOutputConcurrent;
    private final String allocatorClassName;
    private final Set<TaskContext> dupSet = Collections.newSetFromMap(new ConcurrentHashMap());
    private final List<RequestorInfo> requestList;

    public MemoryDistributor(int numTotalInputs, int numTotalOutputs, Configuration conf) {
        this.conf = conf;
        this.isEnabled = conf.getBoolean("tez.task.scale.memory.enabled", true);
        this.isInputOutputConcurrent = conf.getBoolean("tez.task.scale.memory.input-output-concurrent", true);
        this.allocatorClassName = this.isEnabled ? conf.get("tez.task.scale.memory.allocator.class", "org.apache.tez.runtime.library.resources.WeightedScalingMemoryDistributor") : null;
        this.numTotalInputs = numTotalInputs;
        this.numTotalOutputs = numTotalOutputs;
        this.totalJvmMemory = Runtime.getRuntime().maxMemory();
        this.requestList = Collections.synchronizedList(new LinkedList());
        LOG.info("InitialMemoryDistributor (isEnabled=" + this.isEnabled + ") invoked with: numInputs=" + numTotalInputs + ", numOutputs=" + numTotalOutputs + ", JVM.maxFree=" + this.totalJvmMemory + ", allocatorClassName=" + this.allocatorClassName);
    }

    public void requestMemory(long requestSize, MemoryUpdateCallback callback, TaskContext taskContext, EntityDescriptor<?> descriptor) {
        this.registerRequest(requestSize, callback, taskContext, descriptor);
    }

    public void makeInitialAllocations() throws TezException {
        Preconditions.checkState((this.numInputsSeen.get() == this.numTotalInputs ? 1 : 0) != 0, (Object)"All inputs are expected to ask for memory");
        Preconditions.checkState((this.numOutputsSeen.get() == this.numTotalOutputs ? 1 : 0) != 0, (Object)"All outputs are expected to ask for memory");
        this.logInitialRequests(this.requestList);
        Iterable requestContexts = Iterables.transform(this.requestList, (Function)new Function<RequestorInfo, InitialMemoryRequestContext>(){

            public InitialMemoryRequestContext apply(RequestorInfo requestInfo) {
                return requestInfo.getRequestContext();
            }
        });
        Iterable allocations = null;
        if (!this.isEnabled) {
            allocations = Iterables.transform(this.requestList, (Function)new Function<RequestorInfo, Long>(){

                public Long apply(RequestorInfo requestInfo) {
                    return requestInfo.getRequestContext().getRequestedSize();
                }
            });
        } else {
            InitialMemoryAllocator allocator = (InitialMemoryAllocator)ReflectionUtils.createClazzInstance((String)this.allocatorClassName);
            allocator.setConf(this.conf);
            allocations = allocator.assignMemory(this.totalJvmMemory, this.numTotalInputs, this.numTotalOutputs, Iterables.unmodifiableIterable((Iterable)requestContexts));
            this.validateAllocations(allocations, this.requestList.size());
            this.logFinalAllocations(allocations, this.requestList);
        }
        Iterator allocatedIter = allocations.iterator();
        for (RequestorInfo rInfo : this.requestList) {
            long allocated = (Long)allocatedIter.next();
            if (LOG.isDebugEnabled()) {
                LOG.info("Informing: " + rInfo.getRequestContext().getComponentType() + ", " + rInfo.getRequestContext().getComponentVertexName() + ", " + rInfo.getRequestContext().getComponentClassName() + ": requested=" + rInfo.getRequestContext().getRequestedSize() + ", allocated=" + allocated);
            }
            rInfo.getCallback().memoryAssigned(allocated);
        }
    }

    @InterfaceAudience.Private
    @VisibleForTesting
    void setJvmMemory(long size) {
        this.totalJvmMemory = size;
    }

    private long registerRequest(long requestSize, MemoryUpdateCallback callback, TaskContext entityContext, EntityDescriptor<?> descriptor) {
        Preconditions.checkArgument((requestSize >= 0L ? 1 : 0) != 0);
        Objects.requireNonNull(callback);
        Objects.requireNonNull(entityContext);
        Objects.requireNonNull(descriptor);
        if (!this.dupSet.add(entityContext)) {
            throw new TezUncheckedException("A single entity can only make one call to request resources for now");
        }
        RequestorInfo requestInfo = new RequestorInfo(entityContext, requestSize, callback, descriptor);
        switch (requestInfo.getRequestContext().getComponentType()) {
            case INPUT: {
                this.numInputsSeen.incrementAndGet();
                Preconditions.checkState((this.numInputsSeen.get() <= this.numTotalInputs ? 1 : 0) != 0, (Object)("Num Requesting Inputs higher than total # of inputs: " + this.numInputsSeen + ", " + this.numTotalInputs));
                break;
            }
            case OUTPUT: {
                this.numOutputsSeen.incrementAndGet();
                Preconditions.checkState((this.numOutputsSeen.get() <= this.numTotalOutputs ? 1 : 0) != 0, (Object)("Num Requesting Inputs higher than total # of outputs: " + this.numOutputsSeen + ", " + this.numTotalOutputs));
                break;
            }
            case PROCESSOR: {
                break;
            }
        }
        this.requestList.add(requestInfo);
        return -1L;
    }

    private void validateAllocations(Iterable<Long> allocations, int numRequestors) {
        Objects.requireNonNull(allocations);
        long totalAllocated = 0L;
        int numAllocations = 0;
        for (Long l : allocations) {
            totalAllocated += l.longValue();
            ++numAllocations;
        }
        Preconditions.checkState((numAllocations == numRequestors ? 1 : 0) != 0, (Object)("Number of allocations must match number of requestors. Allocated=" + numAllocations + ", Requests: " + numRequestors));
        if (this.isInputOutputConcurrent) {
            Preconditions.checkState((totalAllocated <= this.totalJvmMemory ? 1 : 0) != 0, (Object)("Total allocation should be <= availableMem. TotalAllocated: " + totalAllocated + ", totalJvmMemory: " + this.totalJvmMemory));
        }
    }

    private void logInitialRequests(List<RequestorInfo> initialRequests) {
        if (initialRequests != null && !initialRequests.isEmpty()) {
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < initialRequests.size(); ++i) {
                InitialMemoryRequestContext context = initialRequests.get(i).getRequestContext();
                sb.append("[");
                sb.append(context.getComponentVertexName()).append(":");
                sb.append(context.getComponentType()).append(":");
                sb.append(context.getRequestedSize()).append(":").append(context.getComponentClassName());
                sb.append("]");
                if (i >= initialRequests.size() - 1) continue;
                sb.append(", ");
            }
            LOG.info("InitialRequests=" + sb.toString());
        }
    }

    private void logFinalAllocations(Iterable<Long> allocations, List<RequestorInfo> requestList) {
        if (requestList != null && !requestList.isEmpty()) {
            Iterator<Long> allocatedIter = allocations.iterator();
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < requestList.size(); ++i) {
                long allocated = allocatedIter.next();
                InitialMemoryRequestContext context = requestList.get(i).getRequestContext();
                sb.append("[");
                sb.append(context.getComponentVertexName()).append(":");
                sb.append(context.getComponentClassName()).append(":");
                sb.append(context.getComponentType()).append(":");
                sb.append(context.getRequestedSize()).append(":").append(allocated);
                sb.append("]");
                if (i >= requestList.size() - 1) continue;
                sb.append(", ");
            }
            LOG.info("Allocations=" + sb.toString());
        }
    }

    private static class RequestorInfo {
        private static final Logger LOG = LoggerFactory.getLogger(RequestorInfo.class);
        private final MemoryUpdateCallback callback;
        private final InitialMemoryRequestContext requestContext;

        public RequestorInfo(TaskContext taskContext, long requestSize, MemoryUpdateCallback callback, EntityDescriptor<?> descriptor) {
            String componentVertexName;
            InitialMemoryRequestContext.ComponentType type;
            if (taskContext instanceof InputContext) {
                type = InitialMemoryRequestContext.ComponentType.INPUT;
                componentVertexName = ((InputContext)taskContext).getSourceVertexName();
            } else if (taskContext instanceof OutputContext) {
                type = InitialMemoryRequestContext.ComponentType.OUTPUT;
                componentVertexName = ((OutputContext)taskContext).getDestinationVertexName();
            } else if (taskContext instanceof ProcessorContext) {
                type = InitialMemoryRequestContext.ComponentType.PROCESSOR;
                componentVertexName = ((ProcessorContext)taskContext).getTaskVertexName();
            } else {
                throw new IllegalArgumentException("Unknown type of entityContext: " + taskContext.getClass().getName());
            }
            this.requestContext = new InitialMemoryRequestContext(requestSize, descriptor.getClassName(), type, componentVertexName);
            this.callback = callback;
        }

        public MemoryUpdateCallback getCallback() {
            return this.callback;
        }

        public InitialMemoryRequestContext getRequestContext() {
            return this.requestContext;
        }
    }
}

