"""A simple module implementing the lamplighter group Z/2 ≀ Z, and
domino tilings on it.

Written by Laurent Bartholdi, last version 2020-05-03.
"""
module LL

using Pkg
if any(x->x[2].name=="CryptoMiniSat",Pkg.dependencies())
    using CryptoMiniSat
    const solver = CryptoMiniSat
elseif any(x->x[2].name=="PicoSAT",Pkg.dependencies())
    using PicoSAT
    const solver = PicoSAT
else
    @warn "No SAT solver seems installed. If you don't add package CryptoMiniSat or PicoSAT, you won't be able to compute tilings"
    solver = nothing
end

if any(x->x[2].name=="GLMakie",Pkg.dependencies())
    using GLMakie
else
    @warn "Makie package doesn't seem installed. You won't be able to plot tetrahedra"
end

using DataStructures, Combinatorics, Distributed
import Base.*, Base.==
export Element, tile, graph, solve, flipinvert, flipab, prodtiles

"""reverse the bits in the k-bit number x"""
function reversebits(x::UInt,k::Int)
    y = zero(x)
    for i=1:k
        y = 2y+(x%2)
        x ÷= 2
    end
    y
end

r"""Element implements very crudely elements of a wreath product.
r is the binary representations of the lamps, read from left to right.
shift is the position of the lamplighter+(WORD_SIZE÷2), which is assumed to be non-negative.

Elements are printed with an underscore between the origin and the
position of the lamplighter, and with the lamps before the origin in
reverse colour. Thus "\reverse{110}\underscore{101}1" means: "lamps at position
-2.5,-1.5,0.5, 2.5 and 3.5 are on, and the lamplighter is at position 3".  This
represented, internally, as (r=2^29+2^30+2^32+2^34+2^35,shift=3+32).

No multiplication of elements is implemented, but the right action of
the group on itself is implemented via *: x*n is the result of the
action of the nth generator, numbered as 1=a,2=b,3=a^-1,4=b^-1

Thus, for example, the group relations are 1^n4^n1^n4^n for all n.
"""
struct Element
    bits::UInt
    shift::Int
end

const LLSHIFT = Sys.WORD_SIZE÷2
Element() = Element(0,LLSHIFT)
Element(left,right,shift) = Element(reversebits(UInt(left),shift+LLSHIFT) | right<<(shift+LLSHIFT), shift+LLSHIFT)
                                        
function Base.show(io::IO,x::Element)
    s = reverse(string(x.bits, base=2))
    while length(s)<Sys.WORD_SIZE s *= '0' end
    start = 1
    while s[start]=='0' && start <= LLSHIFT && start <= x.shift start += 1 end
    stop = Sys.WORD_SIZE
    while s[stop]=='0' && stop > LLSHIFT && stop > x.shift && (stop > LLSHIFT+1 || start <= LLSHIFT) stop -= 1 end
    if x.shift < LLSHIFT
        print(io, "\033[7m",s[start:x.shift],"\033[4m",s[x.shift+1:LLSHIFT],"\033[0m",s[LLSHIFT+1:stop])
    else
        print(io, "\033[7m",s[start:LLSHIFT],"\033[27;4m",s[LLSHIFT+1:x.shift],"\033[0m",s[x.shift+1:stop])
    end
end

Base.hash(x::Element) = hash([x.bits,x.shift])

function ==(x::Element,y::Element)
    (x.bits == y.bits && x.shift == y.shift)
end

function *(x::Element,y::Int)
    if y≥3 && x.shift==0
        error("cannot shift beyond 0")
    elseif y==1
        Element(x.bits,x.shift+1)
    elseif y==2
        Element(x.bits⊻(1<<x.shift),x.shift+1)
    elseif y==3
        Element(x.bits,x.shift-1)
    elseif y==4
        Element(x.bits⊻(1<<(x.shift-1)),x.shift-1)
    end
end

function *(x::Element,y::Element)
    Element(x.bits⊻(y.bits<<x.shift),x.shift+y.shift)
end

one(x::Element) = Element()

isone(x::Element) = (x.bits == 0 && x.shift == LLSHIFT)

# we need a fast representation of graphs.
# they have a vertex set, a bijection of it with [1:nelem], and a and b neighbour relations.
struct Graph{T}
    v::Vector{T}
    n::Int
    label::Dict{T,Int} # lookup indices of elements of v
    nbhd::Array{Int,2}
end

