/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.exec.tez;

import java.io.DataInput;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import javolution.testing.AssertionException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluator;
import org.apache.hadoop.hive.ql.exec.ExprNodeEvaluatorFactory;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
import org.apache.hadoop.hive.ql.plan.MapWork;
import org.apache.hadoop.hive.ql.plan.PartitionDesc;
import org.apache.hadoop.hive.ql.plan.TableDesc;
import org.apache.hadoop.hive.serde2.Deserializer;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.StandardStructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.AbstractPrimitiveWritableObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.util.ReflectionUtils;
import org.apache.tez.dag.api.event.VertexState;
import org.apache.tez.runtime.api.InputInitializerContext;
import org.apache.tez.runtime.api.events.InputInitializerEvent;

public class DynamicPartitionPruner {
    private static final Log LOG = LogFactory.getLog(DynamicPartitionPruner.class);
    private final Map<String, List<SourceInfo>> sourceInfoMap = new HashMap<String, List<SourceInfo>>();
    private final BytesWritable writable = new BytesWritable();
    private final BlockingQueue<Object> queue = new LinkedBlockingQueue<Object>();
    private final Set<String> sourcesWaitingForEvents = new HashSet<String>();
    private int sourceInfoCount = 0;
    private final Object endOfEvents = new Object();
    private int totalEventCount = 0;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void prune(MapWork work, JobConf jobConf, InputInitializerContext context) throws SerDeException, IOException, InterruptedException, HiveException {
        Set<String> set = this.sourcesWaitingForEvents;
        synchronized (set) {
            this.initialize(work, jobConf);
            if (this.sourcesWaitingForEvents.isEmpty()) {
                return;
            }
            Set<VertexState> states = Collections.singleton(VertexState.SUCCEEDED);
            for (String source : this.sourcesWaitingForEvents) {
                context.registerForVertexStateUpdates(source, states);
            }
        }
        LOG.info((Object)("Waiting for events (" + this.sourceInfoCount + " items) ..."));
        this.processEvents();
        this.prunePartitions(work, context);
        LOG.info((Object)"Ok to proceed.");
    }

    public BlockingQueue<Object> getQueue() {
        return this.queue;
    }

    private void clear() {
        this.sourceInfoMap.clear();
        this.sourceInfoCount = 0;
    }

    public void initialize(MapWork work, JobConf jobConf) throws SerDeException {
        this.clear();
        HashMap<String, SourceInfo> columnMap = new HashMap<String, SourceInfo>();
        Set<String> sources = work.getEventSourceTableDescMap().keySet();
        this.sourcesWaitingForEvents.addAll(sources);
        for (String s : sources) {
            List<TableDesc> tables = work.getEventSourceTableDescMap().get(s);
            List<String> columnNames = work.getEventSourceColumnNameMap().get(s);
            List<String> columnTypes = work.getEventSourceColumnTypeMap().get(s);
            List<ExprNodeDesc> partKeyExprs = work.getEventSourcePartKeyExprMap().get(s);
            Iterator<String> cit = columnNames.iterator();
            Iterator<String> typit = columnTypes.iterator();
            Iterator<ExprNodeDesc> pit = partKeyExprs.iterator();
            for (TableDesc t : tables) {
                ++this.sourceInfoCount;
                String columnName = cit.next();
                String columnType = typit.next();
                ExprNodeDesc partKeyExpr = pit.next();
                SourceInfo si = new SourceInfo(t, partKeyExpr, columnName, columnType, jobConf);
                if (!this.sourceInfoMap.containsKey(s)) {
                    this.sourceInfoMap.put(s, new ArrayList());
                }
                List<SourceInfo> sis = this.sourceInfoMap.get(s);
                sis.add(si);
                if (columnMap.containsKey(columnName)) {
                    si.values = ((SourceInfo)columnMap.get((Object)columnName)).values;
                    si.skipPruning = ((SourceInfo)columnMap.get((Object)columnName)).skipPruning;
                }
                columnMap.put(columnName, si);
            }
        }
    }

