Hace poco estuve leyendo unas notas de un curso en línea sobre planeación en IA. En una de ellas me encontré con un algoritmo que tenía rato que no veía ni utilizaba, y me dio curiosidad por implementarlo en Scala; me refiero al algoritmo A*.
En vez de explicar qué hace específicamente el algoritmo, un googlazo o una búsqueda en Wikipedia proveen información más detallada al respecto. El problema a resolver era el famoso 8-puzzle, aquel cuadro con números del 1 al 8 en el que hay que ponerlos en orden:
El algoritmo A* aplicado a este problema lo pueden encontrar fácilmente con una búsqueda en internet, pero como yo quería practicar Scala (lenguaje que uso en mis proyectos) me puse a ver qué tal me quedaba. Solamente tuve un problema en el algoritmo: tuve que usar un mutable hashset (horror, lo sé), porque al usar uno inmutable el tiempo de ejecución se hacía muy largo. Si hay alguien por ahí que quiera optimizar el código, adelante. También implementé la solución de forma imperativa nada más para comparar.
Aquí el código. Recuerden que esto no es la mejor implementación, y que por ende, puede mejorar. Los heurísticos implementados son Manhattan Distance (distancia de un estado x a uno meta) y Misplaced tiles (contar el número de cuadros que no están en su lugar. El segundo también lleva a la solución, pero tarda más en encontrarla. La función principal (solve) está optimizada como tail recursive para evitar un posible stack overflow. Además, van a ver muchos val quizá innecesarios que puse para darle legibilidad al depurarlo en el caso de que fuera necesario.
Sugerencias y comentarios son bienvenidos:
package org.mmg.astar import scala.math import scala.collection.mutable.HashSet import scala.annotation.tailrec sealed trait Definitions { type Board = List[List[Int]] type Position = (Int,Int) } sealed class Node(val board: List[List[Int]], val parent: Option[Node], val hcost: Int, val gcost: Int) { override def hashCode(): Int = board.hashCode override def equals(obj: Any) = { obj match { case o : Node => o.board == this.board case _ => false } } def f = hcost + gcost } class EightPuzzle (initialState: List[List[Int]], goalState: List[List[Int]], size: Int) extends Definitions { private def getPosition(value: Int)(b: Board) = { val row = b.indexWhere(_.contains(value)) (row, b(row).indexWhere(_ == value)) } private val getZeroPosition = getPosition(0) _ private def getElement(p: Position, b: Board) = b(p._1)(p._2) // Heuristic private def manhattanDistance(p1: Position, p2: Position) = math.abs(p2._1 - p1._1) + math.abs(p2._2 - p1._2) val distance = (current: Board, target: Board) => current.filterNot(_ == 0).foldLeft(0)((acc, row) => { acc + row.foldLeft(0)((res, value) => { val cPos = getPosition(value)(current) val tPos = getPosition(value)(target) res + manhattanDistance(cPos,tPos) }) }) // Another heuristic. Not so good val misplaced = (current: Board, target: Board) => { current.foldLeft(0)((acc, row) => { acc + row.foldLeft(0)( (res, value) => { if (getPosition(value)(current) != getPosition(value)(target)) res + 1 else res }) }) } private def move(move: (Position,Position), b: Board) = { val p1 = move._1 val p2 = move._2 val row1 = p1._1 val row2 = p2._1 val oElem = getElement(p1,b) val nElem = getElement(p2,b) val newRow1 = b(row1).updated(p1._2, nElem) val tempBoard = b.updated(row1, newRow1) val newRow2 = tempBoard(row2).updated(p2._2, oElem) b.updated(row1, newRow1).updated(row2,newRow2) } // Always look for the empty block private def getValidMoves(b: Board): List[(Position, Position)] = { val p = getZeroPosition(b) val left = if (p._2 > 0) p._2 - 1 else -1 val right = if (p._2 < size - 1) p._2 + 1 else -1 val up = if (p._1 > 0) p._1 - 1 else -1 val down = if (p._1 < size - 1) p._1 + 1 else -1 sealed abstract class Direction case class LEFT extends Direction case class RIGHT extends Direction case class UP extends Direction case class DOWN extends Direction val valid = ((LEFT,left) :: (RIGHT,right) :: (UP,up) :: (DOWN,down) :: Nil).filterNot(_._2 == -1) valid.foldLeft(List[(Position,Position)]())( (acc, m) => { val nPos = m._1 match { case LEFT => ((p._1,left), (p._1, left + 1)) case RIGHT => ((p._1, right), (p._1, right - 1)) case UP => ((up, p._2), (up + 1, p._2)) case _ => ((down, p._2), (down - 1, p._2)) } nPos :: acc } ) } // Build the path from the goal to the start point private def getPath(goal: Node): List[Board] = { goal.parent match { case Some(node) => node.board :: getPath(node) case _ => Nil } } /* Solve using A* * h is the heuristic (selected when calling the function) */ @tailrec final def solve(fringe: List[Node], closed: HashSet[Node], h: (Board,Board) => Int) : List[Board] = { if (fringe.isEmpty) { List() } else { val currentState = fringe.head if (currentState.board == goalState) { (currentState.board :: getPath(currentState)).reverse } else { closed += currentState val expanded = getValidMoves(currentState.board).foldLeft(List[Board]())((acc, m) => move(m, currentState.board) :: acc) .map {b => new Node(b,Some(currentState),h(b,goalState),0)} .filterNot(n => closed.contains(n)) // newNeighbors contains the nodes we have to add or update in the fringe val newNeighbors = expanded.map {n => val tempG = currentState.gcost + h(currentState.board,n.board) if (!fringe.tail.contains(n) || tempG < n.gcost) new Node(n.board, Some(currentState), h(n.board,goalState), tempG) else new Node(List(),None,0,0) } filterNot(_.board.isEmpty) val newFringe = newNeighbors.foldLeft(fringe.tail)((acc, n) => { if (!acc.contains(n)) n :: acc else { n :: acc.filterNot(_ == n) } }).sortBy(_.f) solve(newFringe, closed,h) } } } private def printNode(n: Node) { println("H = " + n.hcost + " G = " + n.gcost + " F = " + n.f) EightPuzzle.printBoard(n.board) } // The imperative way def solve_imperative(open: scala.collection.mutable.MutableList[Node], closed: HashSet[Node], h:(Board,Board) => Int) : List[Board] = { var found: Boolean = false var fringe = open while (!fringe.isEmpty) { val currentState = fringe.head /* println("Current state: ") EightPuzzle.printBoard(currentState.board)*/ fringe = fringe.tail if (currentState.board == goalState) { return (currentState.board :: getPath(currentState)).reverse } else { closed += currentState val expanded = getValidMoves(currentState.board).foldLeft(List[Board]())((acc, m) => move(m, currentState.board) :: acc) .map {b => new Node(b,Some(currentState),h(b,goalState),0)} .filterNot(n => closed.contains(n)) expanded.foreach(n => { val tempG = currentState.gcost + h(currentState.board,n.board) if (!fringe.contains(n) || tempG < n.gcost) { val newNode = new Node(n.board, Some(currentState), h(n.board, goalState), tempG) if (!fringe.contains(n)) fringe += newNode else { fringe.update(fringe.indexWhere(_.board == n.board), newNode) } } }) fringe = fringe.sortWith(_.f < _.f) } } List() } } object EightPuzzle extends Definitions { def board2String(b: Board) = { b.foldLeft("")((acc, row) => { acc + row.foldLeft("")((str, value) => str + value + " ") + "\n" }) } def toBoard(state: Array[Int], size: Int) = state.toList.grouped(size).toList def printBoard(b: List[List[Int]]) = println(board2String(b)) def main(a: Array[String]): Unit = { val start = toBoard(Array(1,6,4,8,7,0,3,2,5),3) // 21 steps to the goal //val start = toBoard(Array(8,1,7,4,5,6,2,0,3),3) // 25 steps to the goal //val start = toBoard(Array(1,2,5,3,0,4,6,7,8), 3) val goal = toBoard(Array(0,1,2,3,4,5,6,7,8),3) //val goal = toBoard(Array(1,2,3,4,5,6,7,8,0),3) val e = new EightPuzzle(start,goal,3) println(" ========= Initial State =============") printBoard(start) println(" ========= Goal ===============") printBoard(goal) println("Press ENTER to start the search") readLine() println("Solving. Please wait...") val timer = new Stopwatch().start val solution = e.solve(List(new Node(start,None,0,0)), HashSet(), e.distance) //val solution = e.solve(List(new Node(start,None,0,0)), HashSet(), e.misplaced) // Takes loooooonger to find the solution //val solution = e.solve_imperative(scala.collection.mutable.MutableList[Node](new Node(start,None,0,0)), HashSet(), e.distance) timer.stop solution match { case l if (!l.isEmpty) => println(" ############ Solution ########### ") l.zipWithIndex foreach { case (b,i) => println("(" + i + "): ") printBoard(b) } println("Solution found in " + timer.stop.getElapsedTime + " seconds") case _ => println("No solution found.") } } }
La clase Stopwatch no es de mi autoría; la bajé de internet porque me daba flojera crear una:
package org.mmg.astar // // Stopwatch for benchmarking // class Stopwatch { private var startTime = -1L private var stopTime = -1L private var running = false def start(): Stopwatch = { startTime = System.currentTimeMillis() running = true this } def stop(): Stopwatch = { stopTime = System.currentTimeMillis() running = false this } def isRunning(): Boolean = running private def formatTime(time: Long) = (time / 1000) + "." + (time % 1000) def getElapsedTime() = { if (startTime == -1) { 0 } if (running) { formatTime(System.currentTimeMillis() - startTime) } else { formatTime(stopTime - startTime) } } def reset() { startTime = -1 stopTime = -1 running = false } }
Lo que sigue es hacer lo mismo, pero en Haskell. Ya tengo un buen rato que no lo uso en serio, así que creo que es buena oportunidad para retomarlo.
ya lo vi, pero ahora sí me hablaste quien sabe en que idioma, jajajajajaja, brillante siempre, y yo solo sé que este jueguito lo sé resolver desde niña, pero recuerdo que más tarde salieron otros con más números…….
Antonio Fulano Detal liked this on Facebook.
Monica De la Paz liked this on Facebook.
Podrías hacer un post de comparación entre haskell y scala? Por ejemplo: si es válido alguno en un ambiente de producción; si es más fácil exponer servicios con esos lenguajes en vez de aplicaciones web, una opinión general de lo que te han parecido en el aspecto técnico, etc…