vars(g::Graph,i::Int) = ((i-1)%g.n+1,(i-1)÷g.n+1)
srav(g::Graph,i::Int,j::Int) = (i-1)+(j-1)*g.n+1

"""
construct a graph object.

graph(height) produces a tetrahedron with 2^height*(height+1) vertices.

graph(height,periodic=true) produces the same tetrahedron, but identifying the top and bottom 2^height vertices.
"""
function graph(height::Int;periodic=false,periods=Element[])
    ball = [Element(r<<LLSHIFT,k+LLSHIFT) for k=0:(periodic ? height-1 : height) for r=0:UInt(2^height-1)]
        
    nelem = length(ball)
    llab = Dict(ball[i]=>i for i=1:nelem)
    periodic && for r=0:UInt(2^height-1)
        push!(llab,Element(r<<LLSHIFT,height+LLSHIFT)=>llab[Element(r<<LLSHIFT,LLSHIFT)])
    end

    nbhd = [get(llab,x*s,0) for x=ball,s=1:2]

    if !isempty(periods)
        merge = IntDisjointSets(nelem)
        done = false
        while !done
            done = true
            for i=1:nelem, g=periods
                j = get(llab,g*ball[i],0)
                j==0 && continue
                !in_same_set(merge,i,j) && (union!(merge,i,j); done = false)
            end
        end
        roots = [i for i=1:nelem if find_root(merge,i)==i]
        stoor = zeros(Int,nelem)
        for i=1:length(roots); stoor[roots[i]] = i; end
        for i=1:nelem; stoor[i] = stoor[find_root(merge,i)]; end
        
        llab = Dict(ball[i]=>stoor[i] for i=1:nelem)
        ball = ball[roots]
        nelem = length(roots)
        nbhd = [nbhd[i,s]==0 ? 0 : stoor[nbhd[i,s]] for i=roots,s=1:2]
    end

    Graph{Element}(ball,nelem,llab,nbhd)
end

"""
produces solutions to a tiling problem.

solve(graph,tiles) returns a Dict mapping Element to tile indices in
the list of tiles; or returns the code of the solver (typically
:unsolvable) if there is no solution).

the tiles are Wang tiles, specified as a list of legal tiles [a,b,c,d]
with integers a,b,c,d, saying which colours can lie on the 1,2,3,4-generators (a,b,a^-1,b^-1, see Element).

the tile indices are on vertices of the Cayley graph, and the tile
colours are on edges.

Optional extra arguments:
* seed (a list of Element => tile index)
* numsols (a limit on the number of solutions to return)
* verbose, proplimit: options to pass to the SAT solver
"""
function solve(g::Graph,tiles;seed=nothing,numsols::Integer=-1,verbose::Integer=0,proplimit::Integer=0)
    rules = Vector{Int}[]
    ntiles = length(tiles)
    for x=1:g.n
        push!(rules,[srav(g,x,i) for i=1:ntiles]) # at least one colour
        for i=1:ntiles, j=i+1:ntiles
            push!(rules,[-srav(g,x,i),-srav(g,x,j)]) # at most one colour
        end
        for i=1:ntiles, j=1:ntiles, k=1:2
            tiles[i][k]≠tiles[j][k+2] && g.nbhd[x,k]≠0 && push!(rules,[-srav(g,x,i),-srav(g,g.nbhd[x,k],j)])
        end
    end
    if seed≠nothing
        for s in seed
            push!(rules,[srav(g,g.label[s[1]],s[2])])
        end
    end
    if numsols == -1
        sat = solver.solve(rules,verbose=verbose,proplimit=proplimit)
        if isa(sat,Vector)
            return Dict((v = vars(g,s); g.v[v[1]]=>v[2]) for s=sat if s>0)
        else
            return sat
        end
    else
        sols = Dict[]
        for sat=Iterators.take(solver.itersolve(rules,verbose=verbose,proplimit=proplimit),numsols)
            push!(sols, Dict((v = vars(g,s); g.v[v[1]]=>v[2]) for s=sat if s>0))
        end
        sols
    end
end
solvable(g,tiles;verbose=0,proplimit=0) = isa(solve(g,tiles,verbose=verbose,proplimit=proplimit),Dict)

"""pinch a cube into a tetrahedron"""
point3d(x,y,z) = Point3f(z*x+(1-z)*y,z*x-(1-z)*y,2z)

