diff --git a/src/magic_squares/MagicSquareSolver.scala b/src/magic_squares/MagicSquareSolver.scala new file mode 100644 index 0000000..dabc328 --- /dev/null +++ b/src/magic_squares/MagicSquareSolver.scala @@ -0,0 +1,146 @@ +package magic_squares + +class MagicSquareSolver(initialGrid: Grid) { + protected var _initial: Grid = initialGrid + protected val _size: Int = _initial.length + protected val SUM: Int = 15 + protected val DEBUG: Boolean = false + + protected def print(grid: Grid): Unit = { + for (y: Int <- 0 until _size) { + println(grid(y).mkString(",")) + } + } + + protected def copy(grid: Grid): Grid = { + return grid.map(_.clone()) + } + + def solve(): Unit = { + val sol: Option[Grid] = solveFrom(_initial, 0, 0) + if (sol.isEmpty) { + println("No solution") + } else { + print(sol.get) + } + } + + private def solveFrom(grid: Grid, x: Int, y: Int): Option[Grid] = { + if (DEBUG) println(s"Solving from $x, $y") + if (DEBUG) print(grid) + if (!isValid(grid)) { + if (DEBUG) println(" Grid is invalid") + return None + } + if (y >= _size) { + if (DEBUG) println(" Found solution") + return Some(grid) + } + var values: Array[Int] = Array(_initial(y)(x)) + if (values(0) == 0) values = (1 to 9).toArray + if (DEBUG) println(s" Values to test: " + values.mkString("[", ", ", "]")) + + val newGrid: Grid = copy(grid) + var x2: Int = x + 1 + var y2: Int = y + if (x2 >= _size) { + x2 -= _size + y2 += 1 + } + + for (i: Int <- values) { + if (DEBUG) println(s" Testing $i") + newGrid(y)(x) = i + val sol: Option[Grid] = solveFrom(newGrid, x2, y2) + if (sol.isDefined) { + if (DEBUG) println(" Found solution, collapsing call stack") + return sol + } + } + + if (DEBUG) println(s" No solution for this configuration") + return None + } + + private def isValid(grid: Grid): Boolean = { + val values: Array[Int] = grid.reduce((a, b) => a.concat(b)).filter(_ != 0) + if (values.distinct.length != values.length) return false + + val diag1: Array[Int] = new Array(_size) + val diag2: Array[Int] = new Array(_size) + for (i: Int <- 0 until _size) { + val row: Array[Int] = grid(i) + val col: Array[Int] = grid.map(_(i)) + diag1(i) = row(i) + diag2(i) = col(_size - i - 1) + + if (!isLineValid(row)) { + if (DEBUG) println(s" -> row $i is invalid: " + row.mkString("[", ", ", "]")) + return false + } + if (!isLineValid(col)) { + if (DEBUG) println(s" -> column $i is invalid: " + col.mkString("[", ", ", "]")) + return false + } + } + if (!isLineValid(diag1)) { + if (DEBUG) println(s" -> diag1 is invalid: " + diag1.mkString("[", ", ", "]")) + return false + } + if (!isLineValid(diag2)) { + if (DEBUG) println(s" -> diag2 is invalid: " + diag2.mkString("[", ", ", "]")) + return false + } + + return true + } + + private def isLineValid(line: Array[Int]): Boolean = { + val sum: Int = line.sum + if (line.contains(0)) { + if (sum > SUM) return false + } else if (sum != SUM) return false + + return true + } +} + +object MagicSquareSolver { + def main(args: Array[String]): Unit = { + val solver1: MagicSquareSolver = new MagicSquareSolver(Array( + Array(8, 0, 0), + Array(0, 0, 7), + Array(0, 9, 0) + )) + solver1.solve() + println() + /* + 8,1,6 + 3,5,7 + 4,9,2 + */ + + val solver2: MagicSquareSolver = new MagicSquareSolver(Array( + Array(9, 0, 0), + Array(0, 0, 7), + Array(0, 8, 0) + )) + solver2.solve() + println() + /* + no solution + */ + + val solver3: MagicSquareSolver = new MagicSquareSolver(Array( + Array(0, 0, 2), + Array(0, 5, 7), + Array(0, 0, 0) + )) + solver3.solve() + /* + 4,9,2 + 3,5,7 + 8,1,6 + */ + } +} \ No newline at end of file diff --git a/src/magic_squares/package.scala b/src/magic_squares/package.scala new file mode 100644 index 0000000..873927f --- /dev/null +++ b/src/magic_squares/package.scala @@ -0,0 +1,3 @@ +package object magic_squares { + type Grid = Array[Array[Int]] +}