/projects/lattice-enumeration/
mathematicspythoncode_project
"Lattice enumeration algorithm"
class EnumError(Exception):
	"""Exception thrown by the enumeration algorithm"""
	pass

def enumeration(A, m , r):
	"""Shortest vector problem enumeration algorithm. 
	   The Schnorr-Euchner variant of the KFP enumeration algorithm
		A: bound
		m: Gram-Schmidt coefficients
		r: squared length of the orthogonal basis"""
	# dimension of lattice
	d = len(r) 
	if 0 >= d: raise EnumError("dimension of lattice is to small")

# 1
	# integer coordinates to search shortest vector
	x   = [1] + [0]  * (d-1)
	# variables for the (Schnorr-Euchner) zig zag path
	# last step for each coordinate
	dx  = [1] + [0]  * (d-1)
	# last step direction for each coordinate
	ddx = [1] + [-1] * (d-1)
	# shorest vector up to now
	# remains zero up to the first vector with length smaller then sqrt(A)
	sol = [0] * d

# 2
	# x-y shift
	c   = [0] * d
	# length of the vector sum_{j>i} y_j b_j^*
	l   = [0] * (d+1)
	# squared coordinates in orthogonal basis
	ys   = [0] * d

# 3
	# depth in computation tree
	# means the coordinates x[i:d] are valid
	# ( (i,x[i:d]) : computation state )
	i = 0
	# count loops DEBUG
	loop = 0; solution = False # DEBUG
	while(True):
		# x and sol must be in Z^d
		assert int is type(x[0]) and int is type(sol[0])
# 4 Compute next length
		# compute coordinates in orthogonal basis (squared)	
		ys[i] = (x[i] - c[i])**2
		# compute length of vector sum_{j>i} y_j b_j^* 
		l[i] = l[i+1] + ys[i] * r[i]

# 5 Found possible solution
		# if actual vector is completely build up and has length at most sqrt(A)
		if i == 0 and l[i] <= A :
			# store it as new solution and update bound A
			(sol, A) = update1(sol, A, x, l[0])
			solution = True # DEBUG

		################################################
		# DEBUG
		# visualize computation
		loop += 1 # DEBUG
		s = "" # DEBUG
		s += '{:>4}'.format(loop)+":" # DEBUG
		for j in xrange(i): s+="   " # DEBUG
		for j in xrange(i, d): s+='{:>3}'.format(x[j]) # DEBUG
		if solution : s += " <"; solution = False # DEBUG
		else: s += "  " # DEBUG
		s += '{:>3}'.format(i) # DEBUG
		print s # DEBUG
		################################################

# 6 Build up coordinates 
		# if actual vector is not completely build up but has length still at most sqrt(A)
		if i > 0 and l[i] <= A :
			# descent in computation tree: compute next coordinate in next loop
			i -= 1

# 7
			# compute x_i-y_i shift
			c[i] = 0
			for j in xrange(d-1, i+1, -1):
				c[i] += x[j] * m[j][i]
			#c[i] = -sum([ x[j] * m[j][i] for j in xrange(i+1, d)])

# 8
			# start zig zag path from the center of the interval
			x[i]   = int(round(c[i]))
			# first step is zero in zig zag path
			dx[i]  = 0
			# first step direction is positive if c_i is rounded up
			# negative if rounded down
			if c[i] <= x[i] :
				ddx[i] = 1
			else :
				ddx[i] = -1

# 9 End of computation
		#if already the last coordinate is to big stop computation
		elif i == d-1 and l[i] > A :
			return sol

# 10 Next computation branch 
		else :
			# next change the coordinate one higher
			i += 1

# 11
			# compute next coordinate in zig zag path
			ddx[i] = -ddx[i]
			dx[i]  = -dx[i] + ddx[i]
			x[i]   = x[i] + dx[i]
	

def update1(sol, A, x, l0):
	return (list(x), l0)

def update2(sol, A, x, l0):
	return NotImplemented
	if 0 == sol: # TODO missing or ...
		return (x, A)
	else :
		return (sol, A)

if __name__ == "__main__" :
	from vector import Vector
	from test_bases import test_bases
	from test_bases import test_sols
	from lattice import Lattice
	import datetime
	print datetime.date.today()
	print "# test enumeration.py"
	print
	n = len(test_bases)
	for i in xrange(n):
		lattice = Lattice(test_bases[i])
		(r, m, bo) = lattice.gram_schmidt_coeffs()
		d = len(r)
		print
		print "### dimension: "+str(d)
		short_coords = enumeration(test_bases[i][0]*test_bases[i][0], m, r)
		short_vector = Vector([0] * d)
		for j in xrange(d):
			short_vector += test_bases[i][j] * short_coords[j]
		zero_vector = Vector(d)
		print "coords: " + str(short_coords)
		print "vector: " + str(short_vector)
		assert (short_vector-test_sols[i] == zero_vector) or\
			 (short_vector+test_sols[i] == zero_vector)