#!/usr/bin/python3

import random
import sys
import queue

from showmaze import mazeref, inbounds

def makeEmpty(h, w):
    '''Creates an empty h by w maze'''
    rows = []
    rows.append([True] * (w+2))
    rows.append([False] * (w+1) + [True])
    for i in range(h-2):
        rows.append([True] + [False]*w + [True])
    rows.append([True] + [False]*(w+1))
    rows.append([True] * (w+2))
    return rows

def below(pos): return (pos[0]+1, pos[1])
def above(pos): return (pos[0]-1, pos[1])
def leftof(pos): return (pos[0], pos[1]-1)
def rightof(pos): return (pos[0], pos[1]+1)

def neighbors(pos):
    return (below(pos), above(pos), leftof(pos), rightof(pos))

def dist(pos1, pos2):
    return abs(pos1[0] - pos2[0]) + abs(pos1[1] - pos2[1])

def connected(sofar, pos):
    for np in neighbors(pos):
        if np in sofar:
            return True
    return False

def reachable(maze, p1, p2):
    isr = set()
    q = queue.LifoQueue()
    q.put(p1)
    while not q.empty():
        p = q.get()
        if p == p2:
            return True
        if p not in isr:
            isr.add(p)
            nplist = [(dist(np, p2), np) for np in neighbors(p)]
            nplist.sort()
            for d, np in nplist:
                if np not in isr and not mazeref(maze,np):
                    q.put(np)
    return False

def fillReachable(h, w, param):
    maze = makeEmpty(h, w)
    start = (1,0)
    end = (h, w+1)
    assert mazeref(maze,start) is False
    assert mazeref(maze,end) is False

    for i in range(round(h*w*param)):
        y = random.randrange(1, h+1)
        x = random.randrange(1, w+1)
        if not maze[y][x]:
            maze[y][x] = True
            if not reachable(maze, start, end):
                # undo
                maze[y][x] = False

    return maze

def fillFromPath(h, w, param):
    maze = makeEmpty(h, w)
    start = (1,0)
    end = (h, w+1)
    assert mazeref(maze,start) is False
    assert mazeref(maze,end) is False

    path = [start]
    while path[-1] != end:
        p1 = below(path[-1])
        d1 = end[0] - p1[0] + 1
        p2 = rightof(path[-1])
        d2 = end[1] - p2[1] + 1
        if random.randrange(2):
            p1,p2 = p2,p1
            d1,d2 = d2,d1
        if mazeref(maze, p1):
            np = p2
        elif mazeref(maze, p2):
            np = p1
        elif random.uniform(0,1) < d1/(d1+d2):
            np = p1
        else:
            np = p2
        path.append(np)

    for i in range(round((h*w - len(path)+2)*param)):
        y = random.randrange(1, h+1)
        x = random.randrange(1, w+1)
        if (y,x) not in path:
            maze[y][x] = True

    return maze

def rPrims(h, w, param):
    maze = makeEmpty(h, w)
    
    # divide into cells
    for i in range(2, h, 2):
        for j in range(1, w+1):
            maze[i][j] = True
    for j in range(2, w, 2):
        for i in range(1, h+1):
            maze[i][j] = True

    included = set()
    q = queue.PriorityQueue()

    def addNeighbors(pos):
        nonlocal q, maze, included
        for direction in (above, below, rightof, leftof):
            wp = direction(pos)
            if not mazeref(maze, wp): continue
            other = direction(wp)
            if not inbounds(maze, other): continue
            if other in included: continue
            q.put((random.uniform(0,1), wp, other))

    # choose random starting point
    y = random.randrange(1, h+1, 2)
    x = random.randrange(1, w+1, 2)
    included.add((y,x))
    addNeighbors((y,x))

    while not q.empty():
        (pri, (wy,wx), opp) = q.get()
        if opp in included: continue
        included.add(opp)
        maze[wy][wx] = False
        addNeighbors(opp)

    # randomly remove some walls
    s = 1/8
    scale = (1-param)*s/(s+param)
    for i in range(round(scale*h*w)):
        y = random.randrange(1, h+1)
        x = random.randrange(1, w+1)
        maze[y][x] = False

    return maze

def makeCharMaze(maze):
    return [['X' if entry else ' ' for entry in row] for row in maze]

def addPrizes(maze, nprizes):
    h, w = len(maze)-2, len(maze[0])-2
    while nprizes > 0:
        y = random.randrange(1,h+1)
        x = random.randrange(1,w+1)
        if maze[y][x] == ' ':
            maze[y][x] = 'O'
            nprizes -= 1

def printMaze(maze):
    for row in maze:
        print(''.join(row))

genmeths = {
    'fillFromPath': 'Creates a path and then fills in around it',
    'fillReachable': "Adds obstacles that don't prevent reaching the finish (slow)",
    'rPrims': "Creates a traditional-looking maze using Prim's MST algorithm"
}
defmeth = 'rPrims'

def usage(exitval=0):
    print("Usage: {} height width [method] [param]".format(sys.argv[0]))
    print("method (optional) is one of:")
    for (meth, desc) in genmeths.items():
        df = " (default)" if meth == defmeth else ""
        print("  {}{}: {}".format(meth, df, desc))
    print('param (optional) indicates how "dense" the output will be.')
    exit(exitval)

if __name__ == '__main__':
    if not 3 <= len(sys.argv) <= 5:
        usage(1)
    h = int(sys.argv[1])
    w = int(sys.argv[2])
    method = defmeth if len(sys.argv) <= 3 else sys.argv[3]
    param = 0.5 if len(sys.argv) <= 4 else float(sys.argv[4])

    if method not in genmeths:
        print("Invalid method number")
        usage(2)

    genfunc = globals()[method]
    maze = genfunc(h, w, param)

    cm = makeCharMaze(maze)
    addPrizes(cm, 10)
    printMaze(cm)
