class Lampe:
    
    def actions(self):
        return(['appuyer','rien'])

    def etats(self):
        return(['allume','eteint'])
            
        
    def transition(self,s,a):
        if (a=='rien'):
            return(s)
        if (a=='appuyer'):
            if (s=='allume'):
                return('eteint')
            if (s=='eteint'):
                return('allume')
        return('erreur')
            

    def recompense(self,s,a,sarr):
        if (s=='eteint') and (a=='appuyer'):
            return(10)
        return(0)

      
class SystemeExecute:

    def __init__(self,pb):
        self.pb=pb

    def executerPi(self,pi,depart,nb):
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=pb.transition(s,action)
            print(s," -> ",action," : ",sFin)
            s=sFin

    def executerPiRec(self,pi,depart,nb):
        s=depart
        somme=0
        for i in range(nb):
            action=pi[s]
            sFin=pb.transition(s,action)
            r=pb.recompense(s,action,sFin)
            somme=somme+r
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin
        return somme

            

#************************************************************
pb=Lampe()

pi={}
pi['allume']='appuyer'
pi['eteint']='appuyer'
print ("*** politique ***")
print(pi)

#print("*** test execution ***")
#systemExec = SystemeExecute(pb)
#systemExec.executerPi(pi,'eteint',30)

print("*** test execution recompense ***")
systemExec = SystemeExecute(pb)
somme=systemExec.executerPiRec(pi,'eteint',10)
print("somme: ",somme)

#************************************************************



class Systeme:
    """permet d'executer un probleme
    - pb : attribut du probleme"""

    def __init__(self,pb):
        """construit un systeme a partir d'un probleme"""
        self.pb=pb

    def execute(self,s,a):
        #print ("etat depart: ",s)
        #print ("action: ",a)
        sArriv = self.pb.transition(s,a)
        #print ("etat arrivee: ",sArriv)
        #print ("recompense: ",self.pb.recompense(s,a,sArriv))
        #print("********")
        return(sArriv)

    def intialiseQ(self):
        """construit des Qvaleurs vides"""
        Q = {}
    

    def valueIteration(self,nb):
        gamma=0.999
        #Q=self.intialiseQ()
        Q={}
       
        #une iteration
        for i in range(0,nb):
            Q2={}
            #pour chaque etat,action
            for etat in self.pb.etats():            
                for action in self.pb.actions():
                    
                    #calculer arrivee
                    sArriv=self.pb.transition(etat,action)
                    r=self.pb.recompense(etat,action,sArriv)
                    # cherche max arrivee
                    max=-100000
                    for actionMax in self.pb.actions():
                        #si la clef n'existe pas, on cree
                        if (not((sArriv,actionMax) in Q)):
                            Q[(sArriv,actionMax)]=0
                        if (Q[(sArriv,actionMax)]>max):
                            max=Q[(sArriv,actionMax)]                  
                    #mise à jour                    
                    Q2[(etat,action)]=r+gamma*max
                    
            #on augmente iteration
            Q=Q2
        return(Q)

    def executerPi(self,pi,depart,nb):
        s=depart
        for i in range(nb):
            action=pi[s]
            sFin=self.execute(s,action)
            r=self.pb.recompense(s,action,sFin)
            print(s," -> ",action," : ",sFin,"<",r,">")
            s=sFin
                    

    def afficherQ(self,Q):
        for etat in self.pb.etats():
            chaine = ""+str(etat)+" - "
            for action in self.pb.actions():
                chaine+=action+" -> "+str(Q[(etat,action)])+", "
            print(chaine)

    def afficherQS(self,Q,etat):
        chaine = ""+str(etat)+" - "
        for action in self.pb.actions():
            chaine+=action+" -> "+str(Q[(etat,action)])+", "
        print(chaine)
            

    def politiqueFromQ(self,Q):
        pi={}
        for etat in self.pb.etats():
            # cherche max arrivee
            max=-100000
            amax=-1;
            for actionMax in self.pb.actions():
                if (Q[(etat,actionMax)]>max):
                    max=Q[(etat,actionMax)]
                    amax=actionMax
            pi[etat]=amax
        return(pi)

    def apprentissage(self,Sdep):
        print("")
        

print("****************************************")

pb=Lampe()
print(pb.etats())
print(pb.actions())

print("****************************************")
print("****************************************")

print("planification")
systeme=Systeme(pb)
Q=systeme.valueIteration(1000)
#print(Q)
#systeme.afficherQ(Q)
#systeme.afficherQS(Q,(3,2))
#systeme.afficherQS(Q,(3,3))


print("****************************************")
pi=systeme.politiqueFromQ(Q)
#print(pi)
systeme.executerPi(pi,'allume',30)

        
    



