/*
 * Decompiled with CFR 0.152.
 */
package explicit;

import explicit.Belief;
import explicit.Distribution;
import explicit.MDPSimple;
import explicit.ObservationsSimple;
import explicit.POMDP;
import explicit.rewards.MDPRewards;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import parser.State;
import prism.PrismException;
import prism.PrismUtils;

public class POMDPSimple<Value>
extends MDPSimple<Value>
implements POMDP<Value> {
    protected ObservationsSimple observations;

    public POMDPSimple() {
        this.observations = new ObservationsSimple();
    }

    public POMDPSimple(int n) {
        super(n);
        this.observations = new ObservationsSimple(n);
    }

    public POMDPSimple(POMDPSimple<Value> pOMDPSimple) {
        super(pOMDPSimple);
        this.observations = new ObservationsSimple(pOMDPSimple.observations);
    }

    public POMDPSimple(POMDPSimple<Value> pOMDPSimple, int[] nArray) {
        super(pOMDPSimple, nArray);
        this.observations = new ObservationsSimple(pOMDPSimple.observations, nArray);
    }

    public POMDPSimple(MDPSimple<Value> mDPSimple) {
        super(mDPSimple);
        this.observations = new ObservationsSimple(mDPSimple.numStates);
        this.observations.setIdentityObservations();
    }

    @Override
    public void clearState(int n) {
        super.clearState(n);
        this.observations.clearState(n);
    }

    @Override
    public void addStates(int n) {
        super.addStates(n);
        this.observations.addStates(n);
    }

    public void setObservationsList(List<State> list) {
        this.observations.setObservationsList(list);
    }

    public void setUnobservationsList(List<State> list) {
        this.observations.setUnobservationsList(list);
    }

    public void setObservation(int n, State state, State state2, List<String> list) throws PrismException {
        this.observations.setObservation(n, state, state2, list, this);
    }

    protected void setObservation(int n, int n2) throws PrismException {
        this.observations.setObservation(n, n2, this);
    }

    @Override
    public List<State> getObservationsList() {
        return this.observations.getObservationsList();
    }

    @Override
    public List<State> getUnobservationsList() {
        return this.observations.getUnobservationsList();
    }

    @Override
    public int getObservation(int n) {
        return this.observations.getObservation(n);
    }

    @Override
    public int getUnobservation(int n) {
        return this.observations.getUnobservation(n);
    }

    @Override
    public int getNumChoicesForObservation(int n) {
        return this.getNumChoices(this.observations.getObservationState(n));
    }

    @Override
    public Object getActionForObservation(int n, int n2) {
        return this.getAction(this.observations.getObservationState(n), n2);
    }

    @Override
    public Belief getInitialBelief() {
        double[] dArray = new double[this.numStates];
        for (Integer n : this.initialStates) {
            dArray[n.intValue()] = 1.0;
        }
        PrismUtils.normalise(dArray);
        return new Belief(dArray, this);
    }

    @Override
    public double[] getInitialBeliefInDist() {
        double[] dArray = new double[this.numStates];
        for (Integer n : this.initialStates) {
            dArray[n.intValue()] = 1.0;
        }
        PrismUtils.normalise(dArray);
        return dArray;
    }

    @Override
    public Belief getBeliefAfterChoice(Belief belief, int n) {
        double[] dArray = belief.toDistributionOverStates(this);
        double[] dArray2 = this.getBeliefInDistAfterChoice(dArray, n);
        return new Belief(dArray2, this);
    }

    @Override
    public double[] getBeliefInDistAfterChoice(double[] dArray, int n) {
        int n2 = dArray.length;
        double[] dArray2 = new double[n2];
        for (int i = 0; i < n2; ++i) {
            if (!(dArray[i] >= 1.0E-6)) continue;
            Distribution distribution = this.getChoice(i, n);
            for (Map.Entry entry : distribution) {
                int n3 = (Integer)entry.getKey();
                double d = (Double)entry.getValue();
                int n4 = n3;
                dArray2[n4] = dArray2[n4] + dArray[i] * d;
            }
        }
        return dArray2;
    }

    @Override
    public Belief getBeliefAfterChoiceAndObservation(Belief belief, int n, int n2) {
        double[] dArray = belief.toDistributionOverStates(this);
        double[] dArray2 = this.getBeliefInDistAfterChoiceAndObservation(dArray, n, n2);
        Belief belief2 = new Belief(dArray2, this);
        assert (belief2.so == n2);
        return belief2;
    }

    @Override
    public double[] getBeliefInDistAfterChoiceAndObservation(double[] dArray, int n, int n2) {
        int n3 = dArray.length;
        double[] dArray2 = new double[n3];
        double[] dArray3 = this.getBeliefInDistAfterChoice(dArray, n);
        for (int i = 0; i < n3; ++i) {
            double d;
            dArray2[i] = d = dArray3[i] * this.getObservationProb(i, n2);
        }
        PrismUtils.normalise(dArray2);
        return dArray2;
    }

    @Override
    public double getObservationProbAfterChoice(Belief belief, int n, int n2) {
        double[] dArray = belief.toDistributionOverStates(this);
        double d = this.getObservationProbAfterChoice(dArray, n, n2);
        return d;
    }

    @Override
    public double getObservationProbAfterChoice(double[] dArray, int n, int n2) {
        double[] dArray2 = this.getBeliefInDistAfterChoice(dArray, n);
        double d = 0.0;
        for (int i = 0; i < dArray2.length; ++i) {
            d += dArray2[i] * this.getObservationProb(i, n2);
        }
        return d;
    }

    @Override
    public HashMap<Integer, Double> computeObservationProbsAfterAction(double[] dArray, int n) {
        HashMap<Integer, Double> hashMap = new HashMap<Integer, Double>();
        double[] dArray2 = this.getBeliefInDistAfterChoice(dArray, n);
        for (int i = 0; i < dArray2.length; ++i) {
            int n2 = this.getObservation(i);
            double d = dArray2[i];
            if (!(d > 1.0E-6)) continue;
            Double d2 = hashMap.get(n2);
            if (d2 == null) {
                hashMap.put(n2, d);
                continue;
            }
            hashMap.put(n2, d2 + d);
        }
        return hashMap;
    }

    @Override
    public double getRewardAfterChoice(Belief belief, int n, MDPRewards<Double> mDPRewards) {
        double[] dArray = belief.toDistributionOverStates(this);
        double d = this.getRewardAfterChoice(dArray, n, mDPRewards);
        return d;
    }

    @Override
    public double getRewardAfterChoice(double[] dArray, int n, MDPRewards<Double> mDPRewards) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i] == 0.0) {
                d += 0.0;
                continue;
            }
            d += dArray[i] * ((Double)mDPRewards.getTransitionReward(i, n) + (Double)mDPRewards.getStateReward(i));
        }
        return d;
    }

    protected Belief beliefInDistToBelief(double[] dArray) {
        int n = -1;
        double[] dArray2 = new double[this.getNumUnobservations()];
        for (int i = 0; i < dArray.length; ++i) {
            if (dArray[i] == 0.0) continue;
            n = this.getObservation(i);
            int n2 = this.getUnobservation(i);
            dArray2[n2] = dArray2[n2] + dArray[i];
        }
        Belief belief = null;
        if (n != -1) {
            belief = new Belief(n, dArray2);
        } else {
            System.err.println("Something wrong in POMDPSimple.beliefInDistToBelief(double[] beliefInDist)");
        }
        return belief;
    }

    @Override
    public String toString() {
        Object object = "";
        object = "[ ";
        for (int i = 0; i < this.numStates; ++i) {
            if (i > 0) {
                object = (String)object + ", ";
            }
            object = (String)object + i + "(" + this.getObservation(i) + "/" + this.getUnobservation(i) + "): ";
            object = (String)object + "[";
            int n = this.getNumChoices(i);
            for (int j = 0; j < n; ++j) {
                Object object2;
                if (j > 0) {
                    object = (String)object + ",";
                }
                if ((object2 = this.getAction(i, j)) != null) {
                    object = (String)object + String.valueOf(object2) + ":";
                }
                object = (String)object + String.valueOf(((List)this.trans.get(i)).get(j));
            }
            object = (String)object + "]";
        }
        object = (String)object + " ]\n";
        return object;
    }

    @Override
    public boolean equals(Object object) {
        if (object == null || !(object instanceof POMDPSimple)) {
            return false;
        }
        POMDPSimple pOMDPSimple = (POMDPSimple)object;
        if (this.numStates != pOMDPSimple.numStates) {
            return false;
        }
        if (!this.initialStates.equals(pOMDPSimple.initialStates)) {
            return false;
        }
        return this.trans.equals(pOMDPSimple.trans);
    }
}

