package cdm.observable.asset.calculatedrate.functions;

import cdm.base.math.ArithmeticOperationEnum;
import cdm.base.math.functions.VectorGrowthOperation;
import cdm.base.math.functions.VectorOperation;
import cdm.base.math.functions.VectorScalarOperation;
import cdm.observable.asset.calculatedrate.CalculatedRateDetails;
import cdm.observable.asset.calculatedrate.CalculatedRateDetails.CalculatedRateDetailsBuilder;
import com.google.inject.ImplementedBy;
import com.rosetta.model.lib.expression.MapperMaths;
import com.rosetta.model.lib.functions.ModelObjectValidator;
import com.rosetta.model.lib.functions.RosettaFunction;
import com.rosetta.model.lib.mapper.Mapper;
import com.rosetta.model.lib.mapper.MapperC;
import com.rosetta.model.lib.mapper.MapperS;
import java.math.BigDecimal;
import java.util.List;
import java.util.Optional;
import javax.inject.Inject;


@ImplementedBy(ApplyCompoundingFormula.ApplyCompoundingFormulaDefault.class)
public abstract class ApplyCompoundingFormula implements RosettaFunction {
	
	@Inject protected ModelObjectValidator objectValidator;
	
	// RosettaFunction dependencies
	//
	@Inject protected VectorGrowthOperation vectorGrowthOperation;
	@Inject protected VectorOperation vectorOperation;
	@Inject protected VectorScalarOperation vectorScalarOperation;

	/**
	* @param observations A vector of observation value.
	* @param weights A vector of weights (should be same size as observations, 1 weight per observation.
	* @param yearFrac Year fraction of a single day (i.e. 1/basis.
	* @return results Details of the compounding calculation.
	*/
	public CalculatedRateDetails evaluate(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
		CalculatedRateDetails.CalculatedRateDetailsBuilder resultsBuilder = doEvaluate(observations, weights, yearFrac);
		
		final CalculatedRateDetails results;
		if (resultsBuilder == null) {
			results = null;
		} else {
			results = resultsBuilder.build();
			objectValidator.validate(CalculatedRateDetails.class, results);
		}
		
		return results;
	}

	protected abstract CalculatedRateDetails.CalculatedRateDetailsBuilder doEvaluate(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> weightedObservations(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> scaledAndWeightedObservations(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> growthFactors(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> growthCurve(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> finalValue(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> totalWeight(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> overallYearFrac(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	protected abstract Mapper<BigDecimal> calculatedRate(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac);

	public static class ApplyCompoundingFormulaDefault extends ApplyCompoundingFormula {
		@Override
		protected CalculatedRateDetails.CalculatedRateDetailsBuilder doEvaluate(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			CalculatedRateDetails.CalculatedRateDetailsBuilder results = CalculatedRateDetails.builder();
			return assignOutput(results, observations, weights, yearFrac);
		}
		
		protected CalculatedRateDetails.CalculatedRateDetailsBuilder assignOutput(CalculatedRateDetails.CalculatedRateDetailsBuilder results, List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			results
				.setAggregateValue(MapperS.of(finalValue(observations, weights, yearFrac).get()).get());
			
			results
				.setAggregateWeight(MapperS.of(totalWeight(observations, weights, yearFrac).get()).get());
			
			results
				.setCalculatedRate(MapperS.of(calculatedRate(observations, weights, yearFrac).get()).get());
			
			results
				.addCompoundedGrowth(MapperC.<BigDecimal>of(growthCurve(observations, weights, yearFrac).getMulti()).getMulti());
			
			results
				.addGrowthFactor(MapperC.<BigDecimal>of(growthFactors(observations, weights, yearFrac).getMulti()).getMulti());
			
			results
				.addWeightedRates(MapperC.<BigDecimal>of(weightedObservations(observations, weights, yearFrac).getMulti()).getMulti());
			
			return Optional.ofNullable(results)
				.map(o -> o.prune())
				.orElse(null);
		}
		
		@Override
		protected Mapper<BigDecimal> weightedObservations(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperC.<BigDecimal>of(vectorOperation.evaluate(MapperS.of(ArithmeticOperationEnum.MULTIPLY).get(), MapperC.<BigDecimal>of(observations).getMulti(), MapperC.<BigDecimal>of(weights).getMulti()));
		}
		
		@Override
		protected Mapper<BigDecimal> scaledAndWeightedObservations(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperC.<BigDecimal>of(vectorScalarOperation.evaluate(MapperS.of(ArithmeticOperationEnum.MULTIPLY).get(), MapperC.<BigDecimal>of(weightedObservations(observations, weights, yearFrac).getMulti()).getMulti(), MapperS.of(yearFrac).get()));
		}
		
		@Override
		protected Mapper<BigDecimal> growthFactors(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperC.<BigDecimal>of(vectorScalarOperation.evaluate(MapperS.of(ArithmeticOperationEnum.ADD).get(), MapperC.<BigDecimal>of(scaledAndWeightedObservations(observations, weights, yearFrac).getMulti()).getMulti(), MapperS.of(new BigDecimal("1.0")).get()));
		}
		
		@Override
		protected Mapper<BigDecimal> growthCurve(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperC.<BigDecimal>of(vectorGrowthOperation.evaluate(MapperS.of(new BigDecimal("1.0")).get(), MapperC.<BigDecimal>of(growthFactors(observations, weights, yearFrac).getMulti()).getMulti()));
		}
		
		@Override
		protected Mapper<BigDecimal> finalValue(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperC.<BigDecimal>of(growthCurve(observations, weights, yearFrac).getMulti())
				.last();
		}
		
		@Override
		protected Mapper<BigDecimal> totalWeight(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperC.<BigDecimal>of(weights)
				.sumBigDecimal();
		}
		
		@Override
		protected Mapper<BigDecimal> overallYearFrac(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperMaths.<BigDecimal, BigDecimal, BigDecimal>multiply(MapperS.of(totalWeight(observations, weights, yearFrac).get()), MapperS.of(yearFrac));
		}
		
		@Override
		protected Mapper<BigDecimal> calculatedRate(List<BigDecimal> observations, List<BigDecimal> weights, BigDecimal yearFrac) {
			return MapperMaths.<BigDecimal, BigDecimal, BigDecimal>divide(MapperMaths.<BigDecimal, BigDecimal, Integer>subtract(MapperS.of(finalValue(observations, weights, yearFrac).get()), MapperS.of(Integer.valueOf(1))), MapperS.of(overallYearFrac(observations, weights, yearFrac).get()));
		}
	}
}
