Hace poco estuve leyendo unas notas de un curso en línea sobre planeación en IA. En una de ellas me encontré con un algoritmo que tenía rato que no veía ni utilizaba, y me dio curiosidad por implementarlo en Scala; me refiero al algoritmo A*.
En vez de explicar qué hace específicamente el algoritmo, un googlazo o una búsqueda en Wikipedia proveen información más detallada al respecto. El problema a resolver era el famoso 8-puzzle, aquel cuadro con números del 1 al 8 en el que hay que ponerlos en orden:
El algoritmo A* aplicado a este problema lo pueden encontrar fácilmente con una búsqueda en internet, pero como yo quería practicar Scala (lenguaje que uso en mis proyectos) me puse a ver qué tal me quedaba. Solamente tuve un problema en el algoritmo: tuve que usar un mutable hashset (horror, lo sé), porque al usar uno inmutable el tiempo de ejecución se hacía muy largo. Si hay alguien por ahí que quiera optimizar el código, adelante. También implementé la solución de forma imperativa nada más para comparar.
Aquí el código. Recuerden que esto no es la mejor implementación, y que por ende, puede mejorar. Los heurísticos implementados son Manhattan Distance (distancia de un estado x a uno meta) y Misplaced tiles (contar el número de cuadros que no están en su lugar. El segundo también lleva a la solución, pero tarda más en encontrarla. La función principal (solve) está optimizada como tail recursive para evitar un posible stack overflow. Además, van a ver muchos val quizá innecesarios que puse para darle legibilidad al depurarlo en el caso de que fuera necesario.
Sugerencias y comentarios son bienvenidos:
package org.mmg.astar
import scala.math
import scala.collection.mutable.HashSet
import scala.annotation.tailrec
sealed trait Definitions {
type Board = List[List[Int]]
type Position = (Int,Int)
}
sealed class Node(val board: List[List[Int]], val parent: Option[Node], val hcost: Int, val gcost: Int) {
override def hashCode(): Int = board.hashCode
override def equals(obj: Any) = {
obj match {
case o : Node => o.board == this.board
case _ => false
}
}
def f = hcost + gcost
}
class EightPuzzle (initialState: List[List[Int]], goalState: List[List[Int]], size: Int) extends Definitions {
private def getPosition(value: Int)(b: Board) = {
val row = b.indexWhere(_.contains(value))
(row, b(row).indexWhere(_ == value))
}
private val getZeroPosition = getPosition(0) _
private def getElement(p: Position, b: Board) = b(p._1)(p._2)
// Heuristic
private def manhattanDistance(p1: Position, p2: Position) =
math.abs(p2._1 - p1._1) + math.abs(p2._2 - p1._2)
val distance = (current: Board, target: Board) =>
current.filterNot(_ == 0).foldLeft(0)((acc, row) => {
acc + row.foldLeft(0)((res, value) => {
val cPos = getPosition(value)(current)
val tPos = getPosition(value)(target)
res + manhattanDistance(cPos,tPos)
})
})
// Another heuristic. Not so good
val misplaced = (current: Board, target: Board) => {
current.foldLeft(0)((acc, row) => {
acc + row.foldLeft(0)( (res, value) => {
if (getPosition(value)(current) != getPosition(value)(target))
res + 1
else
res
})
})
}
private def move(move: (Position,Position), b: Board) = {
val p1 = move._1
val p2 = move._2
val row1 = p1._1
val row2 = p2._1
val oElem = getElement(p1,b)
val nElem = getElement(p2,b)
val newRow1 = b(row1).updated(p1._2, nElem)
val tempBoard = b.updated(row1, newRow1)
val newRow2 = tempBoard(row2).updated(p2._2, oElem)
b.updated(row1, newRow1).updated(row2,newRow2)
}
// Always look for the empty block
private def getValidMoves(b: Board): List[(Position, Position)] = {
val p = getZeroPosition(b)
val left = if (p._2 > 0) p._2 - 1 else -1
val right = if (p._2 < size - 1) p._2 + 1 else -1
val up = if (p._1 > 0) p._1 - 1 else -1
val down = if (p._1 < size - 1) p._1 + 1 else -1
sealed abstract class Direction
case class LEFT extends Direction
case class RIGHT extends Direction
case class UP extends Direction
case class DOWN extends Direction
val valid = ((LEFT,left) :: (RIGHT,right) :: (UP,up) :: (DOWN,down) :: Nil).filterNot(_._2 == -1)
valid.foldLeft(List[(Position,Position)]())( (acc, m) => {
val nPos = m._1 match {
case LEFT => ((p._1,left), (p._1, left + 1))
case RIGHT => ((p._1, right), (p._1, right - 1))
case UP => ((up, p._2), (up + 1, p._2))
case _ => ((down, p._2), (down - 1, p._2))
}
nPos :: acc
} )
}
// Build the path from the goal to the start point
private def getPath(goal: Node): List[Board] = {
goal.parent match {
case Some(node) => node.board :: getPath(node)
case _ => Nil
}
}
/* Solve using A*
* h is the heuristic (selected when calling the function)
*/
@tailrec
final def solve(fringe: List[Node], closed: HashSet[Node], h: (Board,Board) => Int) : List[Board] = {
if (fringe.isEmpty) {
List()
}
else {
val currentState = fringe.head
if (currentState.board == goalState) {
(currentState.board :: getPath(currentState)).reverse
}
else {
closed += currentState
val expanded = getValidMoves(currentState.board).foldLeft(List[Board]())((acc, m) => move(m, currentState.board) :: acc)
.map {b => new Node(b,Some(currentState),h(b,goalState),0)}
.filterNot(n => closed.contains(n))
// newNeighbors contains the nodes we have to add or update in the fringe
val newNeighbors = expanded.map {n =>
val tempG = currentState.gcost + h(currentState.board,n.board)
if (!fringe.tail.contains(n) || tempG < n.gcost)
new Node(n.board, Some(currentState), h(n.board,goalState), tempG)
else
new Node(List(),None,0,0)
} filterNot(_.board.isEmpty)
val newFringe = newNeighbors.foldLeft(fringe.tail)((acc, n) => {
if (!acc.contains(n))
n :: acc
else {
n :: acc.filterNot(_ == n)
}
}).sortBy(_.f)
solve(newFringe, closed,h)
}
}
}
private def printNode(n: Node) {
println("H = " + n.hcost + " G = " + n.gcost + " F = " + n.f)
EightPuzzle.printBoard(n.board)
}
// The imperative way
def solve_imperative(open: scala.collection.mutable.MutableList[Node], closed: HashSet[Node], h:(Board,Board) => Int) : List[Board] = {
var found: Boolean = false
var fringe = open
while (!fringe.isEmpty) {
val currentState = fringe.head
/* println("Current state: ")
EightPuzzle.printBoard(currentState.board)*/
fringe = fringe.tail
if (currentState.board == goalState) {
return (currentState.board :: getPath(currentState)).reverse
}
else {
closed += currentState
val expanded = getValidMoves(currentState.board).foldLeft(List[Board]())((acc, m) => move(m, currentState.board) :: acc)
.map {b => new Node(b,Some(currentState),h(b,goalState),0)}
.filterNot(n => closed.contains(n))
expanded.foreach(n => {
val tempG = currentState.gcost + h(currentState.board,n.board)
if (!fringe.contains(n) || tempG < n.gcost) {
val newNode = new Node(n.board, Some(currentState), h(n.board, goalState), tempG)
if (!fringe.contains(n))
fringe += newNode
else {
fringe.update(fringe.indexWhere(_.board == n.board), newNode)
}
}
})
fringe = fringe.sortWith(_.f < _.f)
}
}
List()
}
}
object EightPuzzle extends Definitions {
def board2String(b: Board) = {
b.foldLeft("")((acc, row) => {
acc + row.foldLeft("")((str, value) => str + value + " ") + "\n"
})
}
def toBoard(state: Array[Int], size: Int) = state.toList.grouped(size).toList
def printBoard(b: List[List[Int]]) = println(board2String(b))
def main(a: Array[String]): Unit = {
val start = toBoard(Array(1,6,4,8,7,0,3,2,5),3) // 21 steps to the goal
//val start = toBoard(Array(8,1,7,4,5,6,2,0,3),3) // 25 steps to the goal
//val start = toBoard(Array(1,2,5,3,0,4,6,7,8), 3)
val goal = toBoard(Array(0,1,2,3,4,5,6,7,8),3)
//val goal = toBoard(Array(1,2,3,4,5,6,7,8,0),3)
val e = new EightPuzzle(start,goal,3)
println(" ========= Initial State =============")
printBoard(start)
println(" ========= Goal ===============")
printBoard(goal)
println("Press ENTER to start the search")
readLine()
println("Solving. Please wait...")
val timer = new Stopwatch().start
val solution = e.solve(List(new Node(start,None,0,0)), HashSet(), e.distance)
//val solution = e.solve(List(new Node(start,None,0,0)), HashSet(), e.misplaced) // Takes loooooonger to find the solution
//val solution = e.solve_imperative(scala.collection.mutable.MutableList[Node](new Node(start,None,0,0)), HashSet(), e.distance)
timer.stop
solution match {
case l if (!l.isEmpty) =>
println(" ############ Solution ########### ")
l.zipWithIndex foreach { case (b,i) =>
println("(" + i + "): ")
printBoard(b)
}
println("Solution found in " + timer.stop.getElapsedTime + " seconds")
case _ =>
println("No solution found.")
}
}
}
La clase Stopwatch no es de mi autoría; la bajé de internet porque me daba flojera crear una:
package org.mmg.astar
//
// Stopwatch for benchmarking
//
class Stopwatch {
private var startTime = -1L
private var stopTime = -1L
private var running = false
def start(): Stopwatch = {
startTime = System.currentTimeMillis()
running = true
this
}
def stop(): Stopwatch = {
stopTime = System.currentTimeMillis()
running = false
this
}
def isRunning(): Boolean = running
private def formatTime(time: Long) = (time / 1000) + "." + (time % 1000)
def getElapsedTime() = {
if (startTime == -1) {
0
}
if (running) {
formatTime(System.currentTimeMillis() - startTime)
}
else {
formatTime(stopTime - startTime)
}
}
def reset() {
startTime = -1
stopTime = -1
running = false
}
}
Lo que sigue es hacer lo mismo, pero en Haskell. Ya tengo un buen rato que no lo uso en serio, así que creo que es buena oportunidad para retomarlo.


ya lo vi, pero ahora sí me hablaste quien sabe en que idioma, jajajajajaja, brillante siempre, y yo solo sé que este jueguito lo sé resolver desde niña, pero recuerdo que más tarde salieron otros con más números…….
Antonio Fulano Detal liked this on Facebook.
Monica De la Paz liked this on Facebook.
Podrías hacer un post de comparación entre haskell y scala? Por ejemplo: si es válido alguno en un ambiente de producción; si es más fácil exponer servicios con esos lenguajes en vez de aplicaciones web, una opinión general de lo que te han parecido en el aspecto técnico, etc…