/*
 * Decompiled with CFR 0.152.
 */
package org.datavec.api.transform.transform.string;

import java.io.File;
import java.io.IOException;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(value={"inputSchema", "map", "columnIdx"})
public class StringListToCountsNDArrayTransform
extends BaseTransform {
    protected final String columnName;
    protected final String newColumnName;
    protected final List<String> vocabulary;
    protected final String delimiter;
    protected final boolean binary;
    protected final boolean ignoreUnknown;
    protected final Map<String, Integer> map;
    protected int columnIdx = -1;

    public StringListToCountsNDArrayTransform(String columnName, List<String> vocabulary, String delimiter, boolean binary, boolean ignoreUnknown) {
        this(columnName, columnName + "[BOW]", vocabulary, delimiter, binary, ignoreUnknown);
    }

    public StringListToCountsNDArrayTransform(@JsonProperty(value="columnName") String columnName, @JsonProperty(value="newColumnName") String newColumnName, @JsonProperty(value="vocabulary") List<String> vocabulary, @JsonProperty(value="delimiter") String delimiter, @JsonProperty(value="binary") boolean binary, @JsonProperty(value="ignoreUnknown") boolean ignoreUnknown) {
        this.columnName = columnName;
        this.newColumnName = newColumnName;
        this.vocabulary = vocabulary;
        this.delimiter = delimiter;
        this.binary = binary;
        this.ignoreUnknown = ignoreUnknown;
        this.map = new HashMap<String, Integer>();
        for (int i = 0; i < vocabulary.size(); ++i) {
            this.map.put(vocabulary.get(i), i);
        }
    }

    public static List<String> readVocabFromFile(String path) throws IOException {
        return FileUtils.readLines((File)new File(path), (String)"utf-8");
    }

    @Override
    public Schema transform(Schema inputSchema) {
        int colIdx = inputSchema.getIndexOfColumn(this.columnName);
        List<ColumnMetaData> oldMeta = inputSchema.getColumnMetaData();
        ArrayList<ColumnMetaData> newMeta = new ArrayList<ColumnMetaData>();
        List<String> oldNames = inputSchema.getColumnNames();
        Iterator<ColumnMetaData> typesIter = oldMeta.iterator();
        Iterator<String> namesIter = oldNames.iterator();
        int i = 0;
        while (typesIter.hasNext()) {
            ColumnMetaData t = typesIter.next();
            String name = namesIter.next();
            if (i++ == colIdx) {
                if (t.getColumnType() != ColumnType.String) {
                    throw new IllegalStateException("Cannot convert non-string type");
                }
                NDArrayMetaData meta = new NDArrayMetaData(this.newColumnName, new int[]{this.vocabulary.size()});
                newMeta.add(meta);
                continue;
            }
            newMeta.add(t);
        }
        return inputSchema.newSchema(newMeta);
    }

    @Override
    public void setInputSchema(Schema inputSchema) {
        this.inputSchema = inputSchema;
        this.columnIdx = inputSchema.getIndexOfColumn(this.columnName);
    }

    @Override
    public String toString() {
        return "StringListToCountsTransform(columnName=" + this.columnName + ",vocabularySize=" + this.vocabulary.size() + ",delimiter=\"" + this.delimiter + "\")";
    }

    protected Collection<Integer> getIndices(String text) {
        AbstractCollection indices = this.binary ? new HashSet() : new ArrayList();
        if (text != null && !text.isEmpty()) {
            String[] split;
            for (String s : split = text.split(this.delimiter)) {
                Integer idx = this.map.get(s);
                if (idx == null && !this.ignoreUnknown) {
                    throw new IllegalStateException("Encountered unknown String: \"" + s + "\"");
                }
                if (idx == null) continue;
                indices.add((Integer)idx);
            }
        }
        return indices;
    }

    protected INDArray makeBOWNDArray(Collection<Integer> indices) {
        INDArray counts = Nd4j.zeros((int)this.vocabulary.size());
        for (Integer idx : indices) {
            counts.putScalar(idx.intValue(), counts.getDouble(idx.intValue()) + 1.0);
        }
        Nd4j.getExecutioner().commit();
        return counts;
    }

    @Override
    public List<Writable> map(List<Writable> writables) {
        if (writables.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + this.toString());
        }
        int n = writables.size();
        ArrayList<Writable> out = new ArrayList<Writable>(n);
        int i = 0;
        for (Writable w : writables) {
            if (i++ == this.columnIdx) {
                String text = w.toString();
                Collection<Integer> indices = this.getIndices(text);
                INDArray counts = this.makeBOWNDArray(indices);
                out.add(new NDArrayWritable(counts));
                continue;
            }
            out.add(w);
        }
        return out;
    }

    @Override
    public Object map(Object input) {
        return null;
    }

    @Override
    public Object mapSequence(Object sequence) {
        return null;
    }

    @Override
    public String outputColumnName() {
        throw new UnsupportedOperationException("New column names is always more than 1 in length");
    }

    @Override
    public String[] outputColumnNames() {
        return this.vocabulary.toArray(new String[this.vocabulary.size()]);
    }

    @Override
    public String[] columnNames() {
        return new String[]{this.columnName()};
    }

    @Override
    public String columnName() {
        return this.columnName();
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof StringListToCountsNDArrayTransform)) {
            return false;
        }
        StringListToCountsNDArrayTransform other = (StringListToCountsNDArrayTransform)o;
        if (!other.canEqual(this)) {
            return false;
        }
        String this$columnName = this.columnName;
        String other$columnName = other.columnName;
        if (this$columnName == null ? other$columnName != null : !this$columnName.equals(other$columnName)) {
            return false;
        }
        String this$newColumnName = this.newColumnName;
        String other$newColumnName = other.newColumnName;
        if (this$newColumnName == null ? other$newColumnName != null : !this$newColumnName.equals(other$newColumnName)) {
            return false;
        }
        List<String> this$vocabulary = this.vocabulary;
        List<String> other$vocabulary = other.vocabulary;
        if (this$vocabulary == null ? other$vocabulary != null : !((Object)this$vocabulary).equals(other$vocabulary)) {
            return false;
        }
        String this$delimiter = this.delimiter;
        String other$delimiter = other.delimiter;
        if (this$delimiter == null ? other$delimiter != null : !this$delimiter.equals(other$delimiter)) {
            return false;
        }
        if (this.binary != other.binary) {
            return false;
        }
        if (this.ignoreUnknown != other.ignoreUnknown) {
            return false;
        }
        Map<String, Integer> this$map = this.map;
        Map<String, Integer> other$map = other.map;
        return !(this$map == null ? other$map != null : !((Object)this$map).equals(other$map));
    }

    protected boolean canEqual(Object other) {
        return other instanceof StringListToCountsNDArrayTransform;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        String $columnName = this.columnName;
        result = result * 59 + ($columnName == null ? 43 : $columnName.hashCode());
        String $newColumnName = this.newColumnName;
        result = result * 59 + ($newColumnName == null ? 43 : $newColumnName.hashCode());
        List<String> $vocabulary = this.vocabulary;
        result = result * 59 + ($vocabulary == null ? 43 : ((Object)$vocabulary).hashCode());
        String $delimiter = this.delimiter;
        result = result * 59 + ($delimiter == null ? 43 : $delimiter.hashCode());
        result = result * 59 + (this.binary ? 79 : 97);
        result = result * 59 + (this.ignoreUnknown ? 79 : 97);
        Map<String, Integer> $map = this.map;
        result = result * 59 + ($map == null ? 43 : ((Object)$map).hashCode());
        return result;
    }
}

