from sage.rings.integer_ring import ZZ
from sage.rings.finite_rings.finite_field_constructor import FiniteField
from sage.schemes.elliptic_curves.constructor import EllipticCurve
from sage.misc.prandom import randrange

def double_affine(P, a):
    """
    INPUT:
    - `P`: point in affine coordinates
    - `a`: curve coefficient in y^2 = x^3 + a + b

   Source: http://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
    """
    x1, y1 = P[0], P[1]
    lambd = (3*x1**2+a)/(2*y1)
    x3 = lambd**2-x1-x1
    y3 = lambd*(x1-x3)-y1
    return P.curve()((x3,y3,1))

def add_affine(P, Q):
    """
    INPUT:
    - `P`: point in affine coordinates
    - `Q`: point in affine coordinates
    - `a`: curve coefficient in y^2 = x^3 + a + b

   Source: http://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
    """
    x1, y1 = P[0], P[1]
    x2, y2 = Q[0], Q[1]
    lambd = (y2-y1)/(x2-x1)
    lambd2 = lambd**2
    x3 = lambd2-x1-x2
    y3 = lambd*(x1-x3)-y1
    return P.curve()((x3,y3,1))

def group_law_affine(P, Q):
    E = P.curve()
    assert E.a2() == 0 and E.a1() == 0 and E.a3() == 0 # short Weierstrass form
    x1, y1 = P[0], P[1]
    x2, y2 = Q[0], Q[1]
    if x1 == x2 and y1 == -y2: # P == -Q, P+Q = P-P = 0
        return P.curve()(0) # the point at infinity
    if x1 == x2 and y1 == y2:
        return double_affine(P, E.a4())
    else:
        return add_affine(P, Q)

def double_projective(P, a):
    """
    INPUT:
    - `P`: point in projective coordinates
    - `a`: curve coefficient in Y^2*Z = X^3 + a*X*Z^2 + b*Z^3

    Source: http://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#doubling-dbl-1998-cmo-2
    """
    X1,Y1,Z1 = P # gets the projective coordinates of P
    w = a*Z1**2+3*X1**2
    s = Y1*Z1
    ss = s**2
    sss = s*ss
    R = Y1*s
    B = X1*R
    h = w**2-8*B
    X3 = 2*h*s
    Y3 = w*(4*B-h)-8*R**2
    Z3 = 8*sss
    # at the end, cast the point to the curve
    return P.curve()((X3, Y3, Z3))

def add_projective(P, Q):
    """
    Addition in projective coordinates on a curve
    in reduced Weierstrass form Y^2*Z = X^3 + a*X*Z^2 + b*Z^3
    INPUT:
    - `P`: point (X1,Y1,Z1) in projective coordinates
    - `Q`: point (X2,Y2,Z2) in projective coordinates

    Source: http://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#addition-add-1998-cmo-2
    """
    X1,Y1,Z1 = P # gets the projective coordinates of P
    X2,Y2,Z2 = Q # gets the projective coordinates of Q
    Y1Z2 = Y1*Z2
    X1Z2 = X1*Z2
    Z1Z2 = Z1*Z2
    u = Y2*Z1-Y1Z2
    uu = u**2
    v = X2*Z1-X1Z2
    vv = v**2
    vvv = v*vv
    R = vv*X1Z2
    A = uu*Z1Z2-vvv-2*R
    X3 = v*A
    Y3 = u*(R-A)-vvv*Y1Z2
    Z3 = vvv*Z1Z2
    return P.curve()(X3, Y3, Z3)

def scalar_mult_projective(P, m):
    E = P.curve()
    if m == 0:
        return E(0)
    a = E.a4()
    b = E.a6()
    assert E.a1() == 0 and E.a2() == 0 and E.a3() == 0
    m = ZZ(m)
    if m < 0:
        m = -m
        P = -P # important to negate P here
    R = P
    bits_m = m.bits()
    # the method bits() returns a list of the bits of m, where bits()[0] is the least significant bit
    for i in range(len(bits_m)-2, -1, -1):
        R = double_projective(R, a) # doubling step
        if bits_m[i] == 1:
            R = add_projective(R, P)
    return R

def test_scalar_mult_projective(E, order):
    ok = True
    no_tests = 10
    i = 0
    while ok and i < no_tests:
        m = randrange(order)
        P = E.random_element()
        result = scalar_mult_projective(P, m)
        ok = result == m*P
        i = i+1
    print("test scalar_mult_projective: {}".format(ok))

def multi_scalar_mult(P, Q, m1, m2):
    """
    multi-scalar multiplication m1*P + m2*Q

    INPUT:
    - P, Q points on the same curve E
    - m1, m2 scalars, positive or negative integers

    RETURN:
    S = m1*P + m2*Q
    """
    # TODO write your function here and return the result
    # remove the following instruction:
    return m1*P + m2*Q

def test_multi_scalar_mult(E, order):
    corner_cases = [(0,1), (1,0), (1,1), (2,1), (1,2), (-1,1), (1,-1), (-1,-1)]
    for (m1, m2) in corner_cases:
        P = E.random_element()
        Q = E.random_element()
        result = multi_scalar_mult(P, Q, m1, m2)
        print("{}*P + {}*Q == S: {}".format(m1, m2, result == m1*P + m2*Q))

    ok = True
    no_tests = 10
    i = 0
    while ok and i < no_tests:
        m1 = randrange(order)
        m2 = randrange(order)
        P = E.random_element()
        Q = E.random_element()
        result = multi_scalar_mult(P, Q, m1, m2)
        ok = result == m1*P + m2*Q
        i = i+1
    print("test multi_scalar_mult: {} ({} random pairs of scalars)".format(ok, i))

if __name__ == "__main__":
    p = 2**255-19
    Fp = FiniteField(p)
    a2 = Fp(486662)
    a4 = Fp(1)
    EM = EllipticCurve([0, a2, 0, a4, 0])
    order = ZZ(57896044618658097711785492504343953926856930875039260848015607506283634007912)
    E = EM.short_weierstrass_model()
    test_scalar_mult_projective(E, order)
    test_multi_scalar_mult(E, order)
