from random import uniform

class Clepsydre:
    """laby de 3*3"""
    """a chaque coin une clepsydre avec 4 niveaux"""


    def __init__(self):
        self.proba_baisser=0.1
        self.max=3


    def transition(self,s,a):
        """ etat est simplement (x,y,c0,c1,c2,c3) """
        """ niveaux des clepsydres sont entre 0 et 3 """
        """ actions sont N,S,E,O,Remplir"""

        # calcule la nouvelle position
        x=s[0]
        y=s[1]

        if (a=="N"):
            y=y-1
            if (y<0):
                y=0

        if (a=="S"):
            y=y+1
            if (y>2):
                y=2

        if (a=="E"):
            x=x+1
            if (x>2):
                x=2

        if (a=="O"):
            x=x-1
            if (x<0):
                x=0

        # clacule le niveau
        niveaux={}
        for num in range(2,5+1):
            niveaux[num]=s[num]

        # fait action remplir
        if (a=="Remplir"):
            if (x==0)and (y==0):
                niveaux[2]=self.max
            if (x==2)and (y==0):
                niveaux[3]=self.max
            if (x==0)and (y==2):
                niveaux[4]=self.max
            if (x==2)and (y==2):
                niveaux[5]=self.max

        # ajoute proba de baisser
        sfinal=(x,y,niveaux[2],niveaux[3],niveaux[4],niveaux[5])
        res=[(sfinal,1.-4*self.proba_baisser)]
        # on a une proba de proba sur chaque etat
        for i in range(2,5+1):
            niveaux[i]=niveaux[i]-1
            if (niveaux[i]==-1):
                niveaux[i]=0
                s2=(x,y,niveaux[2],niveaux[3],niveaux[4],niveaux[5])
                res+=[(s2,self.proba_baisser)]
            else:
                s2=(x,y,niveaux[2],niveaux[3],niveaux[4],niveaux[5])
                res+=[(s2,self.proba_baisser)]
                niveaux[i]=niveaux[i]+1
        return(res)


    def executer(self,s,a):
        #recupere distribution
        dist=self.transition(s,a)
        #nombre au hasard

        hasard=uniform(0,1)
        #print(dist)
        i=0
        while(hasard>dist[i][1]):
            hasard = hasard - dist[i][1]
            i=i+1
        return(dist[i][0])


    #les recompenses
    def recompense(self,s,a,sarr):
        som=0
        #on parcourt les niveaux
        for i in range(2,5+1):
            niv=sarr[i]
            if(niv==3):
                som+=0
            if(niv==2):
                som+=-1
            if(niv==1):
                som+=-4
            if(niv==0):
                som+=-9
        return(som)


    #les etats possibles
    def etats(self):
        res=[]
        for x in range(0,2+1):
            for y in range(0,2+1):
                for n1 in range(0,3+1):
                    for n2 in range(0,3+1):
                        for n3 in range(0,3+1):
                            for n4 in range(0,3+1):
                                res+=[(x,y,n1,n2,n3,n4)]
        return res




    #la liste des actions
    def actions(self):
        return ["N","S","E","O","Remplir"]




class Systeme:
    """permet d'executer un probleme
    - pb : attribut du probleme"""

    def __init__(self,pb):
        """construit un systeme a partir d'un probleme"""
        self.pb=pb

    def execute(self,s,a):
        #print ("etat depart: ",s)
        #print ("action: ",a)
        sArriv = self.pb.transition(s,a)
        #print ("etat arrivee: ",sArriv)
        #print ("recompense: ",self.pb.recompense(s,a,sArriv))
        #print("********")
        return(sArriv)

    def intialiseQ(self):
        """construit des Qvaleurs vides"""
        Q = {}
        for etat in self.pb.etats():
            for action in self.pb.actions():
                Q[(etat,action)]=0
        return(Q)


    def planifie(self,nb):
        gamma=0.99
        Q=self.intialiseQ()


        #une iteration
        for i in range(0,nb):
            print("iteration :",i)
            Q2={}
            epsilon=0
            #pour chaque etat,action
            for etat in self.pb.etats():
                for action in self.pb.actions():
                    #calculer arrivee
                    Q2[(etat,action)]=0
                    etatsArrivee=self.pb.transition(etat,action)

                    #parcourir les etats possibles
                    for arrivee in etatsArrivee:
                        etatfin = arrivee[0]
                        proba = arrivee[1]
                        r = self.pb.recompense(etat,action,etatfin)
                        # cherche max arrivee
                        max=-100000
                        for actionMax in self.pb.actions():
                            #print(i,arrivee,etatfin,actionMax)
                            if (Q[(etatfin,actionMax)]>max):
                                max=Q[(etatfin,actionMax)]
                        Q2[(etat,action)]+=proba*(r+gamma*max)
                    #calcule le epsilon max
                    diff=abs(Q2[(etat,action)]-Q[(etat,action)])
                    if (diff>epsilon):
                        epsilon=diff

            #on augmente iteration
            Q=Q2
            print("epsilon: ",epsilon)
            #print("* nb etats:",len(Q))
        return(Q)

    def executerPi(self,pi,depart,nb):
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=self.execute(s,action)
            r=self.pb.recompense(s,action,sFin)
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin


    def afficherQ(self,Q):
        for etatAction in Q:
            etat=etatAction[0]
            action=etatAction[1]
            chaine = ""+str(etat)+" - "+str(action)+" -> "+str(Q[(etat,action)])+", "
            print(chaine)

    def afficherQS(self,Q):
        for etat in self.pb.etats():
            chaine = ""+str(etat)
            for action in self.pb.actions():
                chaine+="\n - "+action+" -> "+str(Q[(etat,action)])+", "
            print(chaine)


    def politiqueFromQ(self,Q):
        pi={}
        for etatAction in Q:
            etat=etatAction[0]
            # cherche max arrivee
            max=-100000
            amax=-1;
            for actionMax in self.pb.actions():
                if (Q[(etat,actionMax)]>max):
                    max=Q[(etat,actionMax)]
                    amax=actionMax
            pi[etat]=amax
        return(pi)

    def apprentissage(self,Sdep):
        print("")



