class Cafe:

    def actions(self):
        """ definit l'ensemble d'actions possibles """
        return(['tournerG','tournerD','prendre','poser','avancer'])

    def etats(self):
        """
        definit l'ensemble d'etats possibles (pos,orientation,cafe)
        orientation ==> 0 = nord, 1 = est, ....
        """
        etats=[]
        self.taille=3
        for posX in range(1,self.taille+1):
            for posY in range(1,self.taille+1):
                for orientation in range(0,4):
                    for cafe in range(0,2):
                        etats+=[(posX,posY,orientation,cafe)]
        return(etats)

    def transition(self,s,a):
        """ definit les consequences d'une action """
        posX=s[0];
        posY=s[1];
        orientation=s[2]
        cafe=s[3];

        # tourner a gauche
        if (a=='tournerG'):
            orientation=(orientation+3)%4;

        #tourner a droite
        if (a=='tournerD'):
            orientation=(orientation+1)%4;

        #prendre le cafe
        if (a=='prendre'):
            if (posX==3) and (posY==3):
                cafe=1

        #poser cafe
        if (a=='poser'):
            cafe=0

        # avancer
        if (a=='avancer'):
            #orientation est
            if (orientation==1) and (posX<self.taille):
                posX=posX+1
            # nord
            if (orientation==0) and (posY>1):
                posY=posY-1
            #sud
            if (orientation==2) and (posY<self.taille):
                posY=posY+1
            #Ouest
            if (orientation==3) and (posX>1):
                posX=posX-1

        return(posX,posY,orientation,cafe);


    def recompense(self,s,a,sarr):
        posX=s[0];
        posY=s[1];
        orientation=s[2]
        cafe=s[3];
        """ definit les recompenses obtenues """
        if (posX==1) and (posY==1) and(cafe==1) and (a=='poser'):
            return(100)
        return(-1)

    def afficherEtat(self,s):
        print ("pos:",s[0]," orientation: ",s[1]," cafe:",s[2])


#************************************************************

class SystemeExecute:

    def __init__(self,pb):
        """ construit un systeme a partir d'un probleme """
        self.pb=pb

    def executerPi(self,pi,depart,nb):
        """ excecute la  politique pi a partir de l'etat depart """
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=pb.transition(s,action)
            print(s," -> ",action," : ",sFin)
            s=sFin

    def executerPiRec(self,pi,depart,nb):
        """ excecute la  politique pi a partir de l'etat depart et retourne perf"""
        s=depart
        gamma=0.99
        somme=0
        for i in range(nb):
            action=pi[s]
            sFin=pb.transition(s,action)
            r=pb.recompense(s,action,sFin)
            somme=somme+(gamma**i)*r
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin
        return somme



#************************************************************



