class Probleme:
    """"permet de efinir un probleme"""

    def actions(self):
        """returne la liste d'actions"""
        return ([])

    def etats(self):
        """returne la liste d'etats"""
        return ([])


    def transition(self,s,a):
        """definit les consequence d une action"""
        return(s)



    def recompense(self,s,a,sarr):
        """definit la recompense obtenue"""
        return(0)





class Laby(Probleme):
    """un probleme ou on se deplace
    - les actions connues sont N,S,E,O
    - l'etat est un couple (x,y)"""

    """pose des trous dans laby"""
    def __init__(self):
        self.dx=10
        self.dy=10

        self.tresor=[(10,10)]

        self.trou=[]
        #self.trou+=[(2,0)]
        self.trou+=[(1,0)]
        self.trou+=[(1,1)]
        self.trou+=[(0,3)]
        self.trou+=[(1,3)]
        self.trou+=[(2,3)]
        self.trou+=[(3,3)]
        self.trou+=[(4,2)]
        self.trou+=[(6,4)]
        self.trou+=[(6,5)]
        self.trou+=[(6,6)]
        self.trou+=[(7,6)]
        self.trou+=[(8,6)]
        self.trou+=[(9,6)]
        self.trou+=[(10,6)]
        self.trou+=[(5,6)]
        self.trou+=[(4,6)]

        self.trou+=[(3,5)]
        self.trou+=[(3,4)]
        self.trou+=[(3,3)]
        self.trou+=[(3,2)]
        self.trou+=[(3,1)]
        self.trou+=[(3,0)]

    def printLaby(self):
        for j in range(11):
            chaine=""
            for i in range(11):
                etat=(i,j)
                if etat in self.trou:
                    chaine+="X"
                else:
                    chaine+="."
            print(chaine)



    def transition(self,s,a):
        """fait evoluer d une unite le systeme"""
        x=s[0];
        y=s[1];

        #si on part de 10,10 on retourne en 0
        if (x==10) and (y==10):
            return(-1,-1)
        if (x==-1) and (y==-1):
            return(-1,-1)

        if (a=='N'):
            y=y-1
            if (y<0):
               y=0
        if (a=='S'):
            y=y+1
            if (y>10):
                y=10
        if (a=='E'):
            x=x+1
            if (x>10):
                x=10
        if (a=='O'):
            x=x-1
            if (x<0):
                x=0
        return((x,y))

    def recompense(self,s,a,sarr):
        """ok si on est en 10,10"""
        if ((sarr[0]==10)and(sarr[1]==10)):
            return(100)
        if (sarr in self.trou):
            return(-20)
        if (sarr[0]==-1)and (sarr[1]==-1):
            return(0)
        return(-1)

    def actions(self):
        return(['N','S','E','O'])

    def etats(self):
        etats=[]
        for i in range(0,11):
            for j in range(0,11):
                etats+=[(i,j)]
        etats+=[(-1,-1)]
        return(etats)

    def executer(self,s,a):
        sArriv = self.transition(s,a)
        return(sArriv)



class Systeme:
    """permet d'executer un probleme
    - pb : attribut du probleme"""

    def __init__(self,pb):
        """construit un systeme a partir d'un probleme"""
        self.pb=pb
        self.dx=self.pb.dx
        self.dy=self.pb.dy

    def executer(self,s,a):
        #print ("etat depart: ",s)
        #print ("action: ",a)
        sArriv = self.pb.transition(s,a)
        #print ("etat arrivee: ",sArriv)
        #print ("recompense: ",self.pb.recompense(s,a,sArriv))
        #print("********")
        return(sArriv)

    def intialiseQ(self):
        """construit des Qvaleurs vides"""
        Q = {}


    def valueIteration(self,nb):
        gamma=0.999
        #Q=self.intialiseQ()
        Q={}

        #une iteration
        for i in range(0,nb):
            Q2={}
            #pour chaque etat,action
            for etat in self.pb.etats():
                for action in self.pb.actions():

                    #calculer arrivee
                    sArriv=self.pb.transition(etat,action)
                    r=self.pb.recompense(etat,action,sArriv)
                    # cherche max arrivee
                    max=-100000
                    for actionMax in self.pb.actions():
                        #si la clef n'existe pas, on cree
                        if (not((sArriv,actionMax) in Q)):
                            Q[(sArriv,actionMax)]=0
                        if (Q[(sArriv,actionMax)]>max):
                            max=Q[(sArriv,actionMax)]
                    #mise Ã  jour
                    Q2[(etat,action)]=r+gamma*max

            #on augmente iteration
            Q=Q2
        return(Q)

    def executerPi(self,pi,depart,nb):
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=self.executer(s,action)
            r=self.pb.recompense(s,action,sFin)
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin


    def afficherQ(self,Q):
        for etat in self.pb.etats():
            chaine = ""+str(etat)+" - "
            for action in self.pb.actions():
                chaine+=action+" -> "+str(Q[(etat,action)])+", "
            print(chaine)

    def afficherQS(self,Q,etat):
        chaine = ""+str(etat)+" - "
        for action in self.pb.actions():
            chaine+=action+" -> "+str(Q[(etat,action)])+", "
        print(chaine)


    def politiqueFromQ(self,Q):
        pi={}
        for etat in self.pb.etats():
            # cherche max arrivee
            max=-100000
            amax=-1;
            for actionMax in self.pb.actions():
                if (Q[(etat,actionMax)]>max):
                    max=Q[(etat,actionMax)]
                    amax=actionMax
            pi[etat]=amax
        return(pi)

    def apprentissage(self,Sdep):
        print("")




