#!/usr/bin/python3

# SI 335: Computer Algorithms
# Unit 6

import sys
from heapq import heappush, heappop
from random import randrange
from copy import copy

infinity = float('inf')

class Graph:
    '''An abstract base class for graphs.
       These are all the methods that would have to be implemented.
       Here it's just an empty graph.'''
    m = 0
    n = 0
    
    def nodes(self):
        return []

    def edges(self):
        E = []
        for u in self.nodes():
            for (v,w) in self.edgesFrom(u):
                E.append((u,v,w))
        return E

    def isEdge(self, u, v):
        return (self.edgeWeight(u, v) < infinity)

    def edgeWeight(self, u, v):
        return infinity

    def edgesFrom(self, u):
        return []


class ALGraph(Graph):
    '''Adjacency list representation of a graph'''

    def __init__(self, vertices, edges):
        self.n = len(vertices)
        self.m = len(edges)
        self.V = vertices
        
        # self.AL is the actual adjacency list, initialized to all empty.
        self.AL = {}
        for u in self.V:
            self.AL[u] = []

        # add each edge to the proper adjacency list
        for (u,v,w) in edges:
            self.AL[u].append((v,w))

    def nodes(self):
        return self.V

    def edgeWeight(self, u, v):
        for (other, w) in self.AL[u]:
            if other == v:
                return w
        return infinity

    def edgesFrom(self, u):
        return self.AL[u]


class AMGraph(Graph):
    '''Adjacency matrix representation of a graph'''

    def __init__(self, vertices, edges):
        self.n = len(vertices)
        self.m = len(edges)
        self.V = list(vertices)

        # lookup table for the vertices
        self.vertind = {}
        i = 0
        for u in self.V:
            self.vertind[u] = i
            i += 1
        
        # self.AM is the actual adjacency matrix, initialized to 0 and infinity
        self.AM = []
        for i in range(self.n):
            self.AM.append([infinity] * self.n)
            self.AM[i][i] = 0

        # add each edge weight to the matrix
        for (u,v,w) in edges:
            self.AM[self.vertind[u]][self.vertind[v]] = w

    def nodes(self):
        return self.V

    def edgeWeight(self, u, v):
        return self.AM[self.vertind[u]][self.vertind[v]]

    def edgesFrom(self, u):
        L = []
        uind = self.vertind[u]
        for i in range(self.n):
            w = self.AM[uind][i]
            if w < infinity:
                L.append((self.V[i], w))
        return L


def DFS(G, start, end):
    colors = {}
    for u in G.V:
        colors[u] = "white"
    fringe = [(start, 0)]
    while len(fringe) > 0:
        (u, w1) = fringe[-1] # end of the list
        #print("DFS",fringe,colors[u])
        if colors[u] == "white":
            if u == end:
                return w1
            colors[u] = "gray"
            for (v, w2) in G.edgesFrom(u):
                if colors[v] == "white":
                    fringe.append((v, w1+w2))
        elif colors[u] == "gray":
            colors[u] = "black"
        else:
            fringe.remove((u, w1))

def BFS(G, start, end):
    colors = {}
    for u in G.V:
        colors[u] = "white"
    fringe = [(start, 0)]
    while len(fringe) > 0:
        (u, w1) = fringe[0] # the only difference from DFS!
        #print("BFS",fringe,colors[u])
        if colors[u] == "white":
            if u == end:
                return w1
            colors[u] = "gray"
            for (v,w2) in G.edgesFrom(u):
                if colors[v] == "white":
                    fringe.append((v, w1+w2))
        elif colors[u] == "gray":
            colors[u] = "black"
        else:
            fringe.remove((u, w1))

def linearize(G):
    '''Returns an ordering of the vertices in the DAG G that have
       no backward edges.'''
    order = []
    colors = {}
    fringe = []
    for u in G.V:
        colors[u] = "white"
        fringe.append(u)
    while len(fringe) > 0:
        u = fringe[-1]
        if colors[u] == "white":
            colors[u] = "gray"
            for (v,w2) in G.edgesFrom(u):
                if colors[v] == "white":
                    fringe.append(v)
        elif colors[u] == "gray":
            colors[u] = "black"
            order.insert(0, u)
        else:
            fringe.pop()
    return order

def dijkstraHeap(G, start):
    '''A dictionary of shortest path lengths from start in G is returned.'''
    shortest = {}
    colors = {}
    for u in G.V:
        colors[u] = "white"
    fringe = [(0, start)] # Note: the weight must come first for the order.
    while len(fringe) > 0:
        (w1, u) = heappop(fringe)
        if colors[u] == "white":
            colors[u] = "black"
            shortest[u] = w1
            for (v, w2) in G.edgesFrom(u):
                heappush(fringe, (w1+w2, v))
    return shortest

