/*
 * Decompiled with CFR 0.152.
 */
package smile.feature;

import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.data.vector.DoubleVector;
import smile.feature.FeatureTransform;
import smile.math.MathEx;

public class MaxAbsScaler
implements FeatureTransform {
    private static final long serialVersionUID = 2L;
    protected StructType schema;
    private double[] scale;

    public MaxAbsScaler(StructType schema, double[] scale) {
        if (schema.length() != scale.length) {
            throw new IllegalArgumentException("Schema and scaling factor size don't match");
        }
        this.schema = schema;
        this.scale = scale;
        for (int i = 0; i < scale.length; ++i) {
            if (!MathEx.isZero((double)scale[i])) continue;
            scale[i] = 1.0;
        }
    }

    public static MaxAbsScaler fit(DataFrame data) {
        if (data.isEmpty()) {
            throw new IllegalArgumentException("Empty data frame");
        }
        StructType schema = data.schema();
        double[] scale = new double[schema.length()];
        for (int i = 0; i < scale.length; ++i) {
            if (!schema.field(i).isNumeric()) continue;
            scale[i] = ((DoubleStream)data.doubleVector(i).stream()).map(Math::abs).max().getAsDouble();
        }
        return new MaxAbsScaler(schema, scale);
    }

    public static MaxAbsScaler fit(double[][] data) {
        return MaxAbsScaler.fit(DataFrame.of((double[][])data, (String[])new String[0]));
    }

    private double scale(double x, int i) {
        return x / this.scale[i];
    }

    @Override
    public double[] transform(double[] x) {
        double[] y = new double[x.length];
        for (int i = 0; i < y.length; ++i) {
            y[i] = this.scale(x[i], i);
        }
        return y;
    }

    @Override
    public Tuple transform(final Tuple x) {
        if (!this.schema.equals((Object)x.schema())) {
            throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", x.schema(), this.schema));
        }
        return new AbstractTuple(){

            public Object get(int i) {
                if (MaxAbsScaler.this.schema.field(i).isNumeric()) {
                    return MaxAbsScaler.this.scale(x.getDouble(i), i);
                }
                return x.get(i);
            }

            public StructType schema() {
                return MaxAbsScaler.this.schema;
            }
        };
    }

    @Override
    public DataFrame transform(DataFrame data) {
        if (!this.schema.equals((Object)data.schema())) {
            throw new IllegalArgumentException(String.format("Invalid schema %s, expected %s", data.schema(), this.schema));
        }
        BaseVector[] vectors = new BaseVector[this.schema.length()];
        for (int i = 0; i < this.scale.length; ++i) {
            StructField field = this.schema.field(i);
            if (field.isNumeric()) {
                int col = i;
                DoubleStream stream = data.stream().mapToDouble(t -> this.scale(t.getDouble(col), col));
                vectors[i] = DoubleVector.of((StructField)field, (DoubleStream)stream);
                continue;
            }
            vectors[i] = data.column(i);
        }
        return DataFrame.of((BaseVector[])vectors);
    }

    public String toString() {
        return IntStream.range(0, this.scale.length).mapToObj(i -> String.format("%s[%.4f]", this.schema.field((int)i).name, this.scale[i])).collect(Collectors.joining(",", "MaxAbsScaler(", ")"));
    }
}