    private void prunePartitions(MapWork work, InputInitializerContext context) throws HiveException {
        int expectedEvents = 0;
        for (String source : this.sourceInfoMap.keySet()) {
            for (SourceInfo si : this.sourceInfoMap.get(source)) {
                int taskNum = context.getVertexNumTasks(source);
                LOG.info((Object)("Expecting " + taskNum + " events for vertex " + source));
                expectedEvents += taskNum;
                this.prunePartitionSingleSource(source, si, work);
            }
        }
        if (expectedEvents != this.totalEventCount) {
            LOG.error((Object)("Expecting: " + expectedEvents + ", received: " + this.totalEventCount));
            throw new HiveException("Incorrect event count in dynamic parition pruning");
        }
    }

    private void prunePartitionSingleSource(String source, SourceInfo si, MapWork work) throws HiveException {
        if (si.skipPruning.get()) {
            LOG.info((Object)("Skip pruning on " + source + ", column " + si.columnName));
            return;
        }
        Set<Object> values = si.values;
        String columnName = si.columnName;
        if (LOG.isDebugEnabled()) {
            StringBuilder sb = new StringBuilder("Pruning ");
            sb.append(columnName);
            sb.append(" with ");
            for (Object value : values) {
                sb.append(value == null ? null : value.toString());
                sb.append(", ");
            }
            LOG.debug((Object)sb.toString());
        }
        AbstractPrimitiveWritableObjectInspector oi = PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(TypeInfoFactory.getPrimitiveTypeInfo(si.columnType));
        ObjectInspectorConverters.Converter converter = ObjectInspectorConverters.getConverter((ObjectInspector)PrimitiveObjectInspectorFactory.javaStringObjectInspector, (ObjectInspector)oi);
        StandardStructObjectInspector soi = ObjectInspectorFactory.getStandardStructObjectInspector(Collections.singletonList(columnName), Collections.singletonList(oi));
        ExprNodeEvaluator eval = ExprNodeEvaluatorFactory.get(si.partKey);
        eval.initialize(soi);
        this.applyFilterToPartitions(work, converter, eval, columnName, values);
    }

