class LabyOrientation:
    """labyrinthe avec orientation"""
    """aspect stochastique d'avancer"""

    def __init__(self):
        self.dx=5
        self.dy=5
        #objectif a atteindre
        self.tresor=[(4,4)]
        # les trous
        self.trou=[]
        self.trou+=[(2,3)]
        self.trou+=[(2,0)]
        self.trou+=[(2,2)]
        self.trou+=[(0,3)]
        self.trou+=[(3,3)]
        self.trou+=[(4,3)]
        
        
    
    def transition(self,s,a):
        """ etat est simplement (x,y,orientation)"""
        """ orientation = N,S,E,O"""
        """ action D =devant, G=gauche, droite"""

        if (s=="plouf"):
            return[("plouf",1.0)]

        if (s=="tresor"):
            return[("tresor",1.0)]
        
        
        if (a=="gauche"):
            sfin=(s[0],s[1],self.tournerG(s[2]))
            return [(sfin,1.)]

        if (a=="droite"):
            sfin=(s[0],s[1],self.tournerD(s[2]))
            return [(sfin,1.)]
        
        if (a=="devant"):
            s1=self.avancer(s)
            s2=self.avancer(s1)
            return [(s1,0.8),(s2,0.2)]
          
    def tournerG(self,o):
        if (o=="N"):
            return "O"
        if (o=="O"):
            return "S"
        if (o=="S"):
            return "E"
        if (o=="E"):
            return "N"

    def tournerD(self,o):
        if (o=="N"):
            return "E"
        if (o=="E"):
            return "S"
        if (o=="S"):
            return "O"
        if (o=="O"):
            return "N"

    def avancer(self,s):
        #si etat est tombe
        if (s=="plouf"):
            return "plouf"

        #si gagne
        if (s=="tresor"):
            return "tresor"

        #sinon on avance
        x=s[0]
        y=s[1]
        o=s[2]
        if (o=="N"):
            y=y-1
            if (y<0):
                y=0
        if (o=="S"):
            y=y+1
            if (y>self.dy):
                y=self.dy
        if (o=="O"):
            x=x-1
            if (x<0):
                x=0
        if (o=="E"):
            x=x+1
            if (x>self.dx):
                x=self.dx

        # on teste les etats speciaux
        if (x,y) in self.trou:
            return("plouf")
        if (x,y) in self.tresor :
            return("tresor")

        return (x,y,o)
                
        
    

    #les recompenses 
    def recompense(self,s,a,sarr):
        if (s!="plouf") and (sarr=="plouf"):
            return(-10)
        if (s!="tresor") and (sarr=="tresor"):
            return(100)
        if (s=="tresor") or (s=="plouf"):
            return(0)
        
        return(-1)


    #les etats possibles
    def etats(self):
        res=[]
        for x in range(self.dx+1):
            for y in range(self.dy+1):
                res+=[(x,y,"N")]
                res+=[(x,y,"E")]
                res+=[(x,y,"S")]
                res+=[(x,y,"O")]
        res+=["plouf"]
        res+=["tresor"]
        return res
    
          
      

    #la liste des actions
    def actions(self):
        return ["devant","gauche","droite"]

              
    

class Systeme:
    """permet d'executer un probleme
    - pb : attribut du probleme"""

    def __init__(self,pb):
        """construit un systeme a partir d'un probleme"""
        self.pb=pb

    def execute(self,s,a):
        #print ("etat depart: ",s)
        #print ("action: ",a)
        sArriv = self.pb.transition(s,a)
        #print ("etat arrivee: ",sArriv)
        #print ("recompense: ",self.pb.recompense(s,a,sArriv))
        #print("********")
        return(sArriv)

    def intialiseQ(self):
        """construit des Qvaleurs vides"""
        Q = {}
        for etat in self.pb.etats():
            for action in self.pb.actions():
                Q[(etat,action)]=0
        return(Q)
    

    def planifie(self,nb):
        gamma=0.99
        Q=self.intialiseQ()
            
       
        #une iteration
        for i in range(0,nb):
            #print("iteration :",i)
            Q2={}
            
            #pour chaque etat,action
            for etat in self.pb.etats():
                for action in self.pb.actions():
                    #calculer arrivee
                    Q2[(etat,action)]=0
                    etatsArrivee=self.pb.transition(etat,action)
                    
                    #parcourir les etats possibles
                    for arrivee in etatsArrivee:
                        etatfin = arrivee[0]
                        proba = arrivee[1]
                        r = self.pb.recompense(etat,action,etatfin)
                        # cherche max arrivee
                        max=-100000
                        for actionMax in self.pb.actions():
                            #print(i,arrivee,etatfin,actionMax)
                            if (Q[(etatfin,actionMax)]>max):
                                max=Q[(etatfin,actionMax)]
                        Q2[(etat,action)]+=proba*(r+gamma*max)
                        
            #on augmente iteration
            Q=Q2
            #print("* nb etats:",len(Q))
        return(Q)

    def executerPi(self,pi,depart,nb):
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=self.execute(s,action)
            r=self.pb.recompense(s,action,sFin)
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin
                    

    def afficherQ(self,Q):
        for etatAction in Q:
            etat=etatAction[0]
            action=etatAction[1]
            chaine = ""+str(etat)+" - "+str(action)+" -> "+str(Q[(etat,action)])+", "
            print(chaine)

    def afficherQS(self,Q):
        for etat in self.pb.etats():
            chaine = ""+str(etat)
            for action in self.pb.actions():
                chaine+="\n - "+action+" -> "+str(Q[(etat,action)])+", "
            print(chaine)
            

    def politiqueFromQ(self,Q):
        pi={}
        for etatAction in Q:
            etat=etatAction[0]
            # cherche max arrivee
            max=-100000
            amax=-1;
            for actionMax in self.pb.actions():
                if (Q[(etat,actionMax)]>max):
                    max=Q[(etat,actionMax)]
                    amax=actionMax
            pi[etat]=amax
        return(pi)

    def apprentissage(self,Sdep):
        print("")
        



pb=LabyOrientation()
print(len(pb.etats()))
print(pb.actions())
print("****************************************")

systeme=Systeme(pb)

print("****************************************")
print("planification 1 ")
Q=systeme.planifie(1)
print(len(Q))
systeme.afficherQS(Q)

print("****************************************")
print("planification 2 ")
Q=systeme.planifie(2)
print(len(Q))
systeme.afficherQS(Q)

print("****************************************")
print("planification 50")
Q=systeme.planifie(50)
print(len(Q))
systeme.afficherQS(Q)

print("****************************************")
print("planification 1000")
Q=systeme.planifie(1000)
print(len(Q))
systeme.afficherQS(Q)


#systeme.afficherQS(Q,(3,2))
#systeme.afficherQS(Q,(3,3))


print("****************************************")
#pi=systeme.politiqueFromQ(Q)
#print(pi)
#systeme.executerPi(pi,sDep,30)

        
    