"""plot a tetrahedron.

tetrahedron(height) creates a Makie 3D plot object of a tetrahedron of given height.

The extra optional arguments are functions:
* vertexcolorizer(Element)
* edgecolorizer(Element,index) where index∈[1,2] is a generator number, a or b
"""
function tetrahedron(height::Int; edgecolorizer = nothing, vertexcolorizer = nothing)
    one = 1-1000*eps()
    edges = Pair{Point3f,Point3f}[]
    edgecolors = Symbol[]
    vertices = Point3f[]
    vertexcolors = Symbol[]
    for k=0:height
        z = k/height
        for l=0:2^k-1
            x = -1+2l/(2^k-one)
            for m=0:2^(height-k)-1
                y = -1+2m/(2^(height-k)-one)
                e = Element(l,m,k)
                k<height && for c=1:2
                    nx = -1+2(2l+(c == 1 ? m%2 : 1-m%2))/(2^(k+1)-one)
                    ny = -1+2(m÷2)/(2^(height-(k+1))-one)
                    nz = (k+1)/height
                    if edgecolorizer==nothing
                        push!(edges,point3d(x,y,z)=>point3d(nx,ny,nz))
                    else
                        w = edgecolorizer(e,c)
                        if w≠nothing
                            push!(edges,point3d(x,y,z)=>point3d(nx,ny,nz))
                            push!(edgecolors,w)
                        end
                    end
                end
                if vertexcolorizer≠nothing
                    w = vertexcolorizer(e)
                    if w≠nothing
                        push!(vertices,point3d(x,y,z))
                        push!(vertexcolors,w)
                    end
                end
            end
        end
    end
    (co,lw) = edgecolorizer == nothing ? (:gray,1) : (edgecolors,4)   
    linesegments(edges,color=co,linewidth=lw,axis=(xgridvisible=false,ygridvisible=false,zgridvisible=false))
    scatter!(vertices,color=vertexcolors, markersize=50)
    current_figure()
end

function tetrahedron(height::Int, tiles::Vector, dict::Dict)
    tetrahedron(height, edgecolorizer=edgecolor(tiles, dict, [:red,:green,:black,:blue]))
end

# some helper functions

"""Take Cartesian product of tilesets"""
prodtiles(tiles1,tiles2,filter=(i,j)->true) = [[(tiles1[i][k],tiles2[j][k]) for k=1:4] for i=1:length(tiles1) for j=1:length(tiles2) if filter(i,j)]

"""Flip a tileset by exchanging a and a^-1, b and b^-1"""
flipinvert(tiles) = [t[[3,4,1,2]] for t=tiles]

"""Flip a tileset by exchanging a and b, a^-1 and b^-1"""
flipab(tiles) = [t[[2,1,4,3]] for t=tiles]

"""
A graph-walking automaton consists of four lists:
label[i] is the list of allowed labels at state i
next[i] is a list of (generator,new state) leaving state i
initial[i] marks whether state i is initial
final[i] marks whether state i is final

!!!should we set a bound on the recursion depth?
"""
struct GWAutomaton
   label::Vector{Vector{Int}}
   next::Vector{Vector{Tuple{Int,Int}}}
   initial::Vector{Bool}
   final::Vector{Bool}
end

function *(a::GWAutomaton, b::GWAutomaton)
    m = length(a.label)
    n = length(b.label)
    LL.GWAutomaton([[(p,q) for p=x for q=y] for x=a.label for y=b.label],
                   [[(i,n*(j-1)+l) for (i,j)=x for (k,l)=y if i==k] for x=a.next for y=b.next],
                   [x&&y for x=a.initial for y=b.initial],
                   [x&&y for x=a.final for y=b.final])
end

function recursivewalk(g::Element, state::Int, a::GWAutomaton, tiling::Dict, solutions::Vector{Element})
    tiling[g] ∈ a.label[state] || return
#    @info "at element $g in state $state, label $(tiling[g])"*(a.final[state] ? " (final)" : "")
    a.final[state] && push!(solutions,g)
    for e=a.next[state]
        recursivewalk(g*e[1],e[2],a,tiling,solutions)
    end
end

"""
walk(g,a,tiling) walks through the Cayley graph of the Lamplighter group, starting at g, using the GWAutomaton a and the Dict tiling.
"""
function walk(g::Element, a::GWAutomaton, tiling::Dict)
    solutions = Element[]
    for state=1:length(a.label)
        a.initial[state] && recursivewalk(g,state,a,tiling,solutions)
    end
    if length(solutions)≠1
        error("Non-unique solution ",solutions)
    end
    solutions[1]
