/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression;

import java.io.IOException;
import java.text.MessageFormat;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.regression.RegressionMetric;

public class MeanSquaredError
implements RegressionMetric {
    public static final ParseField NAME = new ParseField("mean_squared_error", new String[0]);
    private static final String PAINLESS_TEMPLATE = "def diff = doc[''{0}''].value - doc[''{1}''].value;return diff * diff;";
    private static final String AGG_NAME = "regression_" + NAME.getPreferredName();
    private static final ObjectParser<MeanSquaredError, Void> PARSER = new ObjectParser("mean_squared_error", true, MeanSquaredError::new);
    private EvaluationMetricResult result;

    private static String buildScript(Object ... args) {
        return new MessageFormat(PAINLESS_TEMPLATE, Locale.ROOT).format(args);
    }

    public static MeanSquaredError fromXContent(XContentParser parser) {
        return (MeanSquaredError)PARSER.apply(parser, null);
    }

    public MeanSquaredError(StreamInput in) {
    }

    public MeanSquaredError() {
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override
    public List<AggregationBuilder> aggs(String actualField, String predictedField) {
        if (this.result != null) {
            return Collections.emptyList();
        }
        return Arrays.asList(AggregationBuilders.avg((String)AGG_NAME).script(new Script(MeanSquaredError.buildScript(actualField, predictedField))));
    }

    @Override
    public void process(Aggregations aggs) {
        NumericMetricsAggregation.SingleValue value = (NumericMetricsAggregation.SingleValue)aggs.get(AGG_NAME);
        this.result = value == null ? new Result(0.0) : new Result(value.value());
    }

    @Override
    public Optional<EvaluationMetricResult> getResult() {
        return Optional.ofNullable(this.result);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput out) throws IOException {
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        return o != null && this.getClass() == o.getClass();
    }

    public int hashCode() {
        return Objects.hashCode(NAME.getPreferredName());
    }

    public static class Result
    implements EvaluationMetricResult {
        private static final String ERROR = "error";
        private final double error;

        public Result(double error) {
            this.error = error;
        }

        public Result(StreamInput in) throws IOException {
            this.error = in.readDouble();
        }

        public String getWriteableName() {
            return NAME.getPreferredName();
        }

        @Override
        public String getMetricName() {
            return NAME.getPreferredName();
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeDouble(this.error);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(ERROR, this.error);
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Result other = (Result)o;
            return this.error == other.error;
        }

        public int hashCode() {
            return Objects.hashCode(this.error);
        }
    }
}

