/*
 * Decompiled with CFR 0.152.
 */
package floetteroed.utilities.math;

import floetteroed.utilities.math.MathHelpers;
import floetteroed.utilities.math.Matrix;
import floetteroed.utilities.math.Vector;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

public class MultinomialLogit {
    private final List<Integer> ALL_ATTRIBUTE_INDICES;
    private final List<Integer> ALL_ASC_INDICES;
    private double utilityScale = 1.0;
    private final Vector coeff;
    private final Vector asc;
    private final Matrix attr;
    private final Vector utilities;
    private final Vector choiceProbs;
    private final Matrix dProbs_dCoeffs;
    private final Matrix dProbs_dASCs;
    private boolean consistent;

    public MultinomialLogit(int choiceSetSize, int attributeCount) {
        this.coeff = new Vector(attributeCount);
        this.asc = new Vector(choiceSetSize);
        this.attr = new Matrix(choiceSetSize, attributeCount);
        this.utilities = new Vector(choiceSetSize);
        this.choiceProbs = new Vector(choiceSetSize);
        this.dProbs_dCoeffs = new Matrix(choiceSetSize, attributeCount);
        this.dProbs_dASCs = new Matrix(choiceSetSize, choiceSetSize);
        this.consistent = false;
        ArrayList<Integer> allAttrInd = new ArrayList<Integer>(attributeCount);
        for (int i = 0; i < attributeCount; ++i) {
            allAttrInd.add(i);
        }
        this.ALL_ATTRIBUTE_INDICES = Collections.unmodifiableList(allAttrInd);
        ArrayList<Integer> allASCInd = new ArrayList<Integer>(choiceSetSize);
        for (int i = attributeCount; i < attributeCount + choiceSetSize; ++i) {
            allASCInd.add(i);
        }
        this.ALL_ASC_INDICES = Collections.unmodifiableList(allASCInd);
    }

    public void setUtilityScale(double value) {
        this.consistent = false;
        this.utilityScale = value;
    }

    public void setCoefficient(int attrIndex, double value) {
        this.consistent = false;
        this.coeff.set(attrIndex, value);
    }

    public void setASC(int choiceIndex, double value) {
        this.consistent = false;
        this.asc.set(choiceIndex, value);
    }

    public void setAttribute(int choiceIndex, int attrIndex, double value) {
        this.consistent = false;
        this.attr.getRow(choiceIndex).set(attrIndex, value);
    }

