#!/usr/bin/python3

# SI 335: Computer Algorithms
# Unit 4

import random

def add(X, Y, B):
    assert(len(X) >= len(Y))
    carry = 0
    A = [0] * (len(X) + 1)
    for i in range(0, len(Y)):
        carry, A[i] = divmod(X[i] + Y[i] + carry, B)
    for i in range(len(Y), len(X)):
        carry, A[i] = divmod(X[i] + carry, B)
    A[len(X)] = carry
    return A

def sub(X, Y, B):
    assert(len(X) >= len(Y))
    carry = 0
    A = [0] * len(X) 
    for i in range(0, len(Y)):
        carry, A[i] = divmod(X[i] - Y[i] + carry, B)
    for i in range(len(Y), len(X)):
        carry, A[i] = divmod(X[i] + carry, B)
    assert(carry == 0)
    return A

def smul(X, Y, B):
    assert(len(X) == len(Y))
    n = len(X)
    A = [0] * (2*n)
    T = [0] * n
    for i in range(0, n):
        # set T = X * Y[i]
        carry = 0
        for j in range(0, n):
            T[j] = (X[j] * Y[i] + carry) % B
            carry = (X[j] * Y[i] + carry) // B
        # add T to A, the running sum
        A[i : i+n+1] = add(A[i : i+n], T[0 : n], B)
        A[i+n] += carry
    return A

def kmul(X, Y, B):
    assert(len(X) == len(Y))
    n = len(X)
    if n <= 3:
        return smul(X, Y, B)
    else:
        m = n // 2
        X0, X1 = X[0 : m], X[m : n]
        Y0, Y1 = Y[0 : m], Y[m : n]
        U = add(X1, X0, B)
        V = add(Y1, Y0, B)
        P0 = kmul(X0, Y0, B)
        P1 = kmul(X1, Y1, B)
        P2 = kmul(U, V, B)
        A = [0] * (2*n + 1)
        A[0 : 2*m] = P0
        A[2*m : 2*n] = P1
        A[m : 2*n+1] = add(A[m : 2*n], P2, B)
        A[m : 2*n+1] = sub(A[m : 2*n+1], P0, B)
        A[m : 2*n+1] = sub(A[m : 2*n+1], P1, B)
        assert(A[2*n] == 0)
        return A[0 : 2*n]

def fib(n):
    if n <= 1: 
        return n
    else:
        return fib(n-1) + fib(n-2)

fib_table = {}
def fib_memo(n):
    if n not in fib_table:
        if n <= 1:
            return n
        else:
            fib_table[n] = fib_memo(n-1) + fib_memo(n-2)
    return fib_table[n]

# Note: the *'s are so that the arguments to this function
# get interpreted as a tuple. A tuple is basically a list that
# can't be changed. So for example, mm(5,2,6,3) is the
# correct way to call this function (and it returns 66).
def mm(*D):
    n = len(D) - 1
    if n == 1:
        return 0
    else:
        fewest = float('inf') # (just a placeholder)
        for i in range(1, n):
            t = mm(*D[0 : i+1]) + D[0]*D[i]*D[n] + mm(*D[i : n+1])
            if t < fewest:
                fewest = t
        return fewest

mm_table = {}
def mmm(*D):
    n = len(D) - 1
    if D not in mm_table:
        if n == 1:
            mm_table[D] = 0
        else:
            fewest = float('inf')
            for i in range(1, n):
                t = mmm(*D[0 : i+1]) + D[0]*D[i]*D[n] + mmm(*D[i : n+1])
                if t < fewest:
                    fewest = t
            mm_table[D] = fewest
    return mm_table[D]

def dmm(*D):
    n = len(D) - 1
    # A will be a (n+1) by (n+1) array
    A = [[0] * (n+1) for i in range(n+1)]
    for diag in range(1,n+1):
        for row in range(0, n-diag+1):
            col = diag + row
            # This part is just like the original!
            if diag == 1:
                A[row][col] = 0
            else:
                A[row][col] = float('inf')
                for i in range(row+1, col):
                    t = A[row][i] + D[row]*D[i]*D[col] + A[i][col]
                    if t < A[row][col]:
                        A[row][col] = t
    printTable(A)
    return A[0][n]


# The rest is just for testing purposes

def toDigits(n):
    return list(map(int,reversed(str(n))))

def fromDigits(X):
    return int(''.join(map(str,reversed(X))))

def printTable(A):
    def printRow(r):
        for x in r:
            print("{:>4} ".format(x), end="")
    print(" " * (3 + 2), end="")
    printRow(range(len(A[0])))
    print()
    for i in range(len(A)):
        print("{:>3} [".format(i), end="")
        printRow(A[i])
        print("]")
