"""
Dutch flag algorithm / three-way partitioning and sorting.

Debian requirements:
$ sudo aptitude install python-hypothesis python-pytest python-pytest-cov

$ python find.py
$ pytest find.py
$ pytest --hypothesis-show-statistics find.py
$ pytest --cov=. find.py
$ python-coverage run find.py
$ python-coverage html
"""

# This file has been style-checked using
# $ pylint --disable=C0326,W1114 --good-names=a,b,i,j,k,n find.py

from hypothesis import given
from hypothesis.strategies import integers, lists

def swap(a,i,j):
    """
    Exchange a[i] and a[j].
    """
    a[i],a[j] = a[j],a[i]

def partition(a,mid,start,length):
    """
    Partition a[start:start+length] in three zones:
    first the elements strictly smaller than the mid-value,
    then elements equal to it, finally the strictly larger elements.
    Return the indices of the zone boundaries.
    """
    # See at the end of the while loop for the specification of i,j,n.
    i = j = start
    n = start+length-1
    while j <= n:
        # For the record, testing revealed several bugs:
        # I initially had j<n instead of j<=n above;
        # I was increasing j all the time, and didn't
        # have elif but just a sequence of ifs.
        if a[j]<mid:
            swap(a, j, i)
            i += 1
            j += 1
        elif a[j]>mid:
            swap(a, j, n)
            n -= 1
        elif a[j] == mid:
            j += 1
        # Invariant
        assert start <= i
        assert i <= j
        assert j <= n+1
        assert all(elt < mid for elt in a[start:i])
        assert all(elt == mid for elt in a[i:j])
        assert all(elt > mid for elt in a[n+1:start+length])
    assert j == n+1
    return i, j

def sort(a,start,length):
    """Sort a[start:start+length] in place."""
    if length < 2:
        return
    mid = a[start]
    # print(a,start,length)
    i,j = partition(a,mid,start,length)
    # print(a,i,j,a[start:i],a[i:j],a[j:start+length])
    assert i-start < length
    assert start+length-j < length
    sort(a,start,i-start)
    sort(a,j,start+length-j)

#
# Tests
#

def test_sort_example():
    """Test partition on some fixed input."""
    a = [1,2,1,5,1,2,6,2,2,1]
    sort(a,0,len(a))

def test_partition_a():
    """Test partition on some fixed input."""
    a = [1,2,1,5,1,2,6,2,2,1]
    partition(a,2,0,len(a))
    assert a == [1,1,1,1,2,2,2,2,6,5]

def assert_valid(a,mid,i,j):
    """Verify three-way-partitioning with known boundary indices."""
    k = 0
    while k<i:
        assert a[k] < mid
        k += 1
    while k<j:
        assert a[k] == mid
        k += 1
    while k<len(a):
        assert a[k] > mid
        k += 1

@given(integers())
def test_partition_rnd_singleton(i):
    """Test partition() on a random singleton list."""
    a = [i]
    i,j = partition(a,2,0,1)
    assert_valid(a,2,i,j)

@given(lists(elements=integers()))
def test_partition_rnd(a):
    """Test partition() on a random list."""
    init = a[:] # copy a
    i,j = partition(a,2,0,len(a))
    print(init, "->", a)
    assert_valid(a,2,i,j)

def is_sorted(a):
    """Indicates whether the input is sorted."""
    b = a[:] # copy a
    b.sort()
    return a == b

@given(lists(elements=integers()))
def test_sort(a):
    """Test sort() on random list."""
    sort(a,0,len(a))
    assert is_sorted(a)

# Complete coverage is achieved with the above tests only,
# although they always call partition() and sort() on the whole array,
# i.e. with start=0 and length=len(a). It is a good idea to also
# test with non-trivial values.
#
# Note that we always use mid=2, but this isn't restrictive since list
# elements are taken randomly.

@given(lists(elements=integers(),min_size=2),integers(),integers())
def test_subpart(a,start,length):
    """Test partition() on random sublist."""
    # Renormalize start and length so that
    # 0 <= start < len(a) and 0 < length <= len(a)-start.
    start = start % (len(a)-1)
    length = 1 + (length % (len(a)-start-1))
    i,j = partition(a,2,start,length)
    a = a[start:start+length]
    assert_valid(a,2,i-start,j-start)

@given(lists(elements=integers(),min_size=2),integers(),integers())
def test_subsort(a,start,length):
    """Test partition() on random sublist."""
    # Renormalize start and length so that
    # 0 <= start < len(a) and 0 < length <= len(a)-start.
    start = start % (len(a)-1)
    length = 1 + (length % (len(a)-start-1))
    sort(a,start,length)
    assert is_sorted(a[start:start+length])

#
# Main
#

def main():
    """What is executed when the file is ran as executable."""
    # A run of partition()
    a = [2,2,2,1]
    print(a)
    partition(a, 2, 0, len(a))
    print(a)
    # A run of sort()
    a = [-1,0,0]
    print(a)
    sort(a, 0, len(a))
    print(a)

if __name__ == "__main__":
    main()