class Systeme:
    """permet d'executer un probleme
    - pb : attribut du probleme"""

    def __init__(self,pb):
        """construit un systeme a partir d'un probleme"""
        self.pb=pb

    def execute(self,s,a):
        #print ("etat depart: ",s)
        #print ("action: ",a)
        sArriv = self.pb.transition(s,a)
        #print ("etat arrivee: ",sArriv)
        #print ("recompense: ",self.pb.recompense(s,a,sArriv))
        #print("********")
        return(sArriv)

    def intialiseQ(self):
        """construit des Qvaleurs vides"""
        Q = {}
        return Q


    def executerPi(self,pi,depart,nb):
        """permet d'executer la politique pendant nb pas de temps"""
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=self.execute(s,action)
            r=self.pb.recompense(s,action,sFin)
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin

    def executerPiRec(self,pi,depart,nb):
        """ excecute la  politique pi a partir de l'etat depart et retourne perf"""
        s=depart
        gamma=0.99
        somme=0
        for i in range(nb):
            action=pi[s]
            sFin=pb.transition(s,action)
            r=pb.recompense(s,action,sFin)
            somme=somme+(gamma**i)*r
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin
        return somme


    def afficherPi(self,pi):
        """ permet d'afficher la politique pi"""
        for etat in self.pb.etats():
            chaine = str(etat)
            chaine+= " -> " + str(pi[etat])
            print(chaine)


    def afficherQ(self,Q):
        """ permet d'afficher les Qvaleurs"""
        for etat in self.pb.etats():
            chaine = str(etat)
            print()
            print(chaine)
            for action in self.pb.actions():
                chaine="   - "+action+" -> "+str(round(Q[(etat,action)],2))
                print(chaine)

    def afficherQS(self,Q,etat):
        """ permet d'afficher les Qvaleurs de l'etat etat"""
        chaine = ""+str(etat)+" - "
        for action in self.pb.actions():
            chaine+=action+" -> "+str(Q[(etat,action)])+", "
        print(chaine)


    def politiqueFromQ(self,Q):
        """ construit la politique a partir de Q """
        pi={}
        for etat in self.pb.etats():
            # cherche max arrivee
            max=-100000
            amax=-1;
            for actionMax in self.pb.actions():
                if (Q[(etat,actionMax)]>max):
                    max=Q[(etat,actionMax)]
                    amax=actionMax
            pi[etat]=amax
        return(pi)


    ####################################################
    # A COMPLETER
    ####################################################
    def valueIteration(self,nb):
        """effectue algorithme value iteration"""
        gamma=0.99
        # initialiser les Q valeurs
        Q=self.intialiseQ()
        # faire nb iteration
        for i in range(0,nb):
            # creer une mise a jour de Qvaleurs
            Q = self.executerUneIteration(Q,gamma)
            #affiche Q
            print('*** iteration '+str(i)+' ****')
            #self.afficherQ(Q)
        # retourne Qvaleurs
        return(Q)

    def executerUneIteration(self,Q,gamma):
        """ effectue une iteration de value iteration """
        #initialiser Q2
        Q2={}
        #pour chaque etat
        for etat in self.pb.etats():
            #pour chaque action
            for action in self.pb.actions():
                # calculer arrivee et recompense
                sArriv=self.pb.transition(etat,action)
                r=self.pb.recompense(etat,action,sArriv)
                # cherche max arrivee (utiliser chercherMax)
                max=self.chercherMax(Q,sArriv)
                #mise a jour de Q2
                Q2[(etat,action)]=r+gamma*max
        #retourne Q2
        return(Q2)

    def chercherMax(self,Q,sArriv):
        """ cherche la valeur maximale sur un etat"""
        max=-100000;
        #pour chaque action
        for actionMax in self.pb.actions():
            #si la clef n'existe pas, on cree
            if (not((sArriv,actionMax) in Q)):
                Q[(sArriv,actionMax)]=0
            #si la valeur est plus grande que max, c'est le max
            if (Q[(sArriv,actionMax)]>max):
                max=Q[(sArriv,actionMax)]
        #retourne max
        return(max)

    ####################################################
    # FIN A COMPLETER
    ####################################################


#************************************************************


print("****************************************")
print("******        probleme cafe         ****")
print("****************************************")


pb=Cafe()
print('etats => ',pb.etats())
print('actions => ', pb.actions())

print()
print("****************************************")
print("*******       planif            ********")
print("****************************************")
print("planification")
systeme=Systeme(pb)
Q=systeme.valueIteration(30)


print()
print("****************************************")
print("***         Affiche Q               ****")
print("****************************************")
#print(Q)
systeme.afficherQ(Q)
#systeme.afficherQS(Q,(3,2))
#systeme.afficherQS(Q,(3,3))


print()
print("****************************************")
print("***         Affiche pi              ****")
print("****************************************")
pi=systeme.politiqueFromQ(Q)
systeme.afficherPi(pi)

print()
print("****************************************")
print("***         Execute pi              ****")
print("****************************************")


somme=systeme.executerPiRec(pi,(1,1,0,0),30)
print ("somme gamma^t*rec: ",somme)
