/*
 * Decompiled with CFR 0.152.
 */
package ai.h2o.targetencoding;

import ai.h2o.targetencoding.ColumnsToSingleMapping;
import ai.h2o.targetencoding.TargetEncoderHelper;
import ai.h2o.targetencoding.TargetEncoderModel;
import ai.h2o.targetencoding.interaction.InteractionSupport;
import hex.ModelBuilder;
import hex.ModelCategory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import water.DKV;
import water.Key;
import water.Lockable;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMap;

public class TargetEncoder
extends ModelBuilder<TargetEncoderModel, TargetEncoderModel.TargetEncoderParameters, TargetEncoderModel.TargetEncoderOutput> {
    private static final Logger LOG = Logger.getLogger(TargetEncoder.class);
    private TargetEncoderModel _targetEncoderModel;
    private String[][] _columnsToEncode;

    public TargetEncoder(TargetEncoderModel.TargetEncoderParameters parms) {
        super(parms);
        this.init(false);
    }

    public TargetEncoder(TargetEncoderModel.TargetEncoderParameters parms, Key<TargetEncoderModel> key) {
        super(parms, key);
        this.init(false);
    }

    public TargetEncoder(boolean startupOnce) {
        super(new TargetEncoderModel.TargetEncoderParameters(), startupOnce);
    }

    @Override
    public void init(boolean expensive) {
        this.disableIgnoreConstColsFeature(expensive);
        this.ignoreUnusedColumns(expensive);
        super.init(expensive);
        assert (((TargetEncoderModel.TargetEncoderParameters)this._parms)._nfolds == 0) : "nfolds usage forbidden in TargetEncoder";
        if (expensive) {
            if (((TargetEncoderModel.TargetEncoderParameters)this._parms)._data_leakage_handling == null) {
                ((TargetEncoderModel.TargetEncoderParameters)this._parms)._data_leakage_handling = TargetEncoderModel.DataLeakageHandlingStrategy.None;
            }
            if (((TargetEncoderModel.TargetEncoderParameters)this._parms)._data_leakage_handling == TargetEncoderModel.DataLeakageHandlingStrategy.KFold && ((TargetEncoderModel.TargetEncoderParameters)this._parms)._fold_column == null) {
                this.error("_fold_column", "Fold column is required when using KFold leakage handling strategy.");
            }
            Frame train = this.train();
            this._columnsToEncode = ((TargetEncoderModel.TargetEncoderParameters)this._parms)._columns_to_encode;
            if (this._columnsToEncode == null) {
                List<String> nonPredictors = Arrays.asList(((TargetEncoderModel.TargetEncoderParameters)this._parms).getNonPredictors());
                ArrayList<String[]> columnsToEncode = new ArrayList<String[]>(train.numCols());
                for (int i = 0; i < train.numCols(); ++i) {
                    String colName = train.name(i);
                    if (nonPredictors.contains(colName)) continue;
                    if (!train.vec(i).isCategorical()) {
                        this.warn("_train", "Column `" + colName + "` is not categorical and will therefore be ignored by target encoder.");
                        continue;
                    }
                    columnsToEncode.add(new String[]{colName});
                }
                this._columnsToEncode = (String[][])columnsToEncode.toArray((T[])new String[0][]);
            } else {
                HashSet<Object> validated = new HashSet<Object>();
                for (Object[] objectArray : this._columnsToEncode) {
                    if (objectArray.length != new HashSet<String>(Arrays.asList(objectArray)).size()) {
                        this.error("_columns_to_encode", "Columns interaction " + Arrays.toString(objectArray) + " contains duplicate columns.");
                    }
                    for (Object col : objectArray) {
                        if (validated.contains(col)) continue;
                        Vec vec = train.vec((String)col);
                        if (vec == null) {
                            this.error("_columns_to_encode", "Column `" + (String)col + "` from interaction " + Arrays.toString(objectArray) + " is not categorical or is missing from the training frame.");
                        } else if (!vec.isCategorical()) {
                            this.error("_columns_to_encode", "Column `" + (String)col + "` from interaction " + Arrays.toString(objectArray) + " must first be converted into categorical to be used by target encoder.");
                        }
                        validated.add(col);
                    }
                }
            }
        }
    }

    private void disableIgnoreConstColsFeature(boolean expensive) {
        ((TargetEncoderModel.TargetEncoderParameters)this._parms)._ignore_const_cols = false;
        if (expensive && LOG.isInfoEnabled()) {
            LOG.info("We don't want to ignore any columns during target encoding transformation therefore `_ignore_const_cols` parameter was set to `false`");
        }
    }

    private void ignoreUnusedColumns(boolean expensive) {
        if (!expensive || ((TargetEncoderModel.TargetEncoderParameters)this._parms)._columns_to_encode == null || ((TargetEncoderModel.TargetEncoderParameters)this._parms).train() == null) {
            return;
        }
        HashSet<String> usedColumns = new HashSet<String>(Arrays.asList(((TargetEncoderModel.TargetEncoderParameters)this._parms).getNonPredictors()));
        for (String[] colGroup : ((TargetEncoderModel.TargetEncoderParameters)this._parms)._columns_to_encode) {
            usedColumns.addAll(Arrays.asList(colGroup));
        }
        HashSet<String> unusedColumns = new HashSet<String>(Arrays.asList(((TargetEncoderModel.TargetEncoderParameters)this._parms).train()._names));
        unusedColumns.removeAll(usedColumns);
        HashSet ignoredColumns = ((TargetEncoderModel.TargetEncoderParameters)this._parms)._ignored_columns == null ? new HashSet() : new HashSet<String>(Arrays.asList(((TargetEncoderModel.TargetEncoderParameters)this._parms)._ignored_columns));
        unusedColumns.addAll(ignoredColumns);
        ((TargetEncoderModel.TargetEncoderParameters)this._parms)._ignored_columns = unusedColumns.toArray(new String[0]);
    }

    @Override
    public boolean nFoldCV() {
        return false;
    }

    @Override
    protected ModelBuilder.Driver trainModelImpl() {
        return new TargetEncoderDriver();
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.TargetEncoder};
    }

    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    @Override
    public boolean haveMojo() {
        return true;
    }

    private class TargetEncoderDriver
    extends ModelBuilder.Driver {
        private TargetEncoderDriver() {
            super(TargetEncoder.this);
        }

        @Override
        public void computeImpl() {
            TargetEncoder.this._targetEncoderModel = null;
            try {
                TargetEncoder.this.init(true);
                if (TargetEncoder.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(TargetEncoder.this);
                }
                TargetEncoderModel.TargetEncoderOutput output = new TargetEncoderModel.TargetEncoderOutput(TargetEncoder.this);
                TargetEncoderModel model = new TargetEncoderModel(TargetEncoder.this.dest(), (TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms, output);
                TargetEncoder.this._targetEncoderModel = (TargetEncoderModel)model.delete_and_lock(TargetEncoder.this._job);
                Frame workingFrame = new Frame(TargetEncoder.this.train());
                ColumnsToSingleMapping[] columnsToEncodeMapping = new ColumnsToSingleMapping[TargetEncoder.this._columnsToEncode.length];
                for (int i = 0; i < columnsToEncodeMapping.length; ++i) {
                    String[] colGroup = TargetEncoder.this._columnsToEncode[i];
                    int interactionCol = InteractionSupport.addFeatureInteraction(workingFrame, colGroup);
                    String[] interactionDomain = workingFrame.vec(interactionCol).domain();
                    columnsToEncodeMapping[i] = new ColumnsToSingleMapping(colGroup, workingFrame.name(interactionCol), interactionDomain);
                }
                String[] singleColumnsToEncode = (String[])Arrays.stream(columnsToEncodeMapping).map(ColumnsToSingleMapping::toSingle).toArray(String[]::new);
                IcedHashMap<String, Frame> _targetEncodingMap = this.prepareEncodingMap(workingFrame, singleColumnsToEncode);
                for (Map.Entry entry : _targetEncodingMap.entrySet()) {
                    Frame encodings = (Frame)entry.getValue();
                    Scope.untrack(encodings);
                }
                output.init(_targetEncodingMap, columnsToEncodeMapping);
                TargetEncoder.this._job.update(1L);
            }
            catch (Exception e) {
                if (TargetEncoder.this._targetEncoderModel != null) {
                    Scope.track_generic(TargetEncoder.this._targetEncoderModel);
                }
                throw e;
            }
            finally {
                if (TargetEncoder.this._targetEncoderModel != null) {
                    TargetEncoder.this._targetEncoderModel.update(TargetEncoder.this._job);
                    TargetEncoder.this._targetEncoderModel.unlock(TargetEncoder.this._job);
                }
            }
        }

        private Frame filterOutNAsFromTargetColumn(Frame data, int targetColumnIndex) {
            return TargetEncoderHelper.filterOutNAsInColumn(data, targetColumnIndex);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private IcedHashMap<String, Frame> prepareEncodingMap(Frame fr, String[] columnsToEncode) {
            Lockable workingFrame = null;
            try {
                int targetIdx = fr.find(((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._response_column);
                int foldColIdx = ((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._fold_column == null ? -1 : fr.find(((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._fold_column);
                workingFrame = this.filterOutNAsFromTargetColumn(fr, targetIdx);
                IcedHashMap<String, Frame> columnToEncodings = new IcedHashMap<String, Frame>();
                for (String columnToEncode : columnsToEncode) {
                    int colIdx = ((Frame)workingFrame).find(columnToEncode);
                    TargetEncoderHelper.imputeCategoricalColumn((Frame)workingFrame, colIdx, columnToEncode + "_NA");
                    Frame encodings = TargetEncoderHelper.buildEncodingsFrame((Frame)workingFrame, colIdx, targetIdx, foldColIdx, TargetEncoder.this.nclasses());
                    Frame finalEncodings = this.applyLeakageStrategyToEncodings(encodings, columnToEncode, ((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._data_leakage_handling, ((TargetEncoderModel.TargetEncoderParameters)TargetEncoder.this._parms)._fold_column);
                    encodings.delete();
                    encodings = finalEncodings;
                    if (encodings._key != null) {
                        DKV.remove(encodings._key);
                    }
                    encodings._key = Key.make(TargetEncoder.this._result.toString() + "_encodings_" + columnToEncode);
                    DKV.put(encodings);
                    columnToEncodings.put(columnToEncode, encodings);
                }
                IcedHashMap<String, Frame> icedHashMap = columnToEncodings;
                return icedHashMap;
            }
            finally {
                if (workingFrame != null) {
                    workingFrame.delete();
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private Frame applyLeakageStrategyToEncodings(Frame encodings, String columnToEncode, TargetEncoderModel.DataLeakageHandlingStrategy leakageHandlingStrategy, String foldColumn) {
            Frame groupedEncodings = null;
            int encodingsTEColIdx = encodings.find(columnToEncode);
            try {
                Scope.enter();
                switch (leakageHandlingStrategy) {
                    case KFold: {
                        long[] foldValues;
                        for (long foldValue : foldValues = TargetEncoderHelper.getUniqueColumnValues(encodings, encodings.find(foldColumn))) {
                            Frame outOfFoldEncodings = this.getOutOfFoldEncodings(encodings, foldColumn, foldValue);
                            Scope.track(outOfFoldEncodings);
                            Frame tmpEncodings = TargetEncoderHelper.register(TargetEncoderHelper.groupEncodingsByCategory(outOfFoldEncodings, encodingsTEColIdx));
                            Scope.track(tmpEncodings);
                            TargetEncoderHelper.addCon(tmpEncodings, foldColumn, foldValue);
                            if (groupedEncodings == null) {
                                groupedEncodings = tmpEncodings;
                            } else {
                                Frame newHoldoutEncodings = TargetEncoderHelper.rBind(groupedEncodings, tmpEncodings);
                                groupedEncodings.delete();
                                groupedEncodings = newHoldoutEncodings;
                            }
                            Scope.track(groupedEncodings);
                        }
                        break;
                    }
                    case LeaveOneOut: 
                    case None: {
                        groupedEncodings = TargetEncoderHelper.groupEncodingsByCategory(encodings, encodingsTEColIdx, foldColumn != null);
                        break;
                    }
                    default: {
                        throw new IllegalStateException("null or unsupported leakageHandlingStrategy");
                    }
                }
                Scope.untrack(groupedEncodings);
            }
            finally {
                Scope.exit(new Key[0]);
            }
            return groupedEncodings;
        }

        private Frame getOutOfFoldEncodings(Frame encodingsFrame, String foldColumn, long foldValue) {
            int foldColumnIdx = encodingsFrame.find(foldColumn);
            return TargetEncoderHelper.filterNotByValue(encodingsFrame, foldColumnIdx, foldValue);
        }
    }
}