end

end

################################################################
# now we use the module LL to construct some tilesets

using Test

################################################################
# the comb

comb = [[:a,:b,:a,:d], # our first way of emulating a grid
        [:o,:b,:s,:b],
        [:s,:o,:s,:o],
        [:s,:d,:r,:d],
        [:r,:o,:r,:o],
        [:o,:o,:o,:o]]

function combedgecolor(dict::Dict)
    function any(g::LL.Element,x::Int)
        edgecol = Dict(:a => :red, :b => :blue, :d => :green, :s => :yellow, :r => :orange, :o => nothing)
        edgecol[comb[dict[g]][x]]
    end
end

function combvertexcolor(dict::Dict)
    function any(g::LL.Element)
        vertexcol = [:black,:grey,:white,nothing,nothing,nothing,nothing]
        vertexcol[dict[g]]
    end
end

"""Sample plot:

julia> s = LL.solve(LL.graph(6),sea,seed=[LL.Element(0,0,1)=>1]);
julia> LL.tetrahedron(6,vertexcolorizer=combvertexcolor(s))
julia> Ll.tetrahedron(6,vertexcolorizer=combvertexcolor(s),edgecolorizer=combedgecolor(s))
"""

combeast = LL.GWAutomaton([[1,2],[2]],
                          [[(2,2)],[]],
                          [true,false],
                          [false,true])
combwest = LL.GWAutomaton([[1,2],[2]],
                          [[],[(4,1)]],
                          [false,true],
                          [true,false])

combnorth = LL.GWAutomaton([[1],[1], [2],[3],[4],[4],[3],[1,2],[2]],
                           [[(1,2)],[], [(3,4),(3,5)],[(3,4),(3,5)],[(2,6),(2,8)],[(1,7),(1,8)],[(1,7),(1,8)],[(2,9)],[]],
                           [true,false, true,[false for _=1:6]...],
                           [false,true, [false for _=1:6]...,true])

combsouth = LL.GWAutomaton([[1],[1], [2],[3],[4],[4],[3],[1,2],[2]],
                           [[],[(3,1)], [],[(1,3),(1,4)],[(1,3),(1,4)],[(4,5)],[(3,7),(3,6)],[(3,7),(3,6),(4,5)],[(4,8)]],
                           [false,true, [false for _=1:6]...,true],
                           [true,false, true,[false for _=1:6]...])

"""Sample crawling:
julia> root = LL.Element(0,0,0);
julia> s = LL.solve(LL.graph(6),sea,seed=[root=>1]);
julia> LL.walk(root,combnorth,s)
"""

LL.solver != nothing && @testset "Comb" begin
    global root = LL.Element(0,0,0)
    global combtiling = LL.solve(LL.graph(10),comb,seed=[root=>1])

    for i=0:4, j=0:4
        g = root
        for _=1:i g = LL.walk(g,combnorth,combtiling) end
        for _=1:j g = LL.walk(g,combeast,combtiling) end
        
        gn = LL.walk(g,combnorth,combtiling)
        ge = LL.walk(g,combeast,combtiling)
        @test LL.walk(gn,combeast,combtiling) == LL.walk(ge,combnorth,combtiling)
        @test LL.walk(gn,combsouth,combtiling) == g
        @test LL.walk(ge,combwest,combtiling) == g
    end
end

################################################################
# the sealevel

ray0 = [[true,false,true,true],
        [false,true,true,true],
        [false,false,false,false]] # mark a face
ray = LL.prodtiles(ray0,LL.flipinvert(ray0),(i,j)->(i,j)==(1,1) || i==3 || j==3)

# ville's sea-level tiling
sea0 = [["↖", "↗", "↙", "↘"], #1
        ["↗", "↖", "↙", "↘"],
        ["↖", "↗", "↘", "↙"],
        ["↗", "↖", "↘", "↙"],
        
        ["↖", "↕", "↖", "↖"], #5
        ["↕", "↖", "↖", "↖"],
        ["↗", "↕", "↗", "↗"],
        ["↕", "↗", "↗", "↗"],
        ["↕", "↕", "↕", "↕"],
         
        ["↙", "↙", "↙", "↕"], #10
        ["↙", "↙", "↕", "↙"],
        ["↘", "↘", "↘", "↕"],
        ["↘", "↘", "↕", "↘"]]