    public void enforcedUpdate() {
        int i;
        double vMax = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.getChoiceSetSize(); ++i2) {
            double v = this.coeff.innerProd(this.attr.getRow(i2)) + this.asc.get(i2);
            this.utilities.set(i2, v);
            vMax = Math.max(vMax, v);
        }
        double pSum = 0.0;
        for (int i3 = 0; i3 < this.getChoiceSetSize(); ++i3) {
            double p = Math.exp(this.utilityScale * (this.utilities.get(i3) - vMax));
            this.choiceProbs.set(i3, p);
            pSum += p;
        }
        this.choiceProbs.mult(1.0 / pSum);
        this.dProbs_dCoeffs.clear();
        Vector probsTimesAttr = this.attr.timesVectorFromLeft(this.choiceProbs);
        for (i = 0; i < this.getChoiceSetSize(); ++i) {
            Vector dProbi_dCoeff = this.dProbs_dCoeffs.getRow(i);
            double probi = this.choiceProbs.get(i);
            Vector attri = this.attr.getRow(i);
            for (int j = 0; j < this.getAttrCount(); ++j) {
                dProbi_dCoeff.set(j, probi * (attri.get(j) - probsTimesAttr.get(j)));
            }
        }
        this.dProbs_dCoeffs.mult(this.utilityScale);
        this.dProbs_dASCs.clear();
        for (i = 0; i < this.getChoiceSetSize(); ++i) {
            double probi = this.choiceProbs.get(i);
            Vector dProbi_dASC = this.dProbs_dASCs.getRow(i);
            dProbi_dASC.set(i, probi);
            dProbi_dASC.add(this.choiceProbs, -probi);
        }
        this.dProbs_dASCs.mult(this.utilityScale);
        this.consistent = true;
    }

    public void conditionalUpdate() {
        if (!this.consistent) {
            this.enforcedUpdate();
        }
    }

    public double getUtilityScale() {
        return this.utilityScale;
    }

    public int getChoiceSetSize() {
        return this.choiceProbs.size();
    }

    public int getAttrCount() {
        return this.coeff.size();
    }

    public Vector getCoeff() {
        return this.coeff.newImmutableView();
    }

    public Vector getASC() {
        return this.asc.newImmutableView();
    }

    public Vector getProbs() {
        this.conditionalUpdate();
        return this.choiceProbs.newImmutableView();
    }

    public Vector getUtils() {
        this.conditionalUpdate();
        return this.utilities.newImmutableView();
    }

    public Matrix get_dProbs_dCoeffs() {
        this.conditionalUpdate();
        return this.dProbs_dCoeffs.newImmutableView();
    }

    public Matrix get_dProbs_dASCs() {
        this.conditionalUpdate();
        return this.dProbs_dASCs.newImmutableView();
    }

    public int draw(Random rnd) {
        return MathHelpers.draw(this.getProbs(), rnd);
    }

    public int getParameterSize(List<Integer> attributeIndices, boolean withASC) {
        return attributeIndices.size() + (withASC ? this.getChoiceSetSize() : 0);
    }

    public int getParameterSize(boolean withASC) {
        return this.getParameterSize(this.ALL_ATTRIBUTE_INDICES, withASC);
    }

    public double getParameter(int j) {
        if (j < this.getAttrCount()) {
            return this.getCoeff().get(j);
        }
        return this.getASC().get(j - this.getAttrCount());
    }

    public Vector getParameters(boolean withASC) {
        Vector result = new Vector(this.getParameterSize(withASC));
        for (int j = 0; j < result.size(); ++j) {
            result.set(j, this.getParameter(j));
        }
        return result;
    }

    public void setParameter(int j, double value) {
        this.consistent = false;
        if (j < this.getAttrCount()) {
            this.setCoefficient(j, value);
        } else {
            this.setASC(j - this.getAttrCount(), value);
        }
    }

    public void setParameters(Vector parameters) {
        this.consistent = false;
        for (int j = 0; j < parameters.size(); ++j) {
            this.setParameter(j, parameters.get(j));
        }
    }

    public Matrix get_dProb_dParameters(List<Integer> attributeIndices, boolean withASC) {
        this.conditionalUpdate();
        Matrix result = new Matrix(this.getChoiceSetSize(), this.getParameterSize(attributeIndices, withASC));
        for (int i = 0; i < this.getChoiceSetSize(); ++i) {
            Vector resulti = result.getRow(i);
            Vector dProbi_dCoeff = this.get_dProbs_dCoeffs().getRow(i);
            Vector dProbi_dASC = withASC ? this.get_dProbs_dASCs().getRow(i) : null;
            int l = 0;
            for (int j : attributeIndices) {
                resulti.set(l++, dProbi_dCoeff.get(j));
            }
            if (!withASC) continue;
            for (int i2 = 0; i2 < this.getChoiceSetSize(); ++i2) {
                resulti.set(l++, dProbi_dASC.get(i2));
            }
        }
        return result;
    }

    public Matrix get_dProb_dParameters(boolean withASC) {
        return this.get_dProb_dParameters(this.ALL_ATTRIBUTE_INDICES, withASC);
    }

    public List<Matrix> get_d2P_dbdb(double delta, List<Integer> attributeIndices, boolean withASC) {
        int paramSize = this.getParameterSize(attributeIndices, withASC);
        ArrayList<Integer> paramIndices = new ArrayList<Integer>(paramSize);
        paramIndices.addAll(attributeIndices);
        if (withASC) {
            paramIndices.addAll(this.ALL_ASC_INDICES);
        }
        ArrayList<Matrix> result = new ArrayList<Matrix>(this.getChoiceSetSize());
        for (int i = 0; i < this.getChoiceSetSize(); ++i) {
            result.add(new Matrix(paramSize, paramSize));
        }
        Matrix dP_db0 = this.get_dProb_dParameters(attributeIndices, withASC);
        int resultIndex = 0;
        Iterator iterator = paramIndices.iterator();
        while (iterator.hasNext()) {
            int r = (Integer)iterator.next();
            double br0 = this.getParameter(r);
            this.setParameter(r, br0 + delta);
            Matrix dP_dbVaried = this.get_dProb_dParameters(attributeIndices, withASC);
            for (int i = 0; i < this.getChoiceSetSize(); ++i) {
                Vector d2Pi_dbrdb = ((Matrix)result.get(i)).getRow(resultIndex);
                d2Pi_dbrdb.add(dP_dbVaried.getRow(i), 1.0);
                d2Pi_dbrdb.add(dP_db0.getRow(i), -1.0);
                d2Pi_dbrdb.mult(1.0 / delta);
            }
            this.setParameter(r, br0);
            ++resultIndex;
        }
        return result;
    }

    public List<Matrix> get_d2P_dbdb(double delta, boolean withASC) {
        return this.get_d2P_dbdb(delta, this.ALL_ATTRIBUTE_INDICES, withASC);
    }

    public Matrix getAttributesView() {
        this.conditionalUpdate();
        return this.attr.copy();
    }

    public void setAttributes(Matrix attr) {
        this.attr.clear();
        this.attr.add(attr, 1.0);
        this.consistent = false;
    }
}

