8-puzzle: Implementación sencilla de A*

Hace poco estuve leyendo unas notas de un curso en línea sobre planeación en IA. En una de ellas me encontré con un algoritmo que tenía rato que no veía ni utilizaba, y me dio curiosidad por implementarlo en Scala; me refiero al algoritmo A*.

En vez de explicar qué hace específicamente el algoritmo, un googlazo o una búsqueda en Wikipedia proveen información más detallada al respecto. El problema a resolver era el famoso 8-puzzle, aquel cuadro con números del 1 al 8 en el que hay que ponerlos en orden:

El algoritmo A* aplicado a este problema lo pueden encontrar fácilmente con una búsqueda en internet, pero como yo quería practicar Scala (lenguaje que uso en mis proyectos) me puse a ver qué tal me quedaba. Solamente tuve un problema en el algoritmo: tuve que usar un mutable hashset (horror, lo sé), porque al usar uno inmutable el tiempo de ejecución se hacía muy largo. Si hay alguien por ahí que quiera optimizar el código, adelante. También implementé la solución de forma imperativa nada más para comparar.

Aquí el código. Recuerden que esto no es la mejor implementación, y que por ende, puede mejorar. Los heurísticos implementados son Manhattan Distance (distancia de un estado x a uno meta) Misplaced tiles (contar el número de cuadros que no están en su lugar. El segundo también lleva a la solución, pero tarda más en encontrarla. La función principal (solve) está optimizada como tail recursive para evitar un posible stack overflow. Además, van a ver muchos val quizá innecesarios que puse para darle legibilidad al depurarlo en el caso de que fuera necesario.

Sugerencias y comentarios son bienvenidos:

package org.mmg.astar

import scala.math
import scala.collection.mutable.HashSet
import scala.annotation.tailrec

sealed trait Definitions {
  type Board = List[List[Int]]
  type Position = (Int,Int) 
}

sealed class Node(val board: List[List[Int]], val parent: Option[Node], val hcost: Int, val gcost: Int)  {
  override def hashCode(): Int = board.hashCode

  override def equals(obj: Any) = {
    obj match {
    	case o : Node =>  o.board == this.board 
    	case _ => false
    }
  }

  def f = hcost + gcost
}

class EightPuzzle (initialState: List[List[Int]], goalState: List[List[Int]], size: Int) extends Definitions {

  private def getPosition(value: Int)(b: Board) = {
    val row = b.indexWhere(_.contains(value))
    (row, b(row).indexWhere(_ == value))
  }

  private val getZeroPosition = getPosition(0) _
  private def getElement(p: Position, b: Board) = b(p._1)(p._2)

  // Heuristic
  private def manhattanDistance(p1: Position, p2: Position) = 
    math.abs(p2._1 - p1._1) + math.abs(p2._2 - p1._2)    

  val distance = (current: Board, target: Board) => 
    current.filterNot(_ == 0).foldLeft(0)((acc, row) => {
      acc + row.foldLeft(0)((res, value) => {
        val cPos = getPosition(value)(current)
        val tPos = getPosition(value)(target)        
        res + manhattanDistance(cPos,tPos)        
      })

    })

  // Another heuristic. Not so good  
  val misplaced = (current: Board, target: Board) => {
    current.foldLeft(0)((acc, row) => {
      acc + row.foldLeft(0)( (res, value) => {
        if (getPosition(value)(current) != getPosition(value)(target))
           res + 1
         else
           res
      })
    })
  }  

  private def move(move: (Position,Position), b: Board) = {
    val p1 = move._1
    val p2 = move._2
    val row1 = p1._1
    val row2 = p2._1

    val oElem = getElement(p1,b)
    val nElem = getElement(p2,b)

    val newRow1 = b(row1).updated(p1._2, nElem)
    val tempBoard = b.updated(row1, newRow1)
    val newRow2 = tempBoard(row2).updated(p2._2, oElem)

    b.updated(row1, newRow1).updated(row2,newRow2)    
  } 

  // Always look for the empty block
  private def getValidMoves(b: Board): List[(Position, Position)] = {
    val p = getZeroPosition(b)
    val left = if (p._2 > 0) p._2 - 1 else -1
    val right = if (p._2 < size - 1) p._2 + 1 else -1
    val up = if (p._1 > 0) p._1 - 1 else -1
    val down = if (p._1 < size - 1) p._1 + 1 else -1

    sealed abstract class Direction
    case class LEFT extends Direction
    case class RIGHT extends Direction
    case class UP extends Direction
    case class DOWN extends Direction

    val valid = ((LEFT,left) :: (RIGHT,right) :: (UP,up) :: (DOWN,down) :: Nil).filterNot(_._2 == -1)

    valid.foldLeft(List[(Position,Position)]())( (acc, m) => {
      val nPos = m._1 match {
        case LEFT => ((p._1,left), (p._1, left + 1))
        case RIGHT => ((p._1, right), (p._1, right - 1))
        case UP => ((up, p._2), (up + 1, p._2)) 
        case _ => ((down, p._2), (down - 1, p._2))

      } 
      nPos :: acc
    } )

  }

  // Build the path from the goal to the start point
  private def getPath(goal: Node): List[Board] = {
    goal.parent match {
      case Some(node) => node.board :: getPath(node)
      case _ => Nil
    }
  }

  /* Solve using A*
   * h is the heuristic (selected when calling the function)
   */ 
  @tailrec
  final def solve(fringe: List[Node], closed: HashSet[Node], h: (Board,Board) => Int) : List[Board] = {

    if (fringe.isEmpty) {
    	List()
    }
    else {
      val currentState = fringe.head

      if (currentState.board == goalState) {
        (currentState.board :: getPath(currentState)).reverse
      }
      else {
        closed += currentState
        val expanded = getValidMoves(currentState.board).foldLeft(List[Board]())((acc, m) => move(m, currentState.board) :: acc) 
        .map {b => new Node(b,Some(currentState),h(b,goalState),0)}
        .filterNot(n => closed.contains(n))

        // newNeighbors contains the nodes we have to add or update in the fringe
        val newNeighbors = expanded.map {n => 
           	  val tempG =   currentState.gcost + h(currentState.board,n.board)

	          if (!fringe.tail.contains(n) || tempG < n.gcost)
	            new Node(n.board, Some(currentState), h(n.board,goalState), tempG)
	          else
	           new Node(List(),None,0,0)
        } filterNot(_.board.isEmpty)

        val newFringe = newNeighbors.foldLeft(fringe.tail)((acc, n) => {
          if (!acc.contains(n))
            n :: acc
          else {
            n :: acc.filterNot(_ == n)
          }
        }).sortBy(_.f)

        solve(newFringe, closed,h)
      }
    }
  }

  private def printNode(n: Node) {
    println("H = " + n.hcost + " G = " + n.gcost + " F = " + n.f)
    EightPuzzle.printBoard(n.board)
  }

  // The imperative way
  def solve_imperative(open: scala.collection.mutable.MutableList[Node], closed: HashSet[Node], h:(Board,Board) => Int) : List[Board] = {

    var found: Boolean = false
    var fringe = open

    while (!fringe.isEmpty) {
      val currentState = fringe.head
     /* println("Current state: ")
      EightPuzzle.printBoard(currentState.board)*/
      fringe = fringe.tail

      if (currentState.board == goalState) {
    	  return (currentState.board :: getPath(currentState)).reverse
      }
      else {
        closed += currentState

         val expanded = getValidMoves(currentState.board).foldLeft(List[Board]())((acc, m) => move(m, currentState.board) :: acc) 
        .map {b => new Node(b,Some(currentState),h(b,goalState),0)}
        .filterNot(n => closed.contains(n))

        expanded.foreach(n => {
        	val tempG =   currentState.gcost + h(currentState.board,n.board)

        	if (!fringe.contains(n) || tempG < n.gcost) {
        	  val newNode = new Node(n.board, Some(currentState), h(n.board, goalState), tempG)
        	  if (!fringe.contains(n))
        		  fringe += newNode
        	  else {
        	    fringe.update(fringe.indexWhere(_.board == n.board), newNode)
        	  }	  
        	}
        })

       fringe = fringe.sortWith(_.f < _.f)   

      }

    }

    List()
  }

}  

object EightPuzzle extends Definitions {

  def board2String(b: Board) = {
    b.foldLeft("")((acc, row) => {
      acc + row.foldLeft("")((str, value) => str + value + " ") + "\n" 
    })
  }

  def toBoard(state: Array[Int], size: Int) = state.toList.grouped(size).toList

  def printBoard(b: List[List[Int]]) = println(board2String(b))

  def main(a: Array[String]): Unit = {
	  val start = toBoard(Array(1,6,4,8,7,0,3,2,5),3) // 21 steps to the goal
	  //val start = toBoard(Array(8,1,7,4,5,6,2,0,3),3) // 25 steps to the goal
	  //val start = toBoard(Array(1,2,5,3,0,4,6,7,8), 3)
	  val goal = toBoard(Array(0,1,2,3,4,5,6,7,8),3)
	  //val goal = toBoard(Array(1,2,3,4,5,6,7,8,0),3)

	  val e = new EightPuzzle(start,goal,3)

	  println(" ========= Initial State =============")
	  printBoard(start)
	  println(" ========= Goal ===============")
	  printBoard(goal)
	  println("Press ENTER to start the search")
	  readLine()
	  println("Solving. Please wait...")

	  val timer = new Stopwatch().start	  
	  val solution = e.solve(List(new Node(start,None,0,0)), HashSet(), e.distance)
	  //val solution = e.solve(List(new Node(start,None,0,0)), HashSet(), e.misplaced) // Takes loooooonger to find the solution
	  //val solution = e.solve_imperative(scala.collection.mutable.MutableList[Node](new Node(start,None,0,0)), HashSet(), e.distance)
	  timer.stop

	  solution match {
		  case l if (!l.isEmpty) => 
			  println(" ############ Solution ########### ")
			  l.zipWithIndex foreach { case (b,i) =>
				  println("(" + i + "): ")
				  printBoard(b)
			  }
			  println("Solution found in " + timer.stop.getElapsedTime + " seconds")
		  case _ => 
			  println("No solution found.")
	  }

  }
}

La clase Stopwatch no es de mi autoría; la bajé de internet porque me daba flojera crear una:

package org.mmg.astar

//
//   Stopwatch for benchmarking 
//
class Stopwatch {

  private var startTime = -1L
  private var stopTime = -1L
  private var running = false

  def start(): Stopwatch = {
    startTime = System.currentTimeMillis()
    running = true
    this
  }

  def stop(): Stopwatch = {
    stopTime = System.currentTimeMillis()
    running = false
    this 
  }

  def isRunning(): Boolean = running

  private def formatTime(time: Long) = (time / 1000) + "." + (time % 1000)

  def getElapsedTime() = {
    if (startTime == -1) {
      0
    }
    if (running) {
      formatTime(System.currentTimeMillis() - startTime)
    }
    else {
      formatTime(stopTime - startTime)
    }
  }

  def reset() {
    startTime = -1
    stopTime = -1
    running = false
  }
}

Lo que sigue es hacer lo mismo, pero en Haskell. Ya tengo un buen rato que no lo uso en serio, así que creo que es buena oportunidad para retomarlo.