    private void applyFilterToPartitions(MapWork work, ObjectInspectorConverters.Converter converter, ExprNodeEvaluator eval, String columnName, Set<Object> values) throws HiveException {
        Object[] row = new Object[1];
        Iterator<String> it = work.getPathToPartitionInfo().keySet().iterator();
        while (it.hasNext()) {
            String p = it.next();
            PartitionDesc desc = work.getPathToPartitionInfo().get(p);
            LinkedHashMap<String, String> spec = desc.getPartSpec();
            if (spec == null) {
                throw new AssertionException("No partition spec found in dynamic pruning");
            }
            String partValueString = (String)spec.get(columnName);
            if (partValueString == null) {
                throw new AssertionException("Could not find partition value for column: " + columnName);
            }
            Object partValue = converter.convert(partValueString);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Converted partition value: " + partValue + " original (" + partValueString + ")"));
            }
            row[0] = partValue;
            partValue = eval.evaluate(row);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("part key expr applied: " + partValue));
            }
            if (values.contains(partValue)) continue;
            LOG.info((Object)("Pruning path: " + p));
            it.remove();
            work.getPathToAliases().remove(p);
            work.getPaths().remove(p);
            work.getPartitionDescs().remove(desc);
        }
    }

    private void processEvents() throws SerDeException, IOException, InterruptedException {
        Object element;
        int eventCount = 0;
        while ((element = this.queue.take()) != this.endOfEvents) {
            InputInitializerEvent event = (InputInitializerEvent)element;
            LOG.info((Object)("Input event: " + event.getTargetInputName() + ", " + event.getTargetVertexName() + ", " + (event.getUserPayload().limit() - event.getUserPayload().position())));
            this.processPayload(event.getUserPayload(), event.getSourceVertexName());
            ++eventCount;
        }
        LOG.info((Object)("Received events: " + eventCount));
    }

    private String processPayload(ByteBuffer payload, String sourceName) throws SerDeException, IOException {
        DataInputStream in = new DataInputStream(new ByteBufferBackedInputStream(payload));
        String columnName = in.readUTF();
        boolean skip = in.readBoolean();
        LOG.info((Object)("Source of event: " + sourceName));
        List<SourceInfo> infos = this.sourceInfoMap.get(sourceName);
        if (infos == null) {
            in.close();
            throw new AssertionException("no source info for event source: " + sourceName);
        }
        SourceInfo info = null;
        for (SourceInfo si : infos) {
            if (!columnName.equals(si.columnName)) continue;
            info = si;
            break;
        }
        if (info == null) {
            in.close();
            throw new AssertionException("no source info for column: " + columnName);
        }
        if (skip) {
            info.skipPruning.set(true);
        }
        while (payload.hasRemaining()) {
            this.writable.readFields((DataInput)in);
            Object row = info.deserializer.deserialize((Writable)this.writable);
            Object value = info.soi.getStructFieldData(row, info.field);
            value = ObjectInspectorUtils.copyToStandardObject(value, info.fieldInspector);
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)("Adding: " + value + " to list of required partitions"));
            }
            info.values.add(value);
        }
        in.close();
        return sourceName;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void addEvent(InputInitializerEvent event) {
        Set<String> set = this.sourcesWaitingForEvents;
        synchronized (set) {
            if (this.sourcesWaitingForEvents.contains(event.getSourceVertexName())) {
                ++this.totalEventCount;
                this.queue.offer(event);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void processVertex(String name) {
        LOG.info((Object)("Vertex succeeded: " + name));
        Set<String> set = this.sourcesWaitingForEvents;
        synchronized (set) {
            this.sourcesWaitingForEvents.remove(name);
            if (this.sourcesWaitingForEvents.isEmpty()) {
                this.queue.offer(this.endOfEvents);
            } else {
                LOG.info((Object)("Waiting for " + this.sourcesWaitingForEvents.size() + " events."));
            }
        }
    }

    private static class ByteBufferBackedInputStream
    extends InputStream {
        ByteBuffer buf;

        public ByteBufferBackedInputStream(ByteBuffer buf) {
            this.buf = buf;
        }

        @Override
        public int read() throws IOException {
            if (!this.buf.hasRemaining()) {
                return -1;
            }
            return this.buf.get() & 0xFF;
        }

        @Override
        public int read(byte[] bytes, int off, int len) throws IOException {
            if (!this.buf.hasRemaining()) {
                return -1;
            }
            len = Math.min(len, this.buf.remaining());
            this.buf.get(bytes, off, len);
            return len;
        }
    }

    private static class SourceInfo {
        public final ExprNodeDesc partKey;
        public final Deserializer deserializer;
        public final StructObjectInspector soi;
        public final StructField field;
        public final ObjectInspector fieldInspector;
        public Set<Object> values = new HashSet<Object>();
        public AtomicBoolean skipPruning = new AtomicBoolean();
        public final String columnName;
        public final String columnType;

        public SourceInfo(TableDesc table, ExprNodeDesc partKey, String columnName, String columnType, JobConf jobConf) throws SerDeException {
            this.skipPruning.set(false);
            this.partKey = partKey;
            this.columnName = columnName;
            this.columnType = columnType;
            this.deserializer = (Deserializer)ReflectionUtils.newInstance(table.getDeserializerClass(), null);
            this.deserializer.initialize((Configuration)jobConf, table.getProperties());
            ObjectInspector inspector = this.deserializer.getObjectInspector();
            LOG.debug((Object)("Type of obj insp: " + inspector.getTypeName()));
            this.soi = (StructObjectInspector)inspector;
            List<? extends StructField> fields = this.soi.getAllStructFieldRefs();
            if (fields.size() > 1) {
                LOG.error((Object)"expecting single field in input");
            }
            this.field = fields.get(0);
            this.fieldInspector = ObjectInspectorUtils.getStandardObjectInspector(this.field.getFieldObjectInspector());
        }
    }
}

