class LabyOrientation:
    """labyrinthe avec orientation"""
    """aspect stochastique d'avancer"""

    def __init__(self):
        self.dx=5
        self.dy=5
        #objectif a atteindre
        self.tresor=[(4,4)]
        # les trous
        self.trou=[]
        self.trou+=[(2,3)]
        self.trou+=[(2,0)]
        self.trou+=[(2,2)]
        self.trou+=[(0,3)]
        self.trou+=[(3,3)]
        self.trou+=[(4,3)]



    def transition(self,s,a):
        """ etat est simplement (x,y,orientation)"""
        """ orientation = N,S,E,O"""
        """ action D =devant, G=gauche, droite"""

        if (s=="plouf"):
            return[("plouf",1.0)]

        if (s=="tresor"):
            return[("tresor",1.0)]

        if (a=="gauche"):
            sfin=(s[0],s[1],self.tournerG(s[2]))
            return [(sfin,1.)]

        if (a=="droite"):
            sfin=(s[0],s[1],self.tournerD(s[2]))
            return [(sfin,1.)]

        if (a=="devant"):
            s1=self.avancer(s)
            s2=self.avancer(s1)
            return [(s1,0.5),(s2,0.5)]

    def tournerG(self,o):
        if (o=="N"):
            return "O"
        if (o=="O"):
            return "S"
        if (o=="S"):
            return "E"
        if (o=="E"):
            return "N"

    def tournerD(self,o):
        if (o=="N"):
            return "E"
        if (o=="E"):
            return "S"
        if (o=="S"):
            return "O"
        if (o=="O"):
            return "N"

    def avancer(self,s):
        #si etat est tombe
        if (s=="plouf"):
            return "plouf"

        #si gagne
        if (s=="tresor"):
            return "tresor"

        #sinon on avance
        x=s[0]
        y=s[1]
        o=s[2]
        if (o=="N"):
            y=y-1
            if (y<0):
                y=0
        if (o=="S"):
            y=y+1
            if (y>self.dy):
                y=self.dy
        if (o=="O"):
            x=x-1
            if (x<0):
                x=0
        if (o=="E"):
            x=x+1
            if (x>self.dx):
                x=self.dx

        # on teste les etats speciaux
        if (x,y) in self.trou:
            return("plouf")
        if (x,y) in self.tresor :
            return("tresor")

        return (x,y,o)

    def executer(self,s,a):
        #recupere distribution
        dist=self.transition(s,a)
        #nombre au hasard
        from random import uniform
        hasard=uniform(0,1)
        print(dist)
        i=0
        while(hasard>dist[i][1]):
            hasard = hasard - dist[i][1]
            i=i+1
        return(dist[i][0])


    #les recompenses
    def recompense(self,s,a,sarr):
        if (s!="plouf") and (sarr=="plouf"):
            return(-10)
        if (s!="tresor") and (sarr=="tresor"):
            return(100)
        if (s=="tresor") or (s=="plouf"):
            return(0)

        return(-1)


    #les etats possibles
    def etats(self):
        res=[]
        for x in range(self.dx+1):
            for y in range(self.dy+1):
                res+=[(x,y,"N")]
                res+=[(x,y,"E")]
                res+=[(x,y,"S")]
                res+=[(x,y,"O")]
        res+=["plouf"]
        res+=["tresor"]
        return res




    #la liste des actions
    def actions(self):
        return ["devant","gauche","droite"]




class Systeme:
    """permet d'executer un probleme
    - pb : attribut du probleme"""

    def __init__(self,pb):
        """construit un systeme a partir d'un probleme"""
        self.pb=pb

    def execute(self,s,a):
        #print ("etat depart: ",s)
        #print ("action: ",a)
        sArriv = self.pb.transition(s,a)
        #print ("etat arrivee: ",sArriv)
        #print ("recompense: ",self.pb.recompense(s,a,sArriv))
        #print("********")
        return(sArriv)

    def intialiseQ(self):
        """construit des Qvaleurs vides"""
        Q = {}
        for etat in self.pb.etats():
            for action in self.pb.actions():
                Q[(etat,action)]=0
        return(Q)


    def planifie(self,nb):
        gamma=0.99
        Q=self.intialiseQ()


        #une iteration
        for i in range(0,nb):
            #print("iteration :",i)
            Q2={}

            #pour chaque etat,action
            for etat in self.pb.etats():
                for action in self.pb.actions():
                    #calculer arrivee
                    Q2[(etat,action)]=0
                    etatsArrivee=self.pb.transition(etat,action)

                    #parcourir les etats possibles
                    for arrivee in etatsArrivee:
                        etatfin = arrivee[0]
                        proba = arrivee[1]
                        r = self.pb.recompense(etat,action,etatfin)
                        # cherche max arrivee
                        max=-100000
                        for actionMax in self.pb.actions():
                            #print(i,arrivee,etatfin,actionMax)
                            if (Q[(etatfin,actionMax)]>max):
                                max=Q[(etatfin,actionMax)]
                        Q2[(etat,action)]+=proba*(r+gamma*max)

            #on augmente iteration
            Q=Q2
            #print("* nb etats:",len(Q))
        return(Q)

    def executerPi(self,pi,depart,nb):
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=self.execute(s,action)
            r=self.pb.recompense(s,action,sFin)
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin


    def afficherQ(self,Q):
        for etatAction in Q:
            etat=etatAction[0]
            action=etatAction[1]
            chaine = ""+str(etat)+" - "+str(action)+" -> "+str(Q[(etat,action)])+", "
            print(chaine)

    def afficherQS(self,Q):
        for etat in self.pb.etats():
            chaine = ""+str(etat)
            for action in self.pb.actions():
                chaine+="\n - "+action+" -> "+str(Q[(etat,action)])+", "
            print(chaine)


    def politiqueFromQ(self,Q):
        pi={}
        for etatAction in Q:
            etat=etatAction[0]
            # cherche max arrivee
            max=-100000
            amax=-1;
            for actionMax in self.pb.actions():
                if (Q[(etat,actionMax)]>max):
                    max=Q[(etat,actionMax)]
                    amax=actionMax
            pi[etat]=amax
        return(pi)

    def apprentissage(self,Sdep):
        print("")



