package io.improbable.keanu.tensor.validate;

import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.validate.check.CustomElementwiseTensorValueChecker;
import io.improbable.keanu.tensor.validate.check.CustomTensorValueChecker;
import io.improbable.keanu.tensor.validate.check.TensorValueChecker;
import io.improbable.keanu.tensor.validate.check.TensorValueNotEqualsCheck;
import io.improbable.keanu.tensor.validate.policy.TensorValidationPolicy;
import java.util.function.Function;

/* loaded from: input_file:io/improbable/keanu/tensor/validate/TensorValidator.class */
public interface TensorValidator<DATATYPE, TENSOR extends Tensor<DATATYPE, TENSOR>> extends TensorValueChecker<DATATYPE, TENSOR> {
    public static final TensorValidator<Double, DoubleTensor> ZERO_CATCHER = thatExpectsNotToFind(Double.valueOf(0.0d));
    public static final DebugTensorValidator<Double, DoubleTensor> NAN_CATCHER = new DebugTensorValidator<>(thatExpects(doubleTensor -> {
        return doubleTensor.notNaN();
    }));
    public static final DebugTensorValidator<Double, DoubleTensor> NAN_FIXER = new DebugTensorValidator<>(new NaNFixingTensorValidator(0.0d));

    TENSOR validate(TENSOR tensor);

    static <DATATYPE, TENSOR extends Tensor<DATATYPE, TENSOR>> TensorCheckAndRespondValidator<DATATYPE, TENSOR> thatExpectsNotToFind(DATATYPE datatype) {
        return new TensorCheckAndRespondValidator<>(new TensorValueNotEqualsCheck(datatype));
    }

    static <DATATYPE, TENSOR extends Tensor<DATATYPE, TENSOR>> TensorValidator<DATATYPE, TENSOR> thatReplaces(DATATYPE datatype, DATATYPE datatype2) {
        return new TensorCheckAndRespondValidator(new TensorValueNotEqualsCheck(datatype), TensorValidationPolicy.changeValueTo(datatype2));
    }

    static <DATATYPE, TENSOR extends Tensor<DATATYPE, TENSOR>> TensorCheckAndRespondValidator<DATATYPE, TENSOR> thatExpects(Function<TENSOR, BooleanTensor> function) {
        return new TensorCheckAndRespondValidator<>(new CustomTensorValueChecker(function));
    }

    static <DATATYPE, TENSOR extends Tensor<DATATYPE, TENSOR>> TensorCheckAndRespondValidator<DATATYPE, TENSOR> thatExpectsElementwise(Function<DATATYPE, Boolean> function) {
        return new TensorCheckAndRespondValidator<>(new CustomElementwiseTensorValueChecker(function));
    }

    static <DATATYPE, TENSOR extends Tensor<DATATYPE, TENSOR>> TensorCheckAndRespondValidator<DATATYPE, TENSOR> thatFixesElementwise(Function<DATATYPE, Boolean> function, TensorValidationPolicy<DATATYPE, TENSOR> tensorValidationPolicy) {
        return new TensorCheckAndRespondValidator<>(new CustomElementwiseTensorValueChecker(function), tensorValidationPolicy);
    }
}