def dijkstraArray(G, start):
    '''A dictionary of shortest path lengths from start in G is returned.'''
    shortest = {}
    fringe = {}
    for u in G.V:
        fringe[u] = infinity
    fringe[start] = 0
    while len(fringe) > 0:
        w1 = min(fringe.values())
        for u in fringe:
            if fringe[u] == w1:
                break
        del fringe[u]
        shortest[u] = w1
        for (v, w2) in G.edgesFrom(u):
            if v in fringe:
                fringe[v] = min(fringe[v], w1+w2)
    return shortest


def recShortest(AM, i, j, k):
    '''Calculates the shortest path from node numbered i to
       node numbered j, using adjacency matrix AM, not going
       through any node higher than k'''
    if k == -1:
        return AM[i][j]
    else:
        option1 = recShortest(AM, i, j, k-1)
        option2 = recShortest(AM, i, k, k-1) + recShortest(AM, k, j, k-1)
        return min(option1, option2)

def FloydWarshall(AM):
    '''Calculates EVERY shortest path length between any two vertices
       in the original adjacency matrix graph.'''
    L = copy(AM)
    n = len(AM)
    for k in range(0, n):
        for i in range(0, n):
            for j in range(0, n):
                L[i][j] = min(L[i][j], L[i][k] + L[k][j])
    return L

def approxVC(G):
    C = set() # makes an empty set
    for (u,v,w) in G.edges():
        if u not in C and v not in C:
            C.add(u)
            C.add(v)
    return C


# The rest is just for testing/debugging purposes.

# Specifications of my example graphs
def weighted(E):
    '''Makes a weighted from an unweighted graph'''
    return tuple(sorted((u,v,1) for (u,v) in E))

def directed(E):
    '''Makes directed from an undirectd graph'''
    Eset = set(E)
    for (u,v,w) in E:
        Eset.add((v,u,w))
    return tuple(sorted(Eset))

def fromE(E):
    '''Determines vertices from edges'''
    Vset = set()
    for (u,v,w) in E:
        Vset.add(u)
        Vset.add(v)
    return tuple(sorted(Vset)), E

a,b,c,d,e,f,g,h,i,j,k,l,m = (chr(let) for let in range(ord('a'), ord('n')))

dag1 = fromE(weighted(
    ((b,c), (b,a), (c,a), (c,e), (c,d), (a,d), (e,d))
))

ex1 = fromE(
    ((a,c,10), (a,d,22), (b,c,53), (b,e,45), (c,a,21), (c,e,33), (e,d,19))
)

ex2 = fromE(directed(
    ((a,b,6), (a,c,6), (a,d,3), (b,d,2), (b,e,4), (c,d,5), (c,e,1), (d,e,4))
))

ex3 = fromE(directed(
    ((a,c,1), (c,d,6), (b,e,1), (c,f,4), (a,f,6), (b,f,2), (c,e,5), 
     (e,d,2), (b,c,1))
))

match1 = fromE(directed(weighted(
    ((l,h), (h,d), (d,a), (a,b), (c,f), (f,e), (e,i), (j,m), (g,k), (j,e), (j,i),
     (g,f), (a,h), (d,b), (b,e), (l,m), (h,i), (c,b), (k,f), (m,k), (j,f), (d,i),
     (i,m))
)))


if __name__ == '__main__':
    BFS(ALGraph(*ex1),b,d)
    DFS(AMGraph(*ex1),a,b)
    
    assert linearize(ALGraph(*dag1)) == [b,c,a,e,d]
    assert linearize(AMGraph(*dag1)) == [b,c,a,e,d]

    assert dijkstraArray(ALGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}
    assert dijkstraHeap(ALGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}
    assert dijkstraArray(AMGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}
    assert dijkstraHeap(AMGraph(*ex2), a) == {a:0, b:5, c:6, d:3, e:7}

    ex2am = AMGraph(*ex2).AM
    assert recShortest(ex2am, 0, 1, 2) == 6
    assert recShortest(ex2am, 0, 1, 3) == 5

    L = FloydWarshall(ex2am)
    assert L[0][1] == 5
    assert L[0][4] == 7
    assert L[2][1] == 5

    print("All checks passed!")

del a,b,c,d,e,f,g,h,i,j,k,l,m