class RenduGraphique:

    def __init__(self,laby,etat,pi):
        self.laby=laby
        self.etat=etat
        self.pi=pi
        self.etatdep=etat
        self.taille=50


    def evoluer(self):
        print(self.etat)
        action=self.pi[self.etat]
        self.etat=self.laby.executer(self.etat,action)


    def dessiner(self,screen):
        print("ok")
        RED = (0xFF, 0x00, 0x00)
        WHITE = (0xFF, 0xFF, 0xFF)
        GREY = (0x50, 0x50, 0x50)
        YELLOW = (0xAD, 0xFF, 0x2F)
        GREEN= (0x00, 0xFF, 0x00)
        taille=self.taille
        #dessin laby
        for x in range(self.laby.dx+1):
            for y in range(self.laby.dy+1):
                pygame.draw.rect(screen,WHITE,(x*taille,y*taille,taille-2,taille-2))
        #dessin trous
        for t in self.laby.trou:
            pygame.draw.ellipse(screen,GREY,(t[0]*taille,t[1]*taille,taille-2,taille-2))

        #dessin tresor
        for t in self.laby.tresor:
            pygame.draw.ellipse(screen,YELLOW,(t[0]*taille,t[1]*taille,taille-2,taille-2))

        #dessin etat
        if (self.etat!="plouf") and ((self.etat!="tresor")):
            pygame.draw.ellipse(screen,GREEN,(self.etat[0]*taille,self.etat[1]*taille,taille-2,taille-2))
            o=self.etat[2]
            x=0
            y=0
            if (o=="N"):
                y=-1
            if (o=="S"):
                y=1
            if (o=="E"):
                x=1
            if (o=="O"):
                x=-1
            dep=[self.etat[0]*taille+taille/2,self.etat[1]*taille+taille/2]
            arr=[0,0]
            arr[0]=dep[0]+x*taille/2
            arr[1]=dep[1]+y*taille/2
            print(dep)
            print(arr)

            pygame.draw.line(screen,GREY,dep,arr,5)

    def reinitEtat(self,x,y):
        self.etat=(int(x/self.taille),int(y/self.taille),"N")
        #self.etat=self.etatdep

    def lancer(self):

        pygame.init()
        size = (700, 500)
        screen = pygame.display.set_mode(size)
        pygame.display.set_caption("Moteur")
        done = False
        mouse=False
        # -------- boucle jeu
        clock = pygame.time.Clock()
        while not done:
            # --- gestion des evenements
            for event in pygame.event.get():
                # si on arrete
                if event.type == pygame.QUIT:
                    done = True

                # gere le click de souris
                # passe en mode modifie coefficients
                if event.type == pygame.MOUSEBUTTONDOWN  :
                    self.reinitEtat(pygame.mouse.get_pos()[0],pygame.mouse.get_pos()[1])


            # --- mise a jour du jeu
            self.evoluer()

            # --- mise a jour dessins
            self.dessiner(screen)


            # --- graphique
            pygame.display.flip()

            # --- attente
            clock.tick(2)
        pygame.quit()






pb=LabyOrientation()
print(len(pb.etats()))
print(pb.actions())
print("****************************************")

systeme=Systeme(pb)

print("****************************************")
print("planification 1 ")
Q=systeme.planifie(1)
print(len(Q))
systeme.afficherQS(Q)

print("****************************************")
print("planification 2 ")
Q=systeme.planifie(2)
print(len(Q))
systeme.afficherQS(Q)

print("****************************************")
print("planification 50")
Q=systeme.planifie(50)
print(len(Q))
systeme.afficherQS(Q)

print("****************************************")
print("planification 1000")
Q=systeme.planifie(1000)
print(len(Q))
systeme.afficherQS(Q)


#systeme.afficherQS(Q,(3,2))
#systeme.afficherQS(Q,(3,3))


print("****************************************")
pi=systeme.politiqueFromQ(Q)
print(pi)
#systeme.executerPi(pi,sDep,30)


import pygame
r=RenduGraphique(pb,(0,0,"N"),pi)
r.lancer()



