/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.vectors.query;

import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
import org.elasticsearch.xpack.vectors.query.VectorScriptDocValues;

public class ScoreScriptUtils {
    private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
        double v1v2DotProduct = 0.0;
        int v1Index = 0;
        int v2Index = 0;
        while (v1Index < v1Values.length && v2Index < v2Values.length) {
            if (v1Dims[v1Index] == v2Dims[v2Index]) {
                v1v2DotProduct += (double)(v1Values[v1Index] * v2Values[v2Index]);
                ++v1Index;
                ++v2Index;
                continue;
            }
            if (v1Dims[v1Index] > v2Dims[v2Index]) {
                ++v2Index;
                continue;
            }
            ++v1Index;
        }
        return v1v2DotProduct;
    }

    public static final class CosineSimilaritySparse
    extends SparseVectorFunction {
        final double queryVectorMagnitude;

        public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
            super(scoreScript, queryVector);
            double dotProduct = 0.0;
            for (int i = 0; i < this.queryDims.length; ++i) {
                dotProduct += (double)(this.queryValues[i] * this.queryValues[i]);
            }
            this.queryVectorMagnitude = Math.sqrt(dotProduct);
        }

        public double cosineSimilaritySparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.scoreScript._getIndexVersion(), vector);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(this.scoreScript._getIndexVersion(), vector);
            double docQueryDotProduct = ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, docValues, docDims);
            double docVectorMagnitude = 0.0;
            if (this.scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) {
                docVectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(this.scoreScript._getIndexVersion(), vector);
            } else {
                for (float docValue : docValues) {
                    docVectorMagnitude += (double)(docValue * docValue);
                }
                docVectorMagnitude = (float)Math.sqrt(docVectorMagnitude);
            }
            return docQueryDotProduct / (docVectorMagnitude * this.queryVectorMagnitude);
        }
    }

    public static final class DotProductSparse
    extends SparseVectorFunction {
        public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
            super(scoreScript, queryVector);
        }

        public double dotProductSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.scoreScript._getIndexVersion(), vector);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(this.scoreScript._getIndexVersion(), vector);
            return ScoreScriptUtils.intDotProductSparse(this.queryValues, this.queryDims, docValues, docDims);
        }
    }

    public static final class L2NormSparse
    extends SparseVectorFunction {
        public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
            super(scoreScript, queryVector);
        }

        public double l2normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.scoreScript._getIndexVersion(), vector);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(this.scoreScript._getIndexVersion(), vector);
            int queryIndex = 0;
            int docIndex = 0;
            double l2norm = 0.0;
            while (queryIndex < this.queryDims.length && docIndex < docDims.length) {
                double diff;
                if (this.queryDims[queryIndex] == docDims[docIndex]) {
                    diff = this.queryValues[queryIndex] - docValues[docIndex];
                    l2norm += diff * diff;
                    ++queryIndex;
                    ++docIndex;
                    continue;
                }
                if (this.queryDims[queryIndex] > docDims[docIndex]) {
                    diff = docValues[docIndex];
                    l2norm += diff * diff;
                    ++docIndex;
                    continue;
                }
                diff = this.queryValues[queryIndex];
                l2norm += diff * diff;
                ++queryIndex;
            }
            while (queryIndex < this.queryDims.length) {
                l2norm += (double)(this.queryValues[queryIndex] * this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (docIndex < docDims.length) {
                l2norm += (double)(docValues[docIndex] * docValues[docIndex]);
                ++docIndex;
            }
            return Math.sqrt(l2norm);
        }
    }

    public static final class L1NormSparse
    extends SparseVectorFunction {
        public L1NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector) {
            super(scoreScript, queryVector);
        }

        public double l1normSparse(VectorScriptDocValues.SparseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            int[] docDims = VectorEncoderDecoder.decodeSparseVectorDims(this.scoreScript._getIndexVersion(), vector);
            float[] docValues = VectorEncoderDecoder.decodeSparseVector(this.scoreScript._getIndexVersion(), vector);
            int queryIndex = 0;
            int docIndex = 0;
            double l1norm = 0.0;
            while (queryIndex < this.queryDims.length && docIndex < docDims.length) {
                if (this.queryDims[queryIndex] == docDims[docIndex]) {
                    l1norm += (double)Math.abs(this.queryValues[queryIndex] - docValues[docIndex]);
                    ++queryIndex;
                    ++docIndex;
                    continue;
                }
                if (this.queryDims[queryIndex] > docDims[docIndex]) {
                    l1norm += (double)Math.abs(docValues[docIndex]);
                    ++docIndex;
                    continue;
                }
                l1norm += (double)Math.abs(this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (queryIndex < this.queryDims.length) {
                l1norm += (double)Math.abs(this.queryValues[queryIndex]);
                ++queryIndex;
            }
            while (docIndex < docDims.length) {
                l1norm += (double)Math.abs(docValues[docIndex]);
                ++docIndex;
            }
            return l1norm;
        }
    }

    public static class SparseVectorFunction {
        final ScoreScript scoreScript;
        final float[] queryValues;
        final int[] queryDims;

        public SparseVectorFunction(ScoreScript scoreScript, Map<String, Number> queryVector) {
            this.scoreScript = scoreScript;
            int n = queryVector.size();
            this.queryValues = new float[n];
            this.queryDims = new int[n];
            int i = 0;
            for (Map.Entry<String, Number> dimValue : queryVector.entrySet()) {
                try {
                    this.queryDims[i] = Integer.parseInt(dimValue.getKey());
                }
                catch (NumberFormatException e) {
                    throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
                }
                this.queryValues[i] = dimValue.getValue().floatValue();
                ++i;
            }
            VectorEncoderDecoder.sortSparseDimsFloatValues(this.queryDims, this.queryValues, n);
        }

        public void validateDocVector(BytesRef vector) {
            if (vector == null) {
                throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
            }
        }
    }

    public static final class CosineSimilarity
    extends DenseVectorFunction {
        public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector) {
            super(scoreScript, queryVector, true);
        }

        public double cosineSimilarity(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double dotProduct = 0.0;
            double vectorMagnitude = 0.0;
            if (this.scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) {
                for (float queryValue : this.queryVector) {
                    dotProduct += (double)(queryValue * byteBuffer.getFloat());
                }
                vectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(this.scoreScript._getIndexVersion(), vector);
            } else {
                for (float queryValue : this.queryVector) {
                    float docValue = byteBuffer.getFloat();
                    dotProduct += (double)(queryValue * docValue);
                    vectorMagnitude += (double)(docValue * docValue);
                }
                vectorMagnitude = (float)Math.sqrt(vectorMagnitude);
            }
            return dotProduct / vectorMagnitude;
        }
    }

    public static final class DotProduct
    extends DenseVectorFunction {
        public DotProduct(ScoreScript scoreScript, List<Number> queryVector) {
            super(scoreScript, queryVector);
        }

        public double dotProduct(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double dotProduct = 0.0;
            for (float queryValue : this.queryVector) {
                dotProduct += (double)(queryValue * byteBuffer.getFloat());
            }
            return dotProduct;
        }
    }

    public static final class L2Norm
    extends DenseVectorFunction {
        public L2Norm(ScoreScript scoreScript, List<Number> queryVector) {
            super(scoreScript, queryVector);
        }

        public double l2norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double l2norm = 0.0;
            for (float queryValue : this.queryVector) {
                double diff = queryValue - byteBuffer.getFloat();
                l2norm += diff * diff;
            }
            return Math.sqrt(l2norm);
        }
    }

    public static final class L1Norm
    extends DenseVectorFunction {
        public L1Norm(ScoreScript scoreScript, List<Number> queryVector) {
            super(scoreScript, queryVector);
        }

        public double l1norm(VectorScriptDocValues.DenseVectorScriptDocValues dvs) {
            BytesRef vector = dvs.getEncodedValue();
            this.validateDocVector(vector);
            ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
            double l1norm = 0.0;
            for (float queryValue : this.queryVector) {
                l1norm += (double)Math.abs(queryValue - byteBuffer.getFloat());
            }
            return l1norm;
        }
    }

    public static class DenseVectorFunction {
        final ScoreScript scoreScript;
        final float[] queryVector;

        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector) {
            this(scoreScript, queryVector, false);
        }

        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector, boolean normalizeQuery) {
            this.scoreScript = scoreScript;
            this.queryVector = new float[queryVector.size()];
            double queryMagnitude = 0.0;
            for (int i = 0; i < queryVector.size(); ++i) {
                float value;
                this.queryVector[i] = value = queryVector.get(i).floatValue();
                queryMagnitude += (double)(value * value);
            }
            queryMagnitude = Math.sqrt(queryMagnitude);
            if (normalizeQuery) {
                int dim = 0;
                while (dim < this.queryVector.length) {
                    int n = dim++;
                    this.queryVector[n] = (float)((double)this.queryVector[n] / queryMagnitude);
                }
            }
        }

        public void validateDocVector(BytesRef vector) {
            if (vector == null) {
                throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
            }
            int vectorLength = VectorEncoderDecoder.denseVectorLength(this.scoreScript._getIndexVersion(), vector);
            if (this.queryVector.length != vectorLength) {
                throw new IllegalArgumentException("The query vector has a different number of dimensions [" + this.queryVector.length + "] than the document vectors [" + vectorLength + "].");
            }
        }
    }
}

