/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ContextParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator;
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.ClassificationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

public class MulticlassConfusionMatrix
implements ClassificationMetric {
    public static final ParseField NAME = new ParseField("multiclass_confusion_matrix", new String[0]);
    public static final ParseField SIZE = new ParseField("size", new String[0]);
    private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = MulticlassConfusionMatrix.createParser();
    private static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
    private static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
    private static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
    private static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
    private static final String OTHER_BUCKET_KEY = "_other_";
    private static final int DEFAULT_SIZE = 10;
    private static final int MAX_SIZE = 1000;
    private final int size;
    private List<String> topActualClassNames;
    private Result result;

    private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
        ConstructingObjectParser parser = new ConstructingObjectParser(NAME.getPreferredName(), true, args -> new MulticlassConfusionMatrix((Integer)args[0]));
        parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), SIZE);
        return parser;
    }

    public static MulticlassConfusionMatrix fromXContent(XContentParser parser) {
        return (MulticlassConfusionMatrix)PARSER.apply(parser, null);
    }

    public MulticlassConfusionMatrix() {
        this((Integer)null);
    }

    public MulticlassConfusionMatrix(@Nullable Integer size) {
        if (size != null && (size <= 0 || size > 1000)) {
            throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), 1000);
        }
        this.size = size != null ? size : 10;
    }

    public MulticlassConfusionMatrix(StreamInput in) throws IOException {
        this.size = in.readVInt();
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    @Override
    public String getName() {
        return NAME.getPreferredName();
    }

    public int getSize() {
        return this.size;
    }

    @Override
    public final List<AggregationBuilder> aggs(String actualField, String predictedField) {
        if (this.topActualClassNames == null) {
            return Arrays.asList(((TermsAggregationBuilder)AggregationBuilders.terms((String)STEP_1_AGGREGATE_BY_ACTUAL_CLASS).field(actualField)).order(Arrays.asList(BucketOrder.count((boolean)false), BucketOrder.key((boolean)true))).size(this.size));
        }
        if (this.result == null) {
            FiltersAggregator.KeyedFilter[] keyedFiltersActual = (FiltersAggregator.KeyedFilter[])this.topActualClassNames.stream().map(className -> new FiltersAggregator.KeyedFilter(className, (QueryBuilder)QueryBuilders.termQuery((String)actualField, (String)className))).toArray(FiltersAggregator.KeyedFilter[]::new);
            FiltersAggregator.KeyedFilter[] keyedFiltersPredicted = (FiltersAggregator.KeyedFilter[])this.topActualClassNames.stream().map(className -> new FiltersAggregator.KeyedFilter(className, (QueryBuilder)QueryBuilders.termQuery((String)predictedField, (String)className))).toArray(FiltersAggregator.KeyedFilter[]::new);
            return Arrays.asList(AggregationBuilders.cardinality((String)STEP_2_CARDINALITY_OF_ACTUAL_CLASS).field(actualField), AggregationBuilders.filters((String)STEP_2_AGGREGATE_BY_ACTUAL_CLASS, (FiltersAggregator.KeyedFilter[])keyedFiltersActual).subAggregation((AggregationBuilder)AggregationBuilders.filters((String)STEP_2_AGGREGATE_BY_PREDICTED_CLASS, (FiltersAggregator.KeyedFilter[])keyedFiltersPredicted).otherBucket(true).otherBucketKey(OTHER_BUCKET_KEY)));
        }
        return Collections.emptyList();
    }

    @Override
    public void process(Aggregations aggs) {
        if (this.topActualClassNames == null && aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS) != null) {
            Terms termsAgg = (Terms)aggs.get(STEP_1_AGGREGATE_BY_ACTUAL_CLASS);
            this.topActualClassNames = termsAgg.getBuckets().stream().map(MultiBucketsAggregation.Bucket::getKeyAsString).sorted().collect(Collectors.toList());
        }
        if (this.result == null && aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS) != null) {
            Cardinality cardinalityAgg = (Cardinality)aggs.get(STEP_2_CARDINALITY_OF_ACTUAL_CLASS);
            Filters filtersAgg = (Filters)aggs.get(STEP_2_AGGREGATE_BY_ACTUAL_CLASS);
            ArrayList<ActualClass> actualClasses = new ArrayList<ActualClass>(filtersAgg.getBuckets().size());
            for (Filters.Bucket bucket : filtersAgg.getBuckets()) {
                String actualClass = bucket.getKeyAsString();
                long actualClassDocCount = bucket.getDocCount();
                Filters subAgg = (Filters)bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
                ArrayList<PredictedClass> predictedClasses = new ArrayList<PredictedClass>();
                long otherPredictedClassDocCount = 0L;
                for (Filters.Bucket subBucket : subAgg.getBuckets()) {
                    String predictedClass = subBucket.getKeyAsString();
                    long docCount = subBucket.getDocCount();
                    if (OTHER_BUCKET_KEY.equals(predictedClass)) {
                        otherPredictedClassDocCount = docCount;
                        continue;
                    }
                    predictedClasses.add(new PredictedClass(predictedClass, docCount));
                }
                predictedClasses.sort(Comparator.comparing(PredictedClass::getPredictedClass));
                actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
            }
            this.result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - (long)this.size, 0L));
        }
    }

    @Override
    public Optional<EvaluationMetricResult> getResult() {
        return Optional.ofNullable(this.result);
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeVInt(this.size);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(SIZE.getPreferredName(), this.size);
        builder.endObject();
        return builder;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        MulticlassConfusionMatrix that = (MulticlassConfusionMatrix)o;
        return Objects.equals(this.size, that.size);
    }

    public int hashCode() {
        return Objects.hash(this.size);
    }

    private static long requireNonNegative(long value, ParseField field) {
        if (value < 0L) {
            throw ExceptionsHelper.serverError("[" + field.getPreferredName() + "] must be >= 0, was: " + value);
        }
        return value;
    }

    public static class PredictedClass
    implements ToXContentObject,
    Writeable {
        private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class", new String[0]);
        private static final ParseField COUNT = new ParseField("count", new String[0]);
        private static final ConstructingObjectParser<PredictedClass, Void> PARSER = new ConstructingObjectParser("multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String)a[0], (Long)a[1]));
        private final String predictedClass;
        private final long count;

        public PredictedClass(String predictedClass, long count) {
            this.predictedClass = ExceptionsHelper.requireNonNull(predictedClass, PREDICTED_CLASS);
            this.count = MulticlassConfusionMatrix.requireNonNegative(count, PredictedClass.COUNT);
        }

        public PredictedClass(StreamInput in) throws IOException {
            this.predictedClass = in.readString();
            this.count = in.readVLong();
        }

        public String getPredictedClass() {
            return this.predictedClass;
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeString(this.predictedClass);
            out.writeVLong(this.count);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(PREDICTED_CLASS.getPreferredName(), this.predictedClass);
            builder.field(COUNT.getPreferredName(), this.count);
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            PredictedClass that = (PredictedClass)o;
            return Objects.equals(this.predictedClass, that.predictedClass) && this.count == that.count;
        }

        public int hashCode() {
            return Objects.hash(this.predictedClass, this.count);
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_CLASS);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), COUNT);
        }
    }

    public static class ActualClass
    implements ToXContentObject,
    Writeable {
        private static final ParseField ACTUAL_CLASS = new ParseField("actual_class", new String[0]);
        private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count", new String[0]);
        private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes", new String[0]);
        private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count", new String[0]);
        private static final ConstructingObjectParser<ActualClass, Void> PARSER = new ConstructingObjectParser("multiclass_confusion_matrix_actual_class", true, a -> new ActualClass((String)a[0], (Long)a[1], (List)a[2], (Long)a[3]));
        private final String actualClass;
        private final long actualClassDocCount;
        private final List<PredictedClass> predictedClasses;
        private final long otherPredictedClassDocCount;

        public ActualClass(String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
            this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
            this.actualClassDocCount = MulticlassConfusionMatrix.requireNonNegative(actualClassDocCount, ActualClass.ACTUAL_CLASS_DOC_COUNT);
            this.predictedClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(predictedClasses, PREDICTED_CLASSES));
            this.otherPredictedClassDocCount = MulticlassConfusionMatrix.requireNonNegative(otherPredictedClassDocCount, ActualClass.OTHER_PREDICTED_CLASS_DOC_COUNT);
        }

        public ActualClass(StreamInput in) throws IOException {
            this.actualClass = in.readString();
            this.actualClassDocCount = in.readVLong();
            this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new));
            this.otherPredictedClassDocCount = in.readVLong();
        }

        public String getActualClass() {
            return this.actualClass;
        }

        public List<PredictedClass> getPredictedClasses() {
            return this.predictedClasses;
        }

        public long getOtherPredictedClassDocCount() {
            return this.otherPredictedClassDocCount;
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeString(this.actualClass);
            out.writeVLong(this.actualClassDocCount);
            out.writeList(this.predictedClasses);
            out.writeVLong(this.otherPredictedClassDocCount);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(ACTUAL_CLASS.getPreferredName(), this.actualClass);
            builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), this.actualClassDocCount);
            builder.field(PREDICTED_CLASSES.getPreferredName(), this.predictedClasses);
            builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), this.otherPredictedClassDocCount);
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ActualClass that = (ActualClass)o;
            return Objects.equals(this.actualClass, that.actualClass) && this.actualClassDocCount == that.actualClassDocCount && Objects.equals(this.predictedClasses, that.predictedClasses) && this.otherPredictedClassDocCount == that.otherPredictedClassDocCount;
        }

        public int hashCode() {
            return Objects.hash(this.actualClass, this.actualClassDocCount, this.predictedClasses, this.otherPredictedClassDocCount);
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_CLASS);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), ACTUAL_CLASS_DOC_COUNT);
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (ContextParser)PredictedClass.PARSER, PREDICTED_CLASSES);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
        }
    }

    public static class Result
    implements EvaluationMetricResult {
        private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix", new String[0]);
        private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count", new String[0]);
        private static final ConstructingObjectParser<Result, Void> PARSER = new ConstructingObjectParser("multiclass_confusion_matrix_result", true, a -> new Result((List)a[0], (Long)a[1]));
        private final List<ActualClass> actualClasses;
        private final long otherActualClassCount;

        public static Result fromXContent(XContentParser parser) {
            return (Result)PARSER.apply(parser, null);
        }

        public Result(List<ActualClass> actualClasses, long otherActualClassCount) {
            this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, CONFUSION_MATRIX));
            this.otherActualClassCount = MulticlassConfusionMatrix.requireNonNegative(otherActualClassCount, Result.OTHER_ACTUAL_CLASS_COUNT);
        }

        public Result(StreamInput in) throws IOException {
            this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
            this.otherActualClassCount = in.readVLong();
        }

        public String getWriteableName() {
            return NAME.getPreferredName();
        }

        @Override
        public String getMetricName() {
            return NAME.getPreferredName();
        }

        public List<ActualClass> getConfusionMatrix() {
            return this.actualClasses;
        }

        public long getOtherActualClassCount() {
            return this.otherActualClassCount;
        }

        public void writeTo(StreamOutput out) throws IOException {
            out.writeList(this.actualClasses);
            out.writeVLong(this.otherActualClassCount);
        }

        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
            builder.startObject();
            builder.field(CONFUSION_MATRIX.getPreferredName(), this.actualClasses);
            builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), this.otherActualClassCount);
            builder.endObject();
            return builder;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Result that = (Result)o;
            return Objects.equals(this.actualClasses, that.actualClasses) && this.otherActualClassCount == that.otherActualClassCount;
        }

        public int hashCode() {
            return Objects.hash(this.actualClasses, this.otherActualClassCount);
        }

        static {
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), (ContextParser)ActualClass.PARSER, CONFUSION_MATRIX);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), OTHER_ACTUAL_CLASS_COUNT);
        }
    }
}

