#!/usr/bin/env python3

# SY301 Fall 2016
# Project 3
# Zemor Graph class
# NOTE: You don't need to modify this or understand the math,
# but you may make changes or improvements if you want.

import sys

class Zemor:
    """Class for the Cayley graph of SL2(Fp).
    This is the group of 2x2 invertible matrices mod p
    under a certain group action, as defined in the paper
    "Hash Functions and Graphs with Large Girth" by Zemor, 1991.
    The vertex names will be integers from 0 up to p*(p^2 - 1) - 1.
    """

    def __init__(self, p):
        """Sets up the graph with the given prime number p."""
        # graph properties
        self._directed = True
        self._gname = "Zemor" + str(p)
        # these two matrices define the edges in the graph.
        self._p = p
        self._mul0 = Matrix22p((1,1,0,1), p)
        self._mul1 = Matrix22p((1,0,1,1), p)

    def nvertices(self):
        """Returns the number of vertices in the graph."""
        return self._p*(self._p**2 - 1)

    def isEdge(self, vertexA, vertexB):
        """Returns True if there is an edge from vertexA to vertexB."""
        return vertexB in self.neighbors(vertexA)

    def neighbors(self, vertexA):
        """Returns a list of the neighbors (outgoing edges) of vertexA.
        Because the Cayley graph is 2-regular, this list will ALWAYS
        have length 2.
        """
        curnode = Matrix22p(vertexA, self._p)
        nbrs = []
        nbrs.append(int(curnode * self._mul0))
        nbrs.append(int(curnode * self._mul1))
        return nbrs

    def __str__(self):
        """Returns a string for a dot file equivalent to this graph."""
        if self._directed:
            res = "digraph"
            edgesep = '->'
        else:
            res = "graph"
            edgesep = '--'
        res += ' ' + self._gname + ' {\n'
        
        for a in range(self.nvertices()):
            for b in self.neighbors(a):
                if self._directed or a < b:
                    res += '  ' + str(a) + ' ' + edgesep + ' ' + str(b) + ';\n'

        res += '}'
        return res
    

class Matrix22p:
    """Class for invertible 2x2 matrices mod p."""

    def __init__(self, val, p):
        """Creates a 2x2 matrix with entries a,b,c,d mod p.
        If val is a tuple (a,b,c,d) then those are the matrix entries.
        Otherwise it's assumed val is an integer encoding as in the __int__
        method below, and the opposite decoding is done to reconstruct the
        matrix from that integer.
        """
        self._p = p
        if isinstance(val, tuple):
            # if val is a tuple, just copy those values
            self._a, self._b, self._c, self._d = val
            return
        else:
            # decode from the single integer
            self._c = (val+1) % self._p
            val //= self._p
            self._b = (val+1) % self._p
            val //= self._p
            self._a = (val+1) % self._p
            if self._a == 0:
                self._d = self._c
                self._c = self._moddiv(self._p-1, self._b)
            else:
                self._d = self._moddiv(self._b*self._c + 1, self._a)

    def __int__(self):
        """Converts from the matrix with four elements a,b,c,d mod p
        to a single integer N.
        The encoding is done by taking a,b,c, subtracting one from each (mod p)
        to get a', b', c', and then storing a'*p^2 + b'*p + c'.
        Except if a=0, then the last part comes from entry d instead of c.
        """
        N = (self._a-1) % self._p
        N *= self._p
        N += (self._b-1) % self._p
        N *= self._p
        if self._a == 0:
            N += (self._d-1) % self._p
        else:
            N += (self._c-1) % self._p
        return N

    def _moddiv(self, numer, denom):
        """Does modular division of numer/denom mod p.
        That is, this computes a number quo such that denom*quo % p == numer.
        Uses the extended Euclidean algorithm.
        """
        if denom % self._p == 0:
            raise ZeroDivisionError("dividing by " + str(denom) + " which is 0 mod " + str(self._p))
        x, y = self._p, denom
        s, t = 0, 1
        while y > 1:
            x, (q, y) = y, divmod(x,y)
            s, t = t, s - q*t
        if y != 1:
            raise ValueError("error in division because " + str(self._p) + " is not prime")
        return (numer*t) % self._p

    def __mul__(self, other):
        """Multiplies this matrix by another, which must have the same prime p."""
        if not isinstance(other, Matrix22p) or self._p != other._p:
            raise ValueError("Incompatible matrix multiplication")
        a = (self._a * other._a + self._b * other._c) % self._p
        b = (self._a * other._b + self._b * other._d) % self._p
        c = (self._c * other._a + self._d * other._c) % self._p
        d = (self._c * other._b + self._d * other._d) % self._p
        return Matrix22p((a,b,c,d), self._p)

    def __str__(self):
        """Returns a string showing the 2x2 matrix structure."""
        return "[ [ {} {} ]\n  [ {} {} ] ] mod {}".format(self._a, self._b, self._c, self._d, self._p)


if __name__ == '__main__':
    # if run from the command-line, print out the dot file for Zemor(p)
    try:
        p = int(sys.argv[1])
    except:
        print("usage: {} p".format(sys.argv[0]))
        print("p must be a prime number.")
        print("Running this program will create the Zemor expander graph with the given p")
        print("and print the corresponding dot file of that graph to standard out.")
        exit(1)
    g = Zemor(p)
    print(g)