#*****************************************************

import pygame

class RenduGraphique:

    def __init__(self,laby,etat,pi,Q):
        self.laby=laby
        self.etat=etat
        self.pi=pi
        self.etatdep=etat
        self.taille=30
        self.Q=Q


    def evoluer(self):
        print(self.etat)
        action=self.pi[self.etat]
        self.etat=self.laby.executer(self.etat,action)


    def dessiner(self,screen):
        print("ok")
        RED = (0xFF, 0x00, 0x00)
        WHITE = (0xFF, 0xFF, 0xFF)
        GREY = (0x50, 0x50, 0x50)
        YELLOW = (0xAD, 0xFF, 0x2F)
        GREEN= (0x00, 0xFF, 0x00)
        taille=self.taille
        #dessin laby
        for x in range(self.laby.dx+1):
            for y in range(self.laby.dy+1):
                pygame.draw.rect(screen,WHITE,(x*taille,y*taille,taille-2,taille-2))
        #dessin trous
        for t in self.laby.trou:
            pygame.draw.rect(screen,GREY,(t[0]*taille,t[1]*taille,taille-2,taille-2))

        #dessin tresor
        for t in self.laby.tresor:
            pygame.draw.rect(screen,YELLOW,(t[0]*taille,t[1]*taille,taille-2,taille-2))

        #dessin Qvaleur
        for x in range(self.laby.dx+1):
            for y in range(self.laby.dy+1):
                pygame.draw.rect(screen,WHITE,(x*taille,y*taille,taille-2,taille-2))

    def reinitEtat(self,x,y):
        self.etat=(int(x/self.taille),int(y/self.taille),"N")
        #self.etat=self.etatdep

    def lancer(self):

        pygame.init()
        size = (330, 500)
        screen = pygame.display.set_mode(size)
        pygame.display.set_caption("Moteur")
        done = False
        mouse=False
        # -------- boucle jeu
        clock = pygame.time.Clock()
        while not done:
            # --- gestion des evenements
            for event in pygame.event.get():
                # si on arrete
                if event.type == pygame.QUIT:
                    done = True

                # gere le click de souris
                # passe en mode modifie coefficients
                if event.type == pygame.MOUSEBUTTONDOWN  :
                    self.reinitEtat(pygame.mouse.get_pos()[0],pygame.mouse.get_pos()[1])


            # --- mise a jour du jeu
            self.evoluer()

            # --- mise a jour dessins
            self.dessiner(screen)


            # --- graphique
            pygame.display.flip()

            # --- attente
            clock.tick(2)
        pygame.quit()


#******************************************




pb=Laby()
#print(pb.etats())
#print(pb.actions())
print("****************************************")
pb.printLaby()
print("****************************************")

systeme=Systeme(pb)
s=systeme.executer((5,5),'N')
s=systeme.executer(s,'S')

print("****************************************")
print("apprentissage")
Q=systeme.valueIteration(100)
#print(Q)
#systeme.afficherQ(Q)
#systeme.afficherQS(Q,(3,2))
#systeme.afficherQS(Q,(3,3))


print("****************************************")
pi=systeme.politiqueFromQ(Q)
#print(pi)
systeme.executerPi(pi,(0,0),30)

print("****************************************")
rendu=RenduGraphique(pb,(5,5),pi,Q)
rendu.lancer()