"""
sea0 = [["↖", "↗", "↙", "↘"], #1
        ["↗", "↖", "↙", "↘"],
        ["↖", "↗", "↘", "↙"],
        ["↗", "↖", "↘", "↙"],
        
        ["↖", "↑", "↖", "↖"], #5
        ["↑", "↖", "↖", "↖"],
        ["↗", "↑", "↗", "↗"],
        ["↑", "↗", "↗", "↗"],
        ["↑", "↑", "↑", "↑"],
         
        ["↙", "↙", "↙", "↓"], #10
        ["↙", "↙", "↓", "↙"],
        ["↘", "↘", "↘", "↓"],
        ["↘", "↘", "↓", "↘"],
        ["↓", "↓", "↓", "↓"]]
"""

sea = LL.prodtiles(ray,sea0,(i,j)->(ray[i][1][2] <= (j∉[2,4,6,8])) && (ray[i][3][1] <= (j∉[3,4,11,13])))

proj1sea = [findfirst(==([t[i][1] for i=1:4]),sea0) for t=sea]
proj2sea = [findfirst(==([t[i][2] for i=1:4]),sea0) for t=sea]
            
function seaedgecolor(dict::Dict)
    function any(g::LL.Element, x::Int)
        t = sea[dict[g]]
        index1 = proj1sea[t]
        index2 = proj2sea[t]

        [:red,:black][x]
    end
end

function seavertexcolor(dict::Dict)
    function any(g::LL.Element)
        #t = sea[dict[g]]
        t = dict[g]
        index1 = proj1sea[t]
        index2 = proj2sea[t]

        if index1 == 1
            return :black
        end
        d0 = [:white,:gray,:white,:gray,
              :pink,:orange,nothing,nothing,nothing,
              :blue,:green,nothing,nothing,nothing]
        d0[index2]
    end
end

"""Sample plot:

julia> s = LL.solve(LL.graph(6),sea,seed=[LL.Element(0,0,3)=>1]);
julia> LL.tetrahedron(6,vertexcolorizer=seavertexcolor(s))
julia> Ll.tetrahedron(6,vertexcolorizer=seavertexcolor(s),edgecolorizer=seaedgecolor(s))
"""

begin
    function make(a,b,c,d,side)
        sealevel = filter(i->proj2sea[i]∈1:4,1:length(sea))
        seaside = filter(i->proj2sea[i]∈side,1:length(sea))
        
        LL.GWAutomaton([sealevel,seaside,seaside,seaside,sealevel],
                       [[(a,2),(b,3)],[(a,2),(b,3)],[(c,4),(c,5)],[(d,4),(d,5)],[]],
                       [i==1 for i=1:5],
                       [i==5 for i=1:5])
    end
    global seanorth = make(4,3,2,1,[10,11])
    global seasouth = make(3,4,1,2,[10,11])
    global seaeast = make(2,1,4,3,[5,6])
    global seawest = make(1,2,3,4,[5,6])
end

"""Sample crawling:
julia> root = LL.Element(0,0,3);
julia> s = LL.solve(LL.graph(6),sea,seed=[root=>1]);
julia> LL.walk(root,seanorth,s)
"""

LL.solver != nothing && @testset "Sea level" begin
    global root = LL.Element(0,0,3)
    global seatiling = LL.solve(LL.graph(6),sea,seed=[root=>1],numsols=2)
    @test length(seatiling) == 1
    seatiling = seatiling[1]

    for i=0:7, j=0:6
        @test LL.walk(LL.Element(i,j,3),seaeast,seatiling) == LL.Element(i,j+1,3)
        @test LL.walk(LL.Element(j,i,3),seanorth,seatiling) == LL.Element(j+1,i,3)
        @test LL.walk(LL.Element(i,j+1,3),seawest,seatiling) == LL.Element(i,j,3)
        @test LL.walk(LL.Element(j+1,i,3),seasouth,seatiling) == LL.Element(j,i,3)
    end       
end

# a sofic shift representing the full preimage of the binary shift {0,1}^ℤ
plane0 = [[(i,b),((i+1)%3,c),(i,b),(i,b)] for i=0:2 for b=false:true for c=false:true]
plane1 = [plane0...,LL.flipab(plane0)...]
function planecolor0(t)
    for i=1:4
        if t[i][1]==(t[5-i][1]+1)%3
            return t[i][2]
        end
    end
end
plane = LL.prodtiles(plane1,LL.flipinvert(plane1),(i,j)->planecolor0(plane1[i])==planecolor0(plane1[j]))
planecolor(t) = planecolor0([x[1] for x=t])

nothing