class RenduGraphique:

    def __init__(self,laby,etat,pi):
        self.laby=laby
        self.etat=etat
        self.pi=pi
        self.etatdep=etat
        self.taille=50


    def evoluer(self):
        action=self.pi[self.etat]
        print(self.etat,"->",action)
        self.etat=self.laby.executer(self.etat,action)



    def dessiner(self,screen):
        pygame.draw.rect(screen,(0,0,0),(0,0,500,500))
        print("ok")
        RED = (0xFF, 0x00, 0x00)
        BLUE = (0x00, 0x00, 0xFF)
        WHITE = (0xFF, 0xFF, 0xFF)
        GREY = (0x50, 0x50, 0x50)
        YELLOW = (0xAD, 0xFF, 0x2F)
        GREEN= (0x00, 0xFF, 0x00)
        taille=self.taille
        #dessin laby
        for x in range(1,4):
            for y in range(1,4):
                pygame.draw.rect(screen,WHITE,(x*taille,y*taille,taille-2,taille-2))

        #affiche
        font = pygame.font.Font(None, 20)

        #dessin clepsydre
        clepsydre=[(-1,0),(3,0),(-1,2),(3,2)]
        for i in range(len(clepsydre)):
            t=clepsydre[i]
            pygame.draw.ellipse(screen,BLUE,((t[0]+1)*taille,(t[1]+1)*taille,taille-2,taille-2))
            texte="niv: "+str(self.etat[i+2])
            screen.blit(font.render(texte, True, (255,0,0)), ((t[0]+1)*taille+5,(t[1]+1)*taille+taille/4))


        pygame.draw.ellipse(screen,GREY,((self.etat[0]+1)*taille,(self.etat[1]+1)*taille,taille-2,taille-2))


    def reinitEtat(self,x,y):
        self.etat=(int(x/self.taille),int(y/self.taille),"N")
        #self.etat=self.etatdep

    def lancer(self):

        pygame.init()
        size = (700, 500)
        screen = pygame.display.set_mode(size)
        pygame.display.set_caption("Moteur")
        done = False
        mouse=False
        # -------- boucle jeu
        clock = pygame.time.Clock()
        while not done:
            # --- gestion des evenements
            for event in pygame.event.get():
                # si on arrete
                if event.type == pygame.QUIT:
                    done = True

                # gere le click de souris
                # passe en mode modifie coefficients
                if event.type == pygame.MOUSEBUTTONDOWN  :
                    self.reinitEtat(pygame.mouse.get_pos()[0],pygame.mouse.get_pos()[1])


            # --- mise a jour du jeu
            self.evoluer()

            # --- mise a jour dessins
            self.dessiner(screen)


            # --- graphique
            pygame.display.flip()

            # --- attente
            clock.tick(2)
        pygame.quit()






pb=Clepsydre()
print(len(pb.etats()))
print(pb.actions())


print("****************************************")


print(pb.transition((0,0,1,1,1,1),"Remplir"))

systeme=Systeme(pb)

print("****************************************")
print("planification 1 ")
Q=systeme.planifie(1)
print(len(Q))
systeme.afficherQS(Q)

print("****************************************")
print("planification 2 ")
Q=systeme.planifie(2)
print(len(Q))
systeme.afficherQS(Q)


print("****************************************")
print("planification 50")
Q=systeme.planifie(50)
print(len(Q))
#systeme.afficherQS(Q)


#systeme.afficherQS(Q,(3,2))
#systeme.afficherQS(Q,(3,3))


print("****************************************")
pi=systeme.politiqueFromQ(Q)
print(pi)
#systeme.executerPi(pi,sDep,30)


import pygame
r=RenduGraphique(pb,(0,0,3,3,3,3),pi)
r.lancer